From ad11f8053f5be6f82f222d49573fb7470e3b6a5c Mon Sep 17 00:00:00 2001 From: Zaid Sheikh Date: Fri, 25 Oct 2024 08:35:00 -0400 Subject: [PATCH] feat: add constrained decoding to HuggingFaceLM (#78) * feat: add constrained decoding to HuggingFaceLM * update generate() signature for RAGLanguageModel, DatasetLM --- llments/lm/base/api.py | 9 +++++ llments/lm/base/dataset_lm.py | 9 +++++ llments/lm/base/hugging_face.py | 8 +++- llments/lm/lm.py | 7 ++++ llments/lm/rag.py | 66 +++++++++++++++++++-------------- pyproject.toml | 1 + 6 files changed, 71 insertions(+), 29 deletions(-) diff --git a/llments/lm/base/api.py b/llments/lm/base/api.py index 525cf54..68d1b7c 100644 --- a/llments/lm/base/api.py +++ b/llments/lm/base/api.py @@ -2,9 +2,11 @@ import abc import os +from typing import Callable import warnings from litellm import ModelResponse, batch_completion, completion +import torch from llments.lm.lm import LanguageModel @@ -58,6 +60,8 @@ def generate( max_new_tokens: int | None = None, temperature: float = 1.0, num_return_sequences: int = 1, + prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], list[int]] + | None = None, ) -> list[str]: """Generate a response based on the given prompt. @@ -73,6 +77,7 @@ def generate( max_new_tokens (float): The maximum number of tokens to generate in the chat completion. temperature (float): The sampling temperature to be used, between 0 and 2. num_return_sequences (int): The number of chat completion choices to generate for each input message. + prefix_allowed_tokens_fn: This argument is not supported for API-based language models. Returns: str: Sampled output sequences from the language model. @@ -89,6 +94,10 @@ def generate( warnings.warn( "A non-default value for 'max_length' was provided.", UserWarning ) + if prefix_allowed_tokens_fn is not None: + raise NotImplementedError( + "The 'prefix_allowed_tokens_fn' argument is not supported for API-based language models." + ) responses = [] response = completion( diff --git a/llments/lm/base/dataset_lm.py b/llments/lm/base/dataset_lm.py index 1afe835..cc9342a 100644 --- a/llments/lm/base/dataset_lm.py +++ b/llments/lm/base/dataset_lm.py @@ -2,8 +2,10 @@ import json import random +from typing import Callable import pandas as pd +import torch from llments.lm.lm import LanguageModel @@ -31,8 +33,15 @@ def generate( max_new_tokens: int | None = None, temperature: float = 1.0, num_return_sequences: int = 1, + prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], list[int]] + | None = None, ) -> list[str]: """See base class.""" + if prefix_allowed_tokens_fn is not None: + raise NotImplementedError( + "The 'prefix_allowed_tokens_fn' argument is not supported for DatasetLM." + ) + filtered_df = self.data # Adjust distribution if condition: diff --git a/llments/lm/base/hugging_face.py b/llments/lm/base/hugging_face.py index 7b0af73..e705934 100644 --- a/llments/lm/base/hugging_face.py +++ b/llments/lm/base/hugging_face.py @@ -2,7 +2,8 @@ import json import os -from typing import Any +from typing import Any, Callable, List, Optional +import torch from llments.lm.lm import LanguageModel @@ -72,6 +73,8 @@ def generate( max_new_tokens: int | None = None, temperature: float = 1.0, num_return_sequences: int = 1, + prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], list[int]] + | None = None, ) -> list[str]: """Generate an output given the language model. @@ -85,6 +88,8 @@ def generate( temperature: The value used to module the next token probabilities. num_return_sequences: The number of independently computed returned sequences for each element in the batch. + prefix_allowed_tokens_fn: this function constraints the beam search to allowed tokens only at each step. + This function takes 2 arguments: the batch ID and input_ids and returns a list with the allowed tokens for the next generation. Returns: str: A sampled output sequence from the language model. @@ -102,6 +107,7 @@ def generate( num_return_sequences=num_return_sequences, do_sample=do_sample, max_new_tokens=max_new_tokens, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, ) return [ self.tokenizer.decode(output, skip_special_tokens=True) diff --git a/llments/lm/lm.py b/llments/lm/lm.py index d421923..e37fc96 100644 --- a/llments/lm/lm.py +++ b/llments/lm/lm.py @@ -1,6 +1,9 @@ """Base class for language models.""" import abc +from typing import Callable + +import torch class LanguageModel: @@ -29,6 +32,8 @@ def generate( max_new_tokens: int | None = None, temperature: float = 1.0, num_return_sequences: int = 1, + prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], list[int]] + | None = None, ) -> list[str]: """Generate an output given the language model. @@ -42,6 +47,8 @@ def generate( temperature: The value used to module the next token probabilities. num_return_sequences: The number of independently computed returned sequences for each element in the batch. + prefix_allowed_tokens_fn: this function constraints the beam search to allowed tokens only at each step. + This function takes 2 arguments: the batch ID and input_ids and returns a list with the allowed tokens for the next generation. Returns: str: Sampled output sequences from the language model. diff --git a/llments/lm/rag.py b/llments/lm/rag.py index 82e4f07..b3cd47a 100644 --- a/llments/lm/rag.py +++ b/llments/lm/rag.py @@ -2,9 +2,13 @@ import json import os +from typing import Callable + +import torch from llments.datastore.datastore import Datastore from llments.lm.lm import LanguageModel + class RAGLanguageModel(LanguageModel): """RAGLanguageModel class for performing Retrieval Augmented Generation.""" @@ -30,29 +34,27 @@ def __init__( self.base = base self.datastore = datastore print("Loading the index...") - self.index = faiss.read_index(os.path.join(datastore.index_path, 'index'), faiss.IO_FLAG_MMAP) - self.docids = RAGLanguageModel.load_docids(os.path.join(datastore.index_path, 'docid')) + self.index = faiss.read_index( + os.path.join(datastore.index_path, "index"), faiss.IO_FLAG_MMAP + ) + self.docids = RAGLanguageModel.load_docids( + os.path.join(datastore.index_path, "docid") + ) print("Index loaded successfully!") print("Loading the document file...") self.doc_dict = self.read_jsonl_to_dict(datastore.document_path) print("Documents loaded successfully!") self.max_results = max_results - def set_max_results( - self, - max_results: int - ) -> None: + def set_max_results(self, max_results: int) -> None: """Set the max retrieval results for RAG. Args: max_results (int): The maximum retrieved results for RAG. - """ + """ self.max_results = max_results - def read_jsonl_to_dict( - self, - file_path: str - ) -> dict[str, str]: + def read_jsonl_to_dict(self, file_path: str) -> dict[str, str]: """Read JSONL file and convert it into a dictionary with document ID as keys and contents as values. Args: @@ -62,16 +64,16 @@ def read_jsonl_to_dict( dict: Dictionary containing document contents with document ID as keys. """ data_dict = {} - with open(file_path, 'r', encoding='utf-8') as file: + with open(file_path, "r", encoding="utf-8") as file: for line in file: json_data = json.loads(line) - data_dict[str(json_data[self.datastore.docid_field])] = json_data[self.datastore.fields[0]] + data_dict[str(json_data[self.datastore.docid_field])] = json_data[ + self.datastore.fields[0] + ] return data_dict - + @staticmethod - def load_docids( - file_path: str - ) -> list[str]: + def load_docids(file_path: str) -> list[str]: """Read docids and convert it into a list. Args: @@ -80,10 +82,10 @@ def load_docids( Returns: dict: List containing document IDs. """ - with open(file_path, 'r') as file: + with open(file_path, "r") as file: docids = [line.rstrip() for line in file.readlines()] return docids - + def generate( self, condition: str | None, @@ -92,6 +94,8 @@ def generate( max_new_tokens: int | None = None, temperature: float = 1, num_return_sequences: int = 1, + prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], list[int]] + | None = None, ) -> list[str]: """Generate an output given the language model. @@ -105,10 +109,16 @@ def generate( temperature: The value used to module the next token probabilities. num_return_sequences: The number of independently computed returned sequences for each element in the batch. + prefix_allowed_tokens_fn: This argument is not supported for RAGLanguageModel. Returns: str: output sequence from the language model. """ + if prefix_allowed_tokens_fn is not None: + raise NotImplementedError( + "The 'prefix_allowed_tokens_fn' argument is not supported for RAGLanguageModel." + ) + top_docs = self.datastore.retrieve( condition, index=self.index, @@ -116,11 +126,17 @@ def generate( max_results=self.max_results, ) - context = '\n'.join([self.doc_dict[str(key.docid)] for key in top_docs]) + context = "\n".join([self.doc_dict[str(key.docid)] for key in top_docs]) prompt = None if condition is not None: - prompt = "\nContext: " + context + "\nPlease answer the following question.\nQuestion: " + condition + "\nAnswer: " - + prompt = ( + "\nContext: " + + context + + "\nPlease answer the following question.\nQuestion: " + + condition + + "\nAnswer: " + ) + lm_response = self.base.generate( condition=prompt, do_sample=do_sample, @@ -131,9 +147,3 @@ def generate( ) return lm_response - - - - - - diff --git a/pyproject.toml b/pyproject.toml index ccfd3e1..3089754 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ classifiers = [ dependencies = [ "pandas", "tqdm", + "litellm", ] dynamic = ["version"]