-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathirt_app.py
97 lines (75 loc) · 3.36 KB
/
irt_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from langchain_openai import ChatOpenAI
from langchain_groq import ChatGroq
from langchain_core.output_parsers import StrOutputParser
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.runnables import RunnableLambda, Runnable
from prompt_templates import router_template, recording_template, rewriting_template, summary_template
from dotenv import load_dotenv
load_dotenv()
import os
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
os.environ["LANGCHAIN_TRACING_V2"] = "false" #set to true to trace the conversation on langchain
os.environ["LANGCHAIN_API_KEY"] = os.getenv("LANGCHAIN_API_KEY")
os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY")
os.environ["LANGCHAIN_PROJECT"] = "DreamGuardTest"
ephemeral_chat_history_for_chain = ChatMessageHistory()
ephemeral_chat_history_2 = ChatMessageHistory()
transcript = ""
openAI = ChatOpenAI(model="gpt-4o")
groq = ChatGroq(temperature=0, model_name="llama3-8b-8192")
llama3 = ChatGroq(temperature=0, model_name="llama3-70b-8192")
output_parser =StrOutputParser()
# Define the chains
router_chain = router_template | groq | StrOutputParser()
recording_chain = recording_template | openAI | StrOutputParser()
rewriting_chain = rewriting_template | openAI | StrOutputParser()
summary_chain = summary_template | openAI | StrOutputParser()
# Routing function
def route(info):
if "recording" in info["stage"].lower():
return recording_chain
elif "rewriting" in info["stage"].lower():
return rewriting_chain
else:
return summary_chain
def clean_history(chat_history):
if len(chat_history.messages) > 1:
chat_history.messages = chat_history.messages[:-2]
return chat_history
class PostProcessingRunnable(Runnable):
def __init__(self, base_chain, post_processing_fn):
self.base_chain = base_chain
self.post_processing_fn = post_processing_fn
def invoke(self, inputs, config):
result = self.base_chain.invoke(inputs, config)
chat_history = ephemeral_chat_history_for_chain
self.post_processing_fn(chat_history)
return result
# Wrap the router chain with message history
router_chain_chat_history = RunnableWithMessageHistory(
router_chain,
#lambda session_id: ephemeral_chat_history_for_chain,
lambda session_id: ephemeral_chat_history_for_chain,
input_messages_key="input",
history_messages_key="chat_history",
)
# Wrap the subchains with message history
router_with_history = RunnableWithMessageHistory(
RunnableLambda(route),
lambda session_id: ephemeral_chat_history_for_chain,
input_messages_key="input",
history_messages_key="chat_history",
)
# Create the PostProcessingRunnable
post_processing_runnable = PostProcessingRunnable(router_chain_chat_history, clean_history)
# Create the full chain configuration
full_chain = {"stage": post_processing_runnable, "input": lambda x: x["input"]} | router_with_history
# Chatbot loop - while True for using irt_app.py instead of server
while False:
user_in = input("User: ")
response = full_chain.invoke({"input": user_in},
{"configurable": {"session_id": 'unused'}})
print("-----------------------------------------")
print("AI: " + response)
print("-----------------------------------------")