Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add constrained decoding to HuggingFaceLM #78

Merged
merged 2 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions llments/lm/base/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import abc
import os
from typing import Callable
import warnings

from litellm import ModelResponse, batch_completion, completion
import torch

from llments.lm.lm import LanguageModel

Expand Down Expand Up @@ -58,6 +60,8 @@ def generate(
max_new_tokens: int | None = None,
temperature: float = 1.0,
num_return_sequences: int = 1,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], list[int]]
| None = None,
) -> list[str]:
"""Generate a response based on the given prompt.

Expand All @@ -73,6 +77,7 @@ def generate(
max_new_tokens (float): The maximum number of tokens to generate in the chat completion.
temperature (float): The sampling temperature to be used, between 0 and 2.
num_return_sequences (int): The number of chat completion choices to generate for each input message.
prefix_allowed_tokens_fn: This argument is not supported for API-based language models.

Returns:
str: Sampled output sequences from the language model.
Expand All @@ -89,6 +94,10 @@ def generate(
warnings.warn(
"A non-default value for 'max_length' was provided.", UserWarning
)
if prefix_allowed_tokens_fn is not None:
raise NotImplementedError(
"The 'prefix_allowed_tokens_fn' argument is not supported for API-based language models."
)

responses = []
response = completion(
Expand Down
9 changes: 9 additions & 0 deletions llments/lm/base/dataset_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import json
import random
from typing import Callable

import pandas as pd
import torch

from llments.lm.lm import LanguageModel

Expand Down Expand Up @@ -31,8 +33,15 @@ def generate(
max_new_tokens: int | None = None,
temperature: float = 1.0,
num_return_sequences: int = 1,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], list[int]]
| None = None,
) -> list[str]:
"""See base class."""
if prefix_allowed_tokens_fn is not None:
raise NotImplementedError(
"The 'prefix_allowed_tokens_fn' argument is not supported for DatasetLM."
)

filtered_df = self.data
# Adjust distribution
if condition:
Expand Down
8 changes: 7 additions & 1 deletion llments/lm/base/hugging_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import json
import os
from typing import Any
from typing import Any, Callable, List, Optional
import torch

from llments.lm.lm import LanguageModel

Expand Down Expand Up @@ -72,6 +73,8 @@ def generate(
max_new_tokens: int | None = None,
temperature: float = 1.0,
num_return_sequences: int = 1,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], list[int]]
| None = None,
) -> list[str]:
"""Generate an output given the language model.

