-
Notifications
You must be signed in to change notification settings - Fork 1
/
query.py
115 lines (86 loc) · 3.5 KB
/
query.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""
Python script that queries the RAG model with a given query and returns the response.
The script connects to the database, searches for the query in the database, builds the prompt, and generates the response using the Ollama model.
The response includes the generated text, sources, original query, and context.
The script can be run from the command line with the query_text argument.
Example:
python query_rag.py --query_text "What is the capital of France?"
The script can also include sources and context in the response using the include_sources and include_context arguments.
Example:
python query_rag.py --query_text "What is the capital of France?" --include_sources --include_context
"""
import argparse
from typing import Dict, Union
import os
from embedding_function import get_embedding_function
from templates.load_jinja_template import load_jinja2_prompt
from dotenv import load_dotenv
from langchain_chroma import Chroma
from langchain_community.llms.ollama import Ollama
# Load the environment variables
load_dotenv()
CHROMA_PATH = os.getenv("CHROMA_PATH")
GENERATOR_MODEL = os.getenv("GENERATOR_MODEL")
OLLAMA_URL = os.getenv("OLLAMA_URL", "http://ollama:11434")
def query_rag(query_text: str) -> Dict[str, Union[str, Dict]]:
"""
Queries the RAG model with the given query and returns the response.
Args:
query_text (str): The query to be passed to the RAG model.
Returns:
str: The response from the RAG model.
"""
# Connect to the database
print("🔗 Connecting to the database...")
embedding_function = get_embedding_function()
db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embedding_function)
# Search in the database
print("🔍 Searching in the database...")
results = db.similarity_search_with_score(query_text, k=5)
# Build the prompt
print("🔮 Building the prompt ...")
context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in results])
prompt = load_jinja2_prompt(context=context_text, question=query_text)
print("🍳 Generating the response...")
model = Ollama(model=GENERATOR_MODEL,
base_url = OLLAMA_URL)
response_text = model.invoke(prompt)
sources = [doc.metadata.get("id", None) for doc, _score in results]
response = {
"response_text": response_text,
"sources": sources,
"original_query": query_text,
"context": context_text,
"prompt": prompt,
}
return response
def main():
parser = argparse.ArgumentParser(
description="Query the RAG model with a given query."
)
parser.add_argument(
"--query_text", type=str, help="The query to be passed to the RAG model."
)
parser.add_argument(
"--include_sources",
action=argparse.BooleanOptionalAction,
help="Include sources in the response.",
)
parser.add_argument(
"--include_context",
action=argparse.BooleanOptionalAction,
help="Include context in the response.",
)
args = parser.parse_args()
query_text = args.query_text
include_sources = args.include_sources
include_context = args.include_context
response = query_rag(query_text)
response_text = f"🤖 Response: {response['response_text']}"
if include_sources:
response_text += f"\n\n\n 📜Sources: {response['sources']}"
if include_context:
response_text += f"\n\n\n 🌄Context: {response['context']}"
print(response_text)
if __name__ == "__main__":
main()