forked from dheerajrhegde/PrecisionFarming
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathRetrievalGraph.py
executable file
·322 lines (250 loc) · 11.5 KB
/
RetrievalGraph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
from typing import List
from typing_extensions import TypedDict
import pprint
import os
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain.schema import Document
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_upstage import UpstageGroundednessCheck
from langchain.chains.query_constructor.base import AttributeInfo
from langgraph.graph import END, START, StateGraph
from trulens.apps.langchain import WithFeedbackFilterDocuments
from trulens.core import Feedback, TruSession
from trulens.providers.openai import OpenAI
from trulens.apps.langchain import TruChain
from langchain.load import dumps, loads
class GradeDocuments(BaseModel):
"""Binary score for relevance check on retrieved documents."""
binary_score: str = Field(
description="Documents are relevant to the question, 'yes' or 'no'"
)
class GraphState(TypedDict):
"""
Represents the state of our graph.
Attributes:
question: question
generation: LLM generation
web_search: whether to add search
documents: list of documents
"""
crop: str
question: str
generation: str
web_search: str
documents: List[str]
groundedness: str
class RetrievalGraph:
def __init__(self):
# Initialize Tavily
self.web_search_tool = TavilySearchResults(k=3)
self.llm = ChatOpenAI(model_name="gpt-4o", temperature=0)
# Get access to Chroma vector store that has NC state agriculture information
openai_api_key = os.getenv("OPENAI_API_KEY")
openai_api_version = "2023-05-15"
model = "text-embedding-ada-002"
vector_store_address = os.getenv("AZURE_SEARCH_ENDPOINT")
vector_store_password = os.getenv("AZURE_SEARCH_ADMIN_KEY")
print(vector_store_password)
embeddings: OpenAIEmbeddings = OpenAIEmbeddings(
openai_api_key=openai_api_key, openai_api_version=openai_api_version, model=model
)
from langchain_community.vectorstores.azuresearch import AzureSearch
index_name: str = "crop_guide"
self.vectorstore = AzureSearch(
azure_search_endpoint=vector_store_address,
azure_search_key=vector_store_password,
index_name=index_name,
embedding_function=embeddings.embed_query,
)
# RAG Chain for checking relevance of retrieved documents
prompt = hub.pull("rlm/rag-prompt")
print(prompt)
self.rag_chain = prompt | self.llm | StrOutputParser()
# Prompt
system = """You a question re-writer that converts an input question to a better version that is optimized \n
for web search. Look at the input and try to reason about the underlying semantic intent / meaning."""
re_write_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
(
"human",
"Here is the initial question: \n\n {question} \n Formulate an improved question.",
),
]
)
self.question_rewriter = re_write_prompt | self.llm | StrOutputParser()
workflow = StateGraph(GraphState)
# Define the nodes
workflow.add_node("retrieve", self.retrieve) # retrieve with content relevance score
workflow.add_node("generate", self.generate) # generate
workflow.add_node("transform_query", self.transform_query) # transform_query
workflow.add_node("web_search_node", self.web_search) # web search
# Build graph
workflow.add_edge(START, "retrieve")
workflow.add_conditional_edges(
"retrieve",
self.nothing_retrieved,
{
"web_search": "web_search_node",
"generate": "generate",
},
)
workflow.add_edge("web_search_node", "generate")
workflow.add_conditional_edges(
"generate",
self.not_grounded,
{
"notGrounded": "transform_query",
"notSure": "transform_query",
"grounded": END
}
)
workflow.add_edge("transform_query", "retrieve")
# Compile
self.app = workflow.compile()
pprint.pprint(self.app.get_graph().draw_ascii())
def invoke(self, question, crop):
os.environ["LANGCHAIN_TRACING_V2"] = "True"
os.environ["LANGCHAIN_PROJECT"] = "RetrievalGraph"
return self.app.invoke({"question": question, "crop":"crop"})["generation"]
def retrieve(self, state):
question = state["question"]
print(question)
provider = OpenAI()
f_context_relevance_score = Feedback(provider.context_relevance)
retriever = self.vectorstore.as_retriever(search_type="similarity_score_threshold", search_kwargs={"score_threshold":0.75})
metadata_field_info = [
AttributeInfo(
name="crop",
description="The crop on which the question is asked",
type="string",
),
]
"""retriever = SelfQueryRetriever.from_llm(
llm=self.llm, vectorstore=self.vectorstore, metadata_field_info=metadata_field_info, verbose=True,
document_contents="information on crops"
)"""
filtered_retriever = WithFeedbackFilterDocuments.of_retriever(
retriever=retriever, feedback=f_context_relevance_score, threshold=0.75
)
template = """You are an AI language model assistant. Your task is to break down the larger question
you get into smaller subquestions to do a vector store retrieval on.
Provide a list of subquestions that can be used to search the web for more information.
Original question: {question}
Crop: {crop}
"""
prompt_sub_q = ChatPromptTemplate.from_template(template)
from langchain_core.output_parsers import StrOutputParser
from langchain_openai import ChatOpenAI
generate_queries = (
prompt_sub_q
| ChatOpenAI(temperature=0)
| StrOutputParser()
| (lambda x: x.split("\n"))
)
#retrieval_chain = generate_queries | map(filtered_retriever.get_relevant_documents) | self.get_unique_union
questions = generate_queries.invoke({"question": question, "crop": state["crop"]})
print("questions asked ", questions)
retrieved_docs = []
for question in questions:
docs = filtered_retriever.get_relevant_documents(question)
print("question", question)
print("docs", docs)
retrieved_docs.append(docs[:])
print("retrieved documents ...", retrieved_docs)
docs = self.get_unique_union(retrieved_docs)
print("documents retrieved from the vector store are", docs)
return {"documents": docs}
def generate(self, state):
question = state["question"]
documents = state["documents"]
provider = OpenAI()
generation = self.rag_chain.invoke({"context": documents, "question": question})
groundedness_check = UpstageGroundednessCheck()
request_input = {
"context": documents,
"answer": generation,
}
response = groundedness_check.invoke(request_input)
print("Groundedness response: ", response)
return {"documents": documents, "question": question, "generation": generation, "groundedness": response}
def transform_query(self, state):
"""
Transform the query to produce a better question.
Args:
state (dict): The current graph state
Returns:
state (dict): Updates question key with a re-phrased question
"""
question = state["question"]
documents = state["documents"]
# Re-write question
better_question = self.question_rewriter.invoke({"question": question})
return {"documents": documents, "question": better_question}
def get_unique_union(self, documents: list[list]):
""" Unique union of retrieved docs """
# Flatten list of lists, and convert each Document to string
flattened_docs = [dumps(doc) for sublist in documents for doc in sublist]
# Get unique documents
unique_docs = list(set(flattened_docs))
# Return
return [loads(doc) for doc in unique_docs]
def web_search(self, state):
question = state["question"]
documents = state["documents"]
template = """You are an AI language model assistant. Your task is to break down the larger question
you get into smaller subquestions to do a web search on.
Provide a list of subquestions that can be used to search the web for more information.
Original question: {question}"""
prompt_sub_q = ChatPromptTemplate.from_template(template)
from langchain_core.output_parsers import StrOutputParser
from langchain_openai import ChatOpenAI
generate_queries = (
prompt_sub_q
| ChatOpenAI(temperature=0)
| StrOutputParser()
| (lambda x: x.split("\n"))
)
retrieval_chain = generate_queries | self.web_search_tool.map() | self.get_unique_union
docs = retrieval_chain.invoke({"question": question})
# Web search
print("Web search for: ", question)
#docs = self.web_search_tool.invoke({"query": question})
print(type(docs), docs)
web_results = "\n".join([d["content"] for d in docs if isinstance(d, dict)])
web_results = Document(page_content=web_results)
documents.append(web_results)
return {"documents": documents, "question": question}
def nothing_retrieved(self, state):
documents = state["documents"]
if len(documents) == 0:
return "web_search"
else:
return "generate"
def not_grounded(self, state):
return state["groundedness"]
if __name__ == "__main__":
graph = RetrievalGraph()
state = graph.invoke("""
You are an agricultural pest management expert is a professional with specialized knowledge in entomology,
plant pathology, and crop protection.
A farmer has come to you with a disease effeecting his/her crop.
The farmer is growing corn.
The farmer has noticed caterpillar insect on the crop.
His farm's current and next few days weather is sunny.
His farm's soil moisture is 30. And his irrigation plan is none.
You need to provide the farmer with the following information:
1. Insights on the insect, how it effects the plant and its yield
2. What factors support insect habitation in your crop field
3. Now that the insects are present, how to remediate it? Include specific informaiton
- On what pesticides to use, when to apply given the weather, moisture and irrigation plan
- explain your reasoning for the timing. Provide reference to the weather and moisture levels and you used it in your reasoning
- give dates when the pesticides should be applied
- Where to get the pesticides from
- Give the websites where the farmer can buy the pesticides
""", crop="corn")
print(state)