Expand All @@ -85,6 +88,8 @@ def generate(
temperature: The value used to module the next token probabilities.
num_return_sequences: The number of independently computed returned
sequences for each element in the batch.
prefix_allowed_tokens_fn: this function constraints the beam search to allowed tokens only at each step.
This function takes 2 arguments: the batch ID and input_ids and returns a list with the allowed tokens for the next generation.

Returns:
str: A sampled output sequence from the language model.
Expand All @@ -102,6 +107,7 @@ def generate(
num_return_sequences=num_return_sequences,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
)
return [
self.tokenizer.decode(output, skip_special_tokens=True)
Expand Down
7 changes: 7 additions & 0 deletions llments/lm/lm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""Base class for language models."""

import abc
from typing import Callable

import torch


class LanguageModel:
Expand Down Expand Up @@ -29,6 +32,8 @@ def generate(
max_new_tokens: int | None = None,
temperature: float = 1.0,
num_return_sequences: int = 1,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], list[int]]
| None = None,
) -> list[str]:
"""Generate an output given the language model.

Expand All @@ -42,6 +47,8 @@ def generate(
temperature: The value used to module the next token probabilities.
num_return_sequences: The number of independently computed returned
sequences for each element in the batch.
prefix_allowed_tokens_fn: this function constraints the beam search to allowed tokens only at each step.
This function takes 2 arguments: the batch ID and input_ids and returns a list with the allowed tokens for the next generation.

Returns:
str: Sampled output sequences from the language model.
Expand Down
66 changes: 38 additions & 28 deletions llments/lm/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@

import json
import os
from typing import Callable

import torch
from llments.datastore.datastore import Datastore
from llments.lm.lm import LanguageModel


class RAGLanguageModel(LanguageModel):
"""RAGLanguageModel class for performing Retrieval Augmented Generation."""

Expand All @@ -30,29 +34,27 @@ def __init__(
self.base = base
self.datastore = datastore
print("Loading the index...")
self.index = faiss.read_index(os.path.join(datastore.index_path, 'index'), faiss.IO_FLAG_MMAP)
self.docids = RAGLanguageModel.load_docids(os.path.join(datastore.index_path, 'docid'))
self.index = faiss.read_index(
os.path.join(datastore.index_path, "index"), faiss.IO_FLAG_MMAP
)
self.docids = RAGLanguageModel.load_docids(
os.path.join(datastore.index_path, "docid")
)
print("Index loaded successfully!")
print("Loading the document file...")
self.doc_dict = self.read_jsonl_to_dict(datastore.document_path)
print("Documents loaded successfully!")
self.max_results = max_results

def set_max_results(
self,
max_results: int
) -> None:
def set_max_results(self, max_results: int) -> None:
"""Set the max retrieval results for RAG.

Args:
max_results (int): The maximum retrieved results for RAG.
"""
"""
self.max_results = max_results

def read_jsonl_to_dict(
self,
file_path: str
) -> dict[str, str]:
def read_jsonl_to_dict(self, file_path: str) -> dict[str, str]:
"""Read JSONL file and convert it into a dictionary with document ID as keys and contents as values.

Args:
Expand All @@ -62,16 +64,16 @@ def read_jsonl_to_dict(
dict: Dictionary containing document contents with document ID as keys.
"""
data_dict = {}
with open(file_path, 'r', encoding='utf-8') as file:
with open(file_path, "r", encoding="utf-8") as file:
for line in file:
json_data = json.loads(line)
data_dict[str(json_data[self.datastore.docid_field])] = json_data[self.datastore.fields[0]]
data_dict[str(json_data[self.datastore.docid_field])] = json_data[
self.datastore.fields[0]
]
return data_dict

@staticmethod
def load_docids(
file_path: str
) -> list[str]:
def load_docids(file_path: str) -> list[str]:
"""Read docids and convert it into a list.

Args:
Expand All @@ -80,10 +82,10 @@ def load_docids(
Returns:
dict: List containing document IDs.
"""
with open(file_path, 'r') as file:
with open(file_path, "r") as file:
docids = [line.rstrip() for line in file.readlines()]
return docids

def generate(
self,
condition: str | None,
Expand All @@ -92,6 +94,8 @@ def generate(
max_new_tokens: int | None = None,
temperature: float = 1,
num_return_sequences: int = 1,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], list[int]]
| None = None,
) -> list[str]:
"""Generate an output given the language model.

Expand All @@ -105,22 +109,34 @@ def generate(
temperature: The value used to module the next token probabilities.
num_return_sequences: The number of independently computed returned
sequences for each element in the batch.
prefix_allowed_tokens_fn: This argument is not supported for RAGLanguageModel.

Returns:
str: output sequence from the language model.
"""
if prefix_allowed_tokens_fn is not None:
raise NotImplementedError(
"The 'prefix_allowed_tokens_fn' argument is not supported for RAGLanguageModel."
)

top_docs = self.datastore.retrieve(
condition,
index=self.index,
docids=self.docids,
max_results=self.max_results,
)

context = '\n'.join([self.doc_dict[str(key.docid)] for key in top_docs])
context = "\n".join([self.doc_dict[str(key.docid)] for key in top_docs])
prompt = None
if condition is not None:
prompt = "\nContext: " + context + "\nPlease answer the following question.\nQuestion: " + condition + "\nAnswer: "

prompt = (
"\nContext: "
+ context
+ "\nPlease answer the following question.\nQuestion: "
+ condition
+ "\nAnswer: "
)

lm_response = self.base.generate(
condition=prompt,
do_sample=do_sample,
Expand All @@ -131,9 +147,3 @@ def generate(
)

return lm_response






1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ classifiers = [
dependencies = [
"pandas",
"tqdm",
"litellm",
]
dynamic = ["version"]

Expand Down
Loading