Skip to content

Commit

Permalink
Integrating structured query llm and database filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
Taniya-Das committed Jul 25, 2024
2 parents 44c30c9 + c3db9d0 commit bf5bb6e
Show file tree
Hide file tree
Showing 66 changed files with 856 additions and 596 deletions.
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,11 @@


## Example usage
- Note that in this picture, I am using a very very tiny model for demonstration purposes. The actual results would be a lot better :)
- ![Example usage](./images/search_ui.png)
- ![Example usage](./docs/images/search_ui.png)

## Where do I go from here?
### I am a developer and I want to contribute to the project
- Hello! We are glad you are here. To get started, refer to the tutorials in the [developer tutorial](./developer%20tutorials/index.md) section.
- Please refer to the documentation for
- If you have any questions, feel free to ask or post an issue.


Expand Down
4 changes: 2 additions & 2 deletions backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '0.1.0'
__version__ = "0.1.0"
from .modules.rag_llm import *
from .modules.vector_store_utils import *
from .modules.utils import *
from .modules.utils import *
1 change: 1 addition & 0 deletions backend/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@
"long_context_reorder" : false,
"structured_query": false,
"use_chroma_for_saving_metadata": false,
"chunk_size": 1000,
"chroma_metadata_dir": "../data/chroma_db_metadata"
}
2 changes: 0 additions & 2 deletions backend/modules/general_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
def find_device(training: bool = False) -> str:
"""
Description: Find the device to use for the pipeline. If cuda is available, use it. If not, check if MPS is available and use it. If not, use CPU.
"""
print("[INFO] Finding device.")
if torch.cuda.is_available():
Expand All @@ -23,7 +22,6 @@ def find_device(training: bool = False) -> str:
def load_config_and_device(config_file: str, training: bool = False) -> dict:
"""
Description: Load the config file and find the device to use for the pipeline.
"""
# Check if the config file exists and load it
if not os.path.exists(config_file):
Expand Down
73 changes: 40 additions & 33 deletions backend/modules/metadata_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
import pickle

# from pqdm.processes import pqdm
from typing import Sequence, Tuple, Union

Expand All @@ -17,41 +18,32 @@

class OpenMLObjectHandler:
"""
Description: The base class for handling OpenML objects.
Description: The base class for handling OpenML objects. The logic for handling datasets/flows are subclasses from this.
"""

def __init__(self, config):
self.config = config

def get_description(self, data_id: int):
"""
Description: Get the description of the OpenML object.
"""
raise NotImplementedError

def get_openml_objects(self):
"""
Description: Get the OpenML objects.
"""
raise NotImplementedError

def initialize_cache(self, data_id: Sequence[int]) -> None:
"""
Description: Initialize the cache for the OpenML objects.
"""
self.get_description(data_id[0])

def get_metadata(self, data_id: Sequence[int]):
"""
Description: Get metadata from OpenML using parallel processing.
"""
return pqdm(
data_id, self.get_description, n_jobs=self.config["data_download_n_jobs"]
Expand All @@ -67,16 +59,13 @@ def process_metadata(
):
"""
Description: Process the metadata.
"""
raise NotImplementedError

def load_metadata(self, file_path: str):
@staticmethod
def load_metadata(file_path: str):
"""
Description: Load metadata from a file.
"""
try:
return pd.read_csv(file_path)
Expand All @@ -85,13 +74,15 @@ def load_metadata(self, file_path: str):
"Metadata files do not exist. Please run the training pipeline first."
)

def extract_attribute(self, attribute: object, attr_name: str) -> str:
@staticmethod
def extract_attribute(attribute: object, attr_name: str) -> str:
"""
Description: Extract an attribute from the OpenML object.
"""
return getattr(attribute, attr_name, "")

def join_attributes(self, attribute: object, attr_name: str) -> str:
@staticmethod
def join_attributes(attribute: object, attr_name: str) -> str:
"""
Description: Join the attributes of the OpenML object.
"""
Expand All @@ -103,8 +94,8 @@ def join_attributes(self, attribute: object, attr_name: str) -> str:
else ""
)

@staticmethod
def create_combined_information_df_for_datasets(
self,
data_id: int | Sequence[int],
descriptions: Sequence[str],
joined_qualities: Sequence[str],
Expand All @@ -122,7 +113,8 @@ def create_combined_information_df_for_datasets(
}
)

def merge_all_columns_to_string(self, row: pd.Series) -> str:
@staticmethod
def merge_all_columns_to_string(row: pd.Series) -> str:
"""
Description: Create a single column that has a combined string of all the metadata and the description in the form of "column - value, column - value, ... description"
"""
Expand All @@ -142,8 +134,9 @@ def combine_metadata(
)
return all_dataset_metadata

@staticmethod
def subset_metadata(
self, subset_ids: Sequence[int] | None, all_dataset_metadata: pd.DataFrame
subset_ids: Sequence[int] | None, all_dataset_metadata: pd.DataFrame
):
if subset_ids is not None:
subset_ids = [int(x) for x in subset_ids]
Expand Down Expand Up @@ -177,6 +170,11 @@ def process_metadata(
file_path: str,
subset_ids=None,
):
"""
Description: Combine the metadata attributes into a single string and save it to a CSV / ChromaDB file. Subset the data if given a list of IDs to subset by.
"""

