-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
58 lines (45 loc) · 1.87 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# Import necessary modules
from flask import Flask, request, jsonify, render_template
from Swisscom_chatbot.src.retrieve.VSbuilder import VectorStore
from Swisscom_chatbot.src.retrieve.retriever import Retriever
from Swisscom_chatbot.src.llm.llm_request import LLM_request
import warnings
from langchain._api.deprecation import LangChainDeprecationWarning
warnings.filterwarnings('ignore', category=LangChainDeprecationWarning)
# Settings
folder_dataset_path = 'dataset/processed_parsed_documents'
csv_path = "resumed_reduced.csv"
persist_directory = "/teamspace/studios/this_studio"
# Build retriever
def create_retriever():
vs = VectorStore(folder_path=folder_dataset_path, csv_path=csv_path, persist_directory=persist_directory)
return Retriever(VectorStore=vs)
# Initialize Flask app
app = Flask(__name__)
history = []
r = create_retriever() # Initialize the retriever once
@app.route('/')
def home():
return render_template('index.html')
@app.route('/chat', methods=['POST'])
def chat():
query = request.form['query']
if not query:
return jsonify({"response": "Please enter a valid question."})
if query.lower() in ["exit", "quit", "stop"]:
return jsonify({"response": "Goodbye!"})
# Generate chatbot response
llm_req = LLM_request(Retriever=r, query=query, history=history)
ai_msg, links = llm_req.send_lmm_request()
# Append current input and output to history
history.append({"user": query, "llm": ai_msg})
# Return chatbot response
return jsonify({"response": ai_msg, "links": links})
@app.route('/refresh', methods=['POST'])
def refresh():
global r, history
r = create_retriever() # Create a new retriever instance
history = [] # Clear the chat history
return jsonify({"response": "Chatbot memory has been refreshed."})
if __name__ == "__main__":
app.run(host='0.0.0.0', port=5000)