Skip to content

Commit

Permalink
update generate() signature for RAGLanguageModel, DatasetLM
Browse files Browse the repository at this point in the history
  • Loading branch information
zaidsheikh committed Oct 25, 2024
1 parent d49037f commit da934aa
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 28 deletions.
9 changes: 9 additions & 0 deletions llments/lm/base/dataset_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import json
import random
from typing import Callable

import pandas as pd
import torch

from llments.lm.lm import LanguageModel

Expand Down Expand Up @@ -31,8 +33,15 @@ def generate(
max_new_tokens: int | None = None,
temperature: float = 1.0,
num_return_sequences: int = 1,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], list[int]]
| None = None,
) -> list[str]:
"""See base class."""
if prefix_allowed_tokens_fn is not None:
raise NotImplementedError(
"The 'prefix_allowed_tokens_fn' argument is not supported for DatasetLM."
)

filtered_df = self.data
# Adjust distribution
if condition:
Expand Down
66 changes: 38 additions & 28 deletions llments/lm/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@

import json
import os
from typing import Callable

import torch
from llments.datastore.datastore import Datastore
from llments.lm.lm import LanguageModel


class RAGLanguageModel(LanguageModel):
"""RAGLanguageModel class for performing Retrieval Augmented Generation."""

Expand All @@ -30,29 +34,27 @@ def __init__(
self.base = base
self.datastore = datastore
print("Loading the index...")
self.index = faiss.read_index(os.path.join(datastore.index_path, 'index'), faiss.IO_FLAG_MMAP)
self.docids = RAGLanguageModel.load_docids(os.path.join(datastore.index_path, 'docid'))
self.index = faiss.read_index(
os.path.join(datastore.index_path, "index"), faiss.IO_FLAG_MMAP
)
self.docids = RAGLanguageModel.load_docids(
os.path.join(datastore.index_path, "docid")
)
print("Index loaded successfully!")
print("Loading the document file...")
self.doc_dict = self.read_jsonl_to_dict(datastore.document_path)
print("Documents loaded successfully!")
self.max_results = max_results

def set_max_results(
self,
max_results: int
) -> None:
def set_max_results(self, max_results: int) -> None:
"""Set the max retrieval results for RAG.
Args:
max_results (int): The maximum retrieved results for RAG.
"""
"""
self.max_results = max_results

def read_jsonl_to_dict(
self,
file_path: str
) -> dict[str, str]:
def read_jsonl_to_dict(self, file_path: str) -> dict[str, str]:
"""Read JSONL file and convert it into a dictionary with document ID as keys and contents as values.
Args:
Expand All @@ -62,16 +64,16 @@ def read_jsonl_to_dict(
dict: Dictionary containing document contents with document ID as keys.
"""
data_dict = {}
with open(file_path, 'r', encoding='utf-8') as file:
with open(file_path, "r", encoding="utf-8") as file:
for line in file:
json_data = json.loads(line)
data_dict[str(json_data[self.datastore.docid_field])] = json_data[self.datastore.fields[0]]
data_dict[str(json_data[self.datastore.docid_field])] = json_data[
self.datastore.fields[0]
]
return data_dict

@staticmethod
def load_docids(
file_path: str
) -> list[str]:
def load_docids(file_path: str) -> list[str]:
"""Read docids and convert it into a list.
Args:
Expand All @@ -80,10 +82,10 @@ def load_docids(
Returns:
dict: List containing document IDs.
"""
with open(file_path, 'r') as file:
with open(file_path, "r") as file:
docids = [line.rstrip() for line in file.readlines()]
return docids

def generate(
self,
condition: str | None,
Expand All @@ -92,6 +94,8 @@ def generate(
max_new_tokens: int | None = None,
temperature: float = 1,
num_return_sequences: int = 1,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], list[int]]
| None = None,
) -> list[str]:
"""Generate an output given the language model.
Expand All @@ -105,22 +109,34 @@ def generate(
temperature: The value used to module the next token probabilities.
num_return_sequences: The number of independently computed returned
sequences for each element in the batch.
prefix_allowed_tokens_fn: This argument is not supported for RAGLanguageModel.
Returns:
str: output sequence from the language model.
"""
if prefix_allowed_tokens_fn is not None:
raise NotImplementedError(
"The 'prefix_allowed_tokens_fn' argument is not supported for RAGLanguageModel."
)

top_docs = self.datastore.retrieve(
condition,
index=self.index,
docids=self.docids,
max_results=self.max_results,
)

context = '\n'.join([self.doc_dict[str(key.docid)] for key in top_docs])
context = "\n".join([self.doc_dict[str(key.docid)] for key in top_docs])
prompt = None
if condition is not None:
prompt = "\nContext: " + context + "\nPlease answer the following question.\nQuestion: " + condition + "\nAnswer: "

prompt = (
"\nContext: "
+ context
+ "\nPlease answer the following question.\nQuestion: "
+ condition
+ "\nAnswer: "
)

lm_response = self.base.generate(
condition=prompt,
do_sample=do_sample,
Expand All @@ -131,9 +147,3 @@ def generate(
)

return lm_response






0 comments on commit da934aa

Please sign in to comment.