Skip to content

Commit

Permalink
feat: add constrained decoding to HuggingFaceLM (#78)
Browse files Browse the repository at this point in the history
* feat: add constrained decoding to HuggingFaceLM

* update generate() signature for RAGLanguageModel, DatasetLM
  • Loading branch information
zaidsheikh authored Oct 25, 2024
1 parent b6815e9 commit ad11f80
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 29 deletions.
9 changes: 9 additions & 0 deletions llments/lm/base/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import abc
import os
from typing import Callable
import warnings

from litellm import ModelResponse, batch_completion, completion
import torch

from llments.lm.lm import LanguageModel

Expand Down Expand Up @@ -58,6 +60,8 @@ def generate(
max_new_tokens: int | None = None,
temperature: float = 1.0,
num_return_sequences: int = 1,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], list[int]]
| None = None,
) -> list[str]:
"""Generate a response based on the given prompt.
Expand All @@ -73,6 +77,7 @@ def generate(
max_new_tokens (float): The maximum number of tokens to generate in the chat completion.
temperature (float): The sampling temperature to be used, between 0 and 2.
num_return_sequences (int): The number of chat completion choices to generate for each input message.
prefix_allowed_tokens_fn: This argument is not supported for API-based language models.
Returns:
str: Sampled output sequences from the language model.
Expand All @@ -89,6 +94,10 @@ def generate(
warnings.warn(
"A non-default value for 'max_length' was provided.", UserWarning
)
if prefix_allowed_tokens_fn is not None:
raise NotImplementedError(
"The 'prefix_allowed_tokens_fn' argument is not supported for API-based language models."
)

responses = []
response = completion(
Expand Down
9 changes: 9 additions & 0 deletions llments/lm/base/dataset_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import json
import random
from typing import Callable

import pandas as pd
import torch

from llments.lm.lm import LanguageModel

Expand Down Expand Up @@ -31,8 +33,15 @@ def generate(
max_new_tokens: int | None = None,
temperature: float = 1.0,
num_return_sequences: int = 1,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], list[int]]
| None = None,
) -> list[str]:
"""See base class."""
if prefix_allowed_tokens_fn is not None:
raise NotImplementedError(
"The 'prefix_allowed_tokens_fn' argument is not supported for DatasetLM."
)

filtered_df = self.data
# Adjust distribution
if condition:
Expand Down
8 changes: 7 additions & 1 deletion llments/lm/base/hugging_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import json
import os
from typing import Any
from typing import Any, Callable, List, Optional
import torch

from llments.lm.lm import LanguageModel

Expand Down Expand Up @@ -72,6 +73,8 @@ def generate(
max_new_tokens: int | None = None,
temperature: float = 1.0,
num_return_sequences: int = 1,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], list[int]]
| None = None,
) -> list[str]:
"""Generate an output given the language model.
Expand All @@ -85,6 +88,8 @@ def generate(
temperature: The value used to module the next token probabilities.
num_return_sequences: The number of independently computed returned
sequences for each element in the batch.
prefix_allowed_tokens_fn: this function constraints the beam search to allowed tokens only at each step.
This function takes 2 arguments: the batch ID and input_ids and returns a list with the allowed tokens for the next generation.
Returns:
str: A sampled output sequence from the language model.
Expand All @@ -102,6 +107,7 @@ def generate(
num_return_sequences=num_return_sequences,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
)
return [
self.tokenizer.decode(output, skip_special_tokens=True)
Expand Down
7 changes: 7 additions & 0 deletions llments/lm/lm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""Base class for language models."""

import abc
from typing import Callable

import torch


class LanguageModel:
Expand Down Expand Up @@ -29,6 +32,8 @@ def generate(
max_new_tokens: int | None = None,
temperature: float = 1.0,
num_return_sequences: int = 1,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], list[int]]
| None = None,
) -> list[str]:
"""Generate an output given the language model.
Expand All @@ -42,6 +47,8 @@ def generate(
temperature: The value used to module the next token probabilities.
num_return_sequences: The number of independently computed returned
sequences for each element in the batch.
prefix_allowed_tokens_fn: this function constraints the beam search to allowed tokens only at each step.
This function takes 2 arguments: the batch ID and input_ids and returns a list with the allowed tokens for the next generation.
Returns:
str: Sampled output sequences from the language model.
Expand Down
66 changes: 38 additions & 28 deletions llments/lm/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@

import json
import os
from typing import Callable

import torch
from llments.datastore.datastore import Datastore
from llments.lm.lm import LanguageModel


class RAGLanguageModel(LanguageModel):
"""RAGLanguageModel class for performing Retrieval Augmented Generation."""

Expand All @@ -30,29 +34,27 @@ def __init__(
self.base = base
self.datastore = datastore
print("Loading the index...")
self.index = faiss.read_index(os.path.join(datastore.index_path, 'index'), faiss.IO_FLAG_MMAP)
self.docids = RAGLanguageModel.load_docids(os.path.join(datastore.index_path, 'docid'))
self.index = faiss.read_index(
os.path.join(datastore.index_path, "index"), faiss.IO_FLAG_MMAP
)
self.docids = RAGLanguageModel.load_docids(
os.path.join(datastore.index_path, "docid")
)
print("Index loaded successfully!")
print("Loading the document file...")
self.doc_dict = self.read_jsonl_to_dict(datastore.document_path)
print("Documents loaded successfully!")
self.max_results = max_results

def set_max_results(
self,
max_results: int
) -> None:
def set_max_results(self, max_results: int) -> None:
"""Set the max retrieval results for RAG.
Args:
max_results (int): The maximum retrieved results for RAG.
"""
"""
self.max_results = max_results

def read_jsonl_to_dict(
self,
file_path: str
) -> dict[str, str]:
def read_jsonl_to_dict(self, file_path: str) -> dict[str, str]:
"""Read JSONL file and convert it into a dictionary with document ID as keys and contents as values.
Args:
Expand All @@ -62,16 +64,16 @@ def read_jsonl_to_dict(
dict: Dictionary containing document contents with document ID as keys.
"""
data_dict = {}
with open(file_path, 'r', encoding='utf-8') as file:
with open(file_path, "r", encoding="utf-8") as file:
for line in file:
json_data = json.loads(line)
data_dict[str(json_data[self.datastore.docid_field])] = json_data[self.datastore.fields[0]]
data_dict[str(json_data[self.datastore.docid_field])] = json_data[
self.datastore.fields[0]
]
return data_dict

@staticmethod
def load_docids(
file_path: str
) -> list[str]:
def load_docids(file_path: str) -> list[str]:
"""Read docids and convert it into a list.
Args:
Expand All @@ -80,10 +82,10 @@ def load_docids(
Returns:
dict: List containing document IDs.
"""
with open(file_path, 'r') as file:
with open(file_path, "r") as file:
docids = [line.rstrip() for line in file.readlines()]
return docids

def generate(
self,
condition: str | None,
Expand All @@ -92,6 +94,8 @@ def generate(
max_new_tokens: int | None = None,
temperature: float = 1,
num_return_sequences: int = 1,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], list[int]]
| None = None,
) -> list[str]:
"""Generate an output given the language model.
Expand All @@ -105,22 +109,34 @@ def generate(
temperature: The value used to module the next token probabilities.
num_return_sequences: The number of independently computed returned
sequences for each element in the batch.
prefix_allowed_tokens_fn: This argument is not supported for RAGLanguageModel.
Returns:
str: output sequence from the language model.
"""
if prefix_allowed_tokens_fn is not None:
raise NotImplementedError(
"The 'prefix_allowed_tokens_fn' argument is not supported for RAGLanguageModel."
)

top_docs = self.datastore.retrieve(
condition,
index=self.index,
docids=self.docids,
max_results=self.max_results,
)

context = '\n'.join([self.doc_dict[str(key.docid)] for key in top_docs])
context = "\n".join([self.doc_dict[str(key.docid)] for key in top_docs])
prompt = None
if condition is not None:
prompt = "\nContext: " + context + "\nPlease answer the following question.\nQuestion: " + condition + "\nAnswer: "

prompt = (
"\nContext: "
+ context
+ "\nPlease answer the following question.\nQuestion: "
+ condition
+ "\nAnswer: "
)

lm_response = self.base.generate(
condition=prompt,
do_sample=do_sample,
Expand All @@ -131,9 +147,3 @@ def generate(
)

return lm_response






1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ classifiers = [
dependencies = [
"pandas",
"tqdm",
"litellm",
]
dynamic = ["version"]

Expand Down

0 comments on commit ad11f80

Please sign in to comment.