# Metadata
descriptions = [
self.extract_attribute(attr, "description") for attr in openml_data_object
]
Expand All @@ -187,6 +185,8 @@ def process_metadata(
self.join_attributes(attr, "features") for attr in openml_data_object
]

# Combine them

all_data_description_df = self.create_combined_information_df_for_datasets(
data_id, descriptions, joined_qualities, joined_features
)
Expand All @@ -197,9 +197,11 @@ def process_metadata(
# subset the metadata if subset_ids is not None
all_dataset_metadata = self.subset_metadata(subset_ids, all_dataset_metadata)

# Save to a CSV
all_dataset_metadata.to_csv(file_path)

if self.config.get("use_chroma_for_saving_metadata") == True:
# Save to chroma if needed
if self.config.get("use_chroma_for_saving_metadata"):
client = chromadb.PersistentClient(
path=self.config["persist_dir"] + "metadata_db"
)
Expand Down Expand Up @@ -263,6 +265,9 @@ def process_metadata(


class OpenMLMetadataProcessor:
"""
Description: Process metadata using the OpenMLHandlers
"""
def __init__(self, config: dict):
self.config = config
self.save_filename = os.path.join(
Expand All @@ -287,7 +292,7 @@ def get_all_metadata_from_openml(self):
"Metadata files do not exist. Please run the training pipeline first."
)
print("[INFO] Loading metadata from file.")
return self.load_metadata_from_file(self.save_filename)
return load_metadata_from_file(self.save_filename)

print("[INFO] Training is set to True.")
handler = (
Expand All @@ -311,22 +316,12 @@ def get_all_metadata_from_openml(self):
openml_data_object = handler.get_metadata(data_id)

print("[INFO] Saving metadata to file.")
self.save_metadata_to_file(
save_metadata_to_file(
(openml_data_object, data_id, all_objects, handler), self.save_filename
)

return openml_data_object, data_id, all_objects, handler

def load_metadata_from_file(self, filename: str):
# Implement the function to load metadata from a file
with open(filename, "rb") as f:
return pickle.load(f)

def save_metadata_to_file(self, data: Tuple, save_filename: str):
# Implement the function to save metadata to a file
with open(save_filename, "wb") as f:
pickle.dump(data, f)

def create_metadata_dataframe(
self,
handler: Union["OpenMLDatasetHandler", "OpenMLFlowHandler"],
Expand Down Expand Up @@ -356,3 +351,15 @@ def create_metadata_dataframe(
self.description_filename,
subset_ids,
)


def save_metadata_to_file(data: Tuple, save_filename: str):
# Implement the function to save metadata to a file
with open(save_filename, "wb") as f:
pickle.dump(data, f)


def load_metadata_from_file(filename: str):
# Implement the function to load metadata from a file
with open(filename, "rb") as f:
return pickle.load(f)
29 changes: 25 additions & 4 deletions backend/modules/rag_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,29 @@


class LLMChainInitializer:
"""
Description: Setup the vectordb (Chroma) as a retriever with parameters
"""
@staticmethod
def initialize_llm_chain(
vectordb: Chroma, config: dict
) -> langchain.chains.retrieval_qa.base.RetrievalQA:
return vectordb.as_retriever(
search_type=config["search_type"],
search_kwargs={"k": config["num_return_documents"]},
)
if config["search_type"] == "similarity_score_threshold":
return vectordb.as_retriever(
search_type=config["search_type"],
search_kwargs={"k": config["num_return_documents"], "score_threshold": 0.5},
)
else:
return vectordb.as_retriever(
search_type=config["search_type"],
search_kwargs={"k": config["num_return_documents"]},
)


class QASetup:
"""
Description: Setup the VectorDB, QA and initalize the LLM for each type of data
"""
def __init__(
self, config: dict, data_type: str, client: ClientAPI, subset_ids: list = None
):
Expand Down Expand Up @@ -65,18 +77,27 @@ def setup_vector_db_and_qa(self):


class LLMChainCreator:
"""
Description: Gets Ollama, sends query, enables query caching
"""
def __init__(self, config: dict, local: bool = False):
self.config = config
self.local = local

def get_llm_chain(self) -> LLMChain | bool:
"""
Description: Send a query to Ollama using the paths.
"""
base_url = "http://127.0.0.1:11434" if self.local else "http://ollama:11434"
llm = Ollama(model=self.config["llm_model"], base_url=base_url)
map_template = self.config["llm_prompt_template"]
map_prompt = PromptTemplate.from_template(map_template)
return map_prompt | llm | StrOutputParser()

def enable_cache(self):
"""
Description: Enable a cache for queries to prevent running the same query again for no reason.
"""
set_llm_cache(
SQLiteCache(
database_path=os.path.join(self.config["data_dir"], ".langchain.db")
Expand Down
5 changes: 3 additions & 2 deletions backend/modules/results_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
import pandas as pd
from flashrank import Ranker, RerankRequest
from langchain.chains.retrieval_qa.base import RetrievalQA
from langchain_community.document_transformers.long_context_reorder import \
LongContextReorder
from langchain_community.document_transformers.long_context_reorder import (
LongContextReorder,
)
from langchain_core.documents import BaseDocumentTransformer, Document
from tqdm import tqdm

Expand Down
Loading

0 comments on commit bf5bb6e

Please sign in to comment.