Skip to content

Commit

Permalink
Merge pull request #15 from NERC-CEH/json-retrieval
Browse files Browse the repository at this point in the history
Added JSON formatting to RAG prompt builder
  • Loading branch information
matthewcoole authored Nov 22, 2024
2 parents 430ca61 + caff400 commit b9060c6
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 8 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,4 @@ metrics.txt
metrics.png
gdrive-oauth.txt
/eval
.tmp/
1 change: 1 addition & 0 deletions data/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
/eidc_rag_testset.csv
/eidc_rag_test_set.csv
/rag-pipeline.yml
/pipeline.yml
1 change: 1 addition & 0 deletions params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ files:
eval-set: data/evaluation_data.csv
metrics: data/metrics.json
eval-plot: data/eval.png
pipeline: data/pipeline.yml
sub-sample: 0 # sample n datasets for testing (0 will use all datasets)
max-length: 0 # truncate longer texts for testing (0 will use all data)
test-set-size: 101 # reduce the size of the test set for faster testing
Expand Down
34 changes: 26 additions & 8 deletions scripts/run_rag_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import shutil
import sys
from argparse import ArgumentParser
Expand All @@ -24,14 +25,29 @@ def build_rag_pipeline(model_name: str, collection_name: str) -> Pipeline:
print("Creating prompt template...")

template = """
Given the following information, answer the question.
You are part of a retrieval augmented generative pipeline.
Your task is to provide an answer to a question based on a given set of retrieved documents.
The retrieved documents will be given in JSON format.
The retrieved documents are chunks of information retrieved from datasets held in the EIDC (Environmental Information Data Centre).
The EIDC is hosted by UKCEH (UK Centre for Ecology and Hydrology).
Your answer should be as faithful as possible to the information provided by the retrieved documents.
Do not use your own knowledge to answer the question, only the information in the retrieved documents.
Do not refer to "retrieved documents" in your answer, instead use phrases like "available information".
Provide a citation to the relevant chunk_id used to generate each part of your answer.
Question: {{query}}
Context:
{% for document in documents %}
{{ document.content }}
{% endfor %}
"retrieved_documents": [{% for document in documents %}
{
content: "{{ document.content }}",
meta: {
dataset_id: "{{ document.meta.id }}",
source: "{{ document.meta.field }}",
chunk_id: "{{ document.id }}"
}
}
{% endfor %}
]
Answer:
"""
Expand All @@ -41,7 +57,7 @@ def build_rag_pipeline(model_name: str, collection_name: str) -> Pipeline:
print(f"Setting up model ({model_name})...")
llm = OllamaGenerator(
model=model_name,
generation_kwargs={"num_ctx": 16384},
generation_kwargs={"num_ctx": 16384, "temperature": 0.0},
url="http://localhost:11434/api/generate",
)

Expand All @@ -60,6 +76,7 @@ def build_rag_pipeline(model_name: str, collection_name: str) -> Pipeline:
rag_pipe.connect("prompt_builder", "llm")

rag_pipe.connect("llm.replies", "answer_builder.replies")
rag_pipe.connect("prompt_builder.prompt", "answer_builder.query")
return rag_pipe


Expand All @@ -68,7 +85,6 @@ def run_query(query: str, pipeline: Pipeline) -> Dict[str, Any]:
{
"retriever": {"query": query},
"prompt_builder": {"query": query},
"answer_builder": {"query": query},
}
)

Expand All @@ -93,6 +109,8 @@ def main(
model: str,
pipeline_file: str,
) -> None:
if os.path.exists(TMP_DOC_PATH):
shutil.rmtree(TMP_DOC_PATH)
shutil.copytree(doc_store_path, TMP_DOC_PATH)

rag_pipe = build_rag_pipeline(model, collection_name)
Expand All @@ -109,7 +127,7 @@ def main(
df["contexts"] = contexts
df.to_csv(ouput_file, index=False)

shutil.rmtree(TMP_DOC_PATH)
# shutil.rmtree(TMP_DOC_PATH)


if __name__ == "__main__":
Expand Down
8 changes: 8 additions & 0 deletions scripts/upload_to_docstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,21 @@
import sys
import uuid
from argparse import ArgumentParser
import logging

__import__("pysqlite3")
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
import chromadb
from chromadb.utils import embedding_functions
from chromadb.utils.batch_utils import create_batches

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__)

def main(
input_file: str, output_path: str, collection_name: str, embedding_model: str
) -> None:
logger.info(f"Uploading data ({input_file}) to chromaDB ({output_path}) in collection {collection_name}.")
if os.path.exists(output_path):
shutil.rmtree(output_path)

Expand All @@ -37,16 +41,20 @@ def main(
collection = client.create_collection(
name=collection_name, embedding_function=func
)

batches = create_batches(
api=client, ids=ids, documents=docs, embeddings=embs, metadatas=metas
)
logger.info(f"Uploading {len(docs)} document(s) to chroma in {len(batches)} batch(es).")
for batch in batches:
collection.add(
documents=batch[3],
metadatas=batch[2],
embeddings=batch[1],
ids=batch[0],
)
docs_in_col = collection.count()
logger.info(f"{docs_in_col} documents(s) are now in the {collection_name} collection")


if __name__ == "__main__":
Expand Down

0 comments on commit b9060c6

Please sign in to comment.