-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathChatGLM3.py
126 lines (115 loc) · 4.02 KB
/
ChatGLM3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import json
from langchain.llms.base import LLM
from transformers import AutoTokenizer, AutoModel, AutoConfig
from typing import List, Optional
from utils import tool_config_from_file
class ChatGLM3(LLM):
max_token: int = 8192
do_sample: bool = False
temperature: float = 0.8
top_p = 0.8
tokenizer: object = None
model: object = None
history: List = []
tool_names: List = []
has_search: bool = False
def __init__(self):
super().__init__()
@property
def _llm_type(self) -> str:
return "ChatGLM3"
def load_model(self, model_name_or_path=None):
model_config = AutoConfig.from_pretrained(
model_name_or_path,
trust_remote_code=True
)
self.tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
trust_remote_code=True
)
self.model = AutoModel.from_pretrained(
model_name_or_path, config=model_config, trust_remote_code=True
).quantize(4).cuda()
def _tool_history(self, prompt: str):
ans = []
try:
tool_prompts = prompt.split(
"You have access to the following tools:\n\n")[1].split("\n\nUse a json blob")[0].split("\n")
except IndexError:
tool_prompts = [] # 如果没有找到工具,就使用一个空列表
tool_names = [tool.split(":")[0] for tool in tool_prompts]
self.tool_names = tool_names
tools_json = []
for i, tool in enumerate(tool_names):
tool_config = tool_config_from_file(tool)
if tool_config:
tools_json.append(tool_config)
else:
ValueError(
f"Tool {tool} config not found! It's description is {tool_prompts[i]}"
)
ans.append({
"role": "system",
"content": "Answer the following questions as best as you can. You have access to the following tools:",
"tools": tools_json
})
query = f"""{prompt.split("Human: ")[-1].strip()}"""
return ans, query
def _extract_observation(self, prompt: str):
return_json = prompt.split("Observation: ")[-1].split("\nThought:")[0]
self.history.append({
"role": "observation",
"content": return_json
})
return
def _extract_tool(self):
if len(self.history[-1]["metadata"]) > 0:
metadata = self.history[-1]["metadata"]
content = self.history[-1]["content"]
if "tool_call" in content:
for tool in self.tool_names:
if tool in metadata:
input_para = content.split("='")[-1].split("'")[0]
action_json = {
"action": tool,
"action_input": input_para
}
self.has_search = True
return f"""
Action:
```
{json.dumps(action_json, ensure_ascii=False)}
```"""
final_answer_json = {
"action": "Final Answer",
"action_input": self.history[-1]["content"]
}
self.has_search = False
return f"""
Action:
```
{json.dumps(final_answer_json, ensure_ascii=False)}
```"""
def _call(self, prompt: str, history: List = [], stop: Optional[List[str]] = ["<|user|>"]):
print("======")
print(prompt)
print("======")
if not self.has_search:
self.history, query = self._tool_history(prompt)
else:
self._extract_observation(prompt)
query = ""
# print("======")
# print(self.history)
# print("======")
_, self.history = self.model.chat(
self.tokenizer,
query,
history=self.history,
do_sample=self.do_sample,
max_length=self.max_token,
temperature=self.temperature,
)
response = self._extract_tool()
history.append((prompt, response))
return response