diff --git a/llments/lm/base/dataset_lm.py b/llments/lm/base/dataset_lm.py index 1afe835..cc9342a 100644 --- a/llments/lm/base/dataset_lm.py +++ b/llments/lm/base/dataset_lm.py @@ -2,8 +2,10 @@ import json import random +from typing import Callable import pandas as pd +import torch from llments.lm.lm import LanguageModel @@ -31,8 +33,15 @@ def generate( max_new_tokens: int | None = None, temperature: float = 1.0, num_return_sequences: int = 1, + prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], list[int]] + | None = None, ) -> list[str]: """See base class.""" + if prefix_allowed_tokens_fn is not None: + raise NotImplementedError( + "The 'prefix_allowed_tokens_fn' argument is not supported for DatasetLM." + ) + filtered_df = self.data # Adjust distribution if condition: diff --git a/llments/lm/rag.py b/llments/lm/rag.py index 82e4f07..b3cd47a 100644 --- a/llments/lm/rag.py +++ b/llments/lm/rag.py @@ -2,9 +2,13 @@ import json import os +from typing import Callable + +import torch from llments.datastore.datastore import Datastore from llments.lm.lm import LanguageModel + class RAGLanguageModel(LanguageModel): """RAGLanguageModel class for performing Retrieval Augmented Generation.""" @@ -30,29 +34,27 @@ def __init__( self.base = base self.datastore = datastore print("Loading the index...") - self.index = faiss.read_index(os.path.join(datastore.index_path, 'index'), faiss.IO_FLAG_MMAP) - self.docids = RAGLanguageModel.load_docids(os.path.join(datastore.index_path, 'docid')) + self.index = faiss.read_index( + os.path.join(datastore.index_path, "index"), faiss.IO_FLAG_MMAP + ) + self.docids = RAGLanguageModel.load_docids( + os.path.join(datastore.index_path, "docid") + ) print("Index loaded successfully!") print("Loading the document file...") self.doc_dict = self.read_jsonl_to_dict(datastore.document_path) print("Documents loaded successfully!") self.max_results = max_results - def set_max_results( - self, - max_results: int - ) -> None: + def set_max_results(self, max_results: int) -> None: """Set the max retrieval results for RAG. Args: max_results (int): The maximum retrieved results for RAG. - """ + """ self.max_results = max_results - def read_jsonl_to_dict( - self, - file_path: str - ) -> dict[str, str]: + def read_jsonl_to_dict(self, file_path: str) -> dict[str, str]: """Read JSONL file and convert it into a dictionary with document ID as keys and contents as values. Args: @@ -62,16 +64,16 @@ def read_jsonl_to_dict( dict: Dictionary containing document contents with document ID as keys. """ data_dict = {} - with open(file_path, 'r', encoding='utf-8') as file: + with open(file_path, "r", encoding="utf-8") as file: for line in file: json_data = json.loads(line) - data_dict[str(json_data[self.datastore.docid_field])] = json_data[self.datastore.fields[0]] + data_dict[str(json_data[self.datastore.docid_field])] = json_data[ + self.datastore.fields[0] + ] return data_dict - + @staticmethod - def load_docids( - file_path: str - ) -> list[str]: + def load_docids(file_path: str) -> list[str]: """Read docids and convert it into a list. Args: @@ -80,10 +82,10 @@ def load_docids( Returns: dict: List containing document IDs. """ - with open(file_path, 'r') as file: + with open(file_path, "r") as file: docids = [line.rstrip() for line in file.readlines()] return docids - + def generate( self, condition: str | None, @@ -92,6 +94,8 @@ def generate( max_new_tokens: int | None = None, temperature: float = 1, num_return_sequences: int = 1, + prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], list[int]] + | None = None, ) -> list[str]: """Generate an output given the language model. @@ -105,10 +109,16 @@ def generate( temperature: The value used to module the next token probabilities. num_return_sequences: The number of independently computed returned sequences for each element in the batch. + prefix_allowed_tokens_fn: This argument is not supported for RAGLanguageModel. Returns: str: output sequence from the language model. """ + if prefix_allowed_tokens_fn is not None: + raise NotImplementedError( + "The 'prefix_allowed_tokens_fn' argument is not supported for RAGLanguageModel." + ) + top_docs = self.datastore.retrieve( condition, index=self.index, @@ -116,11 +126,17 @@ def generate( max_results=self.max_results, ) - context = '\n'.join([self.doc_dict[str(key.docid)] for key in top_docs]) + context = "\n".join([self.doc_dict[str(key.docid)] for key in top_docs]) prompt = None if condition is not None: - prompt = "\nContext: " + context + "\nPlease answer the following question.\nQuestion: " + condition + "\nAnswer: " - + prompt = ( + "\nContext: " + + context + + "\nPlease answer the following question.\nQuestion: " + + condition + + "\nAnswer: " + ) + lm_response = self.base.generate( condition=prompt, do_sample=do_sample, @@ -131,9 +147,3 @@ def generate( ) return lm_response - - - - - -