-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
120 lines (100 loc) · 4.42 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
__import__('pysqlite3')
import sys
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
# 导入必要的库
import gradio as gr
from langchain.vectorstores import Chroma
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from LLM import InternLM_LLM
from langchain.prompts import PromptTemplate
import torch
from modelscope import snapshot_download, AutoModel, AutoTokenizer
import os
def init():
model_dir = snapshot_download('Shanghai_AI_Laboratory/internlm-chat-7b'
, cache_dir='./', revision='v1.0.3')
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
# 下载模型
os.system('huggingface-cli download --resume-download sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2 --local-dir sentence-transformer')
def load_chain():
# 加载问答链
# 定义 Embeddings
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformer")
# 向量数据库持久化路径
persist_directory = 'data_base/vector_db/chroma'
# 加载数据库
vectordb = Chroma(
persist_directory=persist_directory, # 允许我们将persist_directory目录保存到磁盘上
embedding_function=embeddings
)
llm = InternLM_LLM(model_path = "Shanghai_AI_Laboratory/internlm-chat-7b")
template = """使用以下上下文来回答用户的问题。如果你不知道答案,就说你不知道。总是使用中文回答。
问题: {question}
可参考的上下文:
···
{context}
···
如果给定的上下文无法让你做出回答,请回答你不知道。
有用的回答:"""
QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context","question"],
template=template)
# 运行 chain
from langchain.chains import RetrievalQA
qa_chain = RetrievalQA.from_chain_type(llm,
retriever=vectordb.as_retriever(),
return_source_documents=True,
chain_type_kwargs={"prompt":QA_CHAIN_PROMPT})
return qa_chain
class Model_center():
"""
存储问答 Chain 的对象
"""
init()
def __init__(self):
self.chain = load_chain()
def qa_chain_self_answer(self, question: str, chat_history: list = []):
"""
调用不带历史记录的问答链进行回答
"""
if question == None or len(question) < 1:
return "", chat_history
try:
chat_history.append(
(question, self.chain({"query": question})["result"]))
return "", chat_history
except Exception as e:
return e, chat_history
model_center = Model_center()
block = gr.Blocks()
with block as demo:
with gr.Row(equal_height=True):
with gr.Column(scale=15):
gr.Markdown("""<h1><center>InternLM</center></h1>
<center>书生浦语</center>
""")
# gr.Image(value=LOGO_PATH, scale=1, min_width=10,show_label=False, show_download_button=False)
with gr.Row():
with gr.Column(scale=4):
chatbot = gr.Chatbot(height=450, show_copy_button=True)
# 创建一个文本框组件,用于输入 prompt。
msg = gr.Textbox(label="Prompt/问题")
with gr.Row():
# 创建提交按钮。
db_wo_his_btn = gr.Button("Chat")
with gr.Row():
# 创建一个清除按钮,用于清除聊天机器人组件的内容。
clear = gr.ClearButton(
components=[chatbot], value="Clear console")
# 设置按钮的点击事件。当点击时,调用上面定义的 qa_chain_self_answer 函数,并传入用户的消息和聊天历史记录,然后更新文本框和聊天机器人组件。
db_wo_his_btn.click(model_center.qa_chain_self_answer, inputs=[
msg, chatbot], outputs=[msg, chatbot])
gr.Markdown("""提醒:<br>
1. 初始化数据库时间可能较长,请耐心等待。
2. 使用中如果出现异常,将会在文本输入框进行展示,请不要惊慌。 <br>
""")
# threads to consume the request
gr.close_all()
# 启动新的 Gradio 应用,设置分享功能为 True,并使用环境变量 PORT1 指定服务器端口。
# demo.launch(share=True, server_port=int(os.environ['PORT1']))
# 直接启动
demo.launch()