Skip to content

Commit

Permalink
implemented load_from_spec() function, added private methods for fit() (
Browse files Browse the repository at this point in the history
#23)

* implemented load_from_spec() function, added private methods for fit()

* refactor: moved fit() to HuggingFaceLMFitter class

* fix: code modified

* feature:LMFitter and HFLMFitter added

* resolved the comments

* resolved the comments

* fix: format issue

* fix: format issue

---------

Co-authored-by: Graham Neubig <[email protected]>
  • Loading branch information
wanxinran and neubig authored Mar 28, 2024
1 parent 0b8ae64 commit 9530e40
Show file tree
Hide file tree
Showing 2 changed files with 193 additions and 23 deletions.
38 changes: 21 additions & 17 deletions llments/lm/base/hugging_face.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Module for HuggingFace language models."""

from llments.lm.lm import LanguageModel
import json


class HuggingFaceLM(LanguageModel):
Expand All @@ -27,21 +28,6 @@ def __init__(
"text-generation", model=model, device=device
)

def fit(
self, target: LanguageModel, task_description: str | None = None
) -> LanguageModel:
"""Fit the language model to a target language model's distribution.
Args:
target: The language model that should be fitted to.
task_description: A task description that explains more about
what the language model that should be fit is doing (a prompt).
Returns:
The fitted language model.
"""
raise NotImplementedError("This is not implemented yet.")

def generate(
self,
condition: str | None,
Expand Down Expand Up @@ -90,14 +76,32 @@ def set_seed(self, seed: int) -> None:
)
set_seed(seed)

def calculate_probability(self, output: str) -> float:
"""Calculate the probability of an output given the language model.
Args:
output: The output sequence for which the probability is calculated.
Returns:
float: The probability of output x given the language model.
"""
raise NotImplementedError


def load_from_spec(spec_file: str) -> HuggingFaceLM:
"""Load a language model from a specification file.
Args:
spec_file: The path to the specification file.
The file should specifies the model identifier "model" and any other relevant parameters such as "device".
Returns:
A language model.
A HuggingFaceLM instance.
"""
raise NotImplementedError("This is not implemented yet.")
with open(spec_file, "r") as file:
spec = json.load(file)

model_name = spec.get("model")
device = spec.get("device", None)

return HuggingFaceLM(model=model_name, device=device)
178 changes: 172 additions & 6 deletions llments/lm/fit.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,184 @@
"""Module for fitting language models to other language models."""
"""This module provides the necessary interfaces and functionality for working with different types of language models, integrating both custom implementations and models from the Hugging Face Transformers library.
Classes:
LanguageModel: An abstract base class for language models.
HuggingFaceLM: A wrapper class for language models from the Hugging Face library.
The module is designed to be flexible and extendable, allowing for easy integration
of additional language model types and functionalities in the future.
"""

from llments.lm.lm import LanguageModel
from llments.lm.base.hugging_face import HuggingFaceLM
from typing import Union


class LMFitter:
"""A class responsible for fitting one language model to match another.
class FitLanguageModel(LanguageModel):
"""A language model that is fitted to match another language model."""
This class provides the interface for adapting a base language model to more
closely resemble the target language model.
"""

def __init__(self, base: LanguageModel):
@classmethod
def fit(
cls, base: Union[LanguageModel, HuggingFaceLM], target: LanguageModel, **kwargs
):
"""Fit a language model to match another language model.
Args:
base: The language model to be modified.
target: The targetting language model to fit on.
**kwargs: Arguments such as batch_size, training_step, output_dir, log_dir
Returns:
LanguageModel: The fitted language model.
"""
raise NotImplementedError("This is not implemented yet.")
raise NotImplementedError


class HuggingFaceLMFitter(LMFitter):
"""A class responsible for fitting one Hugging Face language model to match another.
This class provides the interface for adapting a base language model to more
closely resemble the target language model.
"""

@classmethod
def fit(cls, base: LanguageModel, target: LanguageModel, **kwargs) -> LanguageModel:
"""Fit the language model to a target language model's distribution.
Args:
base: The HF language model to fine-tune. (delete the type identifier to pass mypy type checker)
target: The language model that should be fitted to.
**kwargs: Arguments such as batch_size, training_step, output_dir, log_dir
Returns:
The fitted language model.
"""
try:
from transformers import TrainingArguments, Trainer
except ImportError:
raise ImportError(
"You need to install 'transformers' package to use this function."
)

if not isinstance(base, HuggingFaceLM):
raise NotImplementedError(
f"Cannot fit language models of type {type(base)}"
)

batch_size = kwargs.get("batch_size", 32)
training_steps = kwargs.get("training_steps", 200)

# Generate data and prepare training dataset
inputs, labels = cls._prepare_training_data(
base, target, batch_size, training_steps
)
dataset = cls._prepare_training_dataset(inputs, labels)

num_train_epochs = training_steps / (len(dataset) / batch_size)

training_args = TrainingArguments(
output_dir=kwargs.get("output_dir", "./training_results"),
num_train_epochs=num_train_epochs,
per_device_train_batch_size=batch_size,
logging_dir=kwargs.get("log_dir", "./logs"),
logging_steps=10,
)

trainer = Trainer(
model=base.text_generator.model,
args=training_args,
train_dataset=dataset,
)

trainer.train()

return base

@classmethod
def _prepare_training_data(
cls,
base: HuggingFaceLM,
target: LanguageModel,
batch_size: int,
training_steps: int,
):
"""Generate data from the target language model, using generate() function.
Helper function of fit().
Args:
base: model to fit.
target: target language model.
batch_size: Number of examples processed in one step.
training_steps: Number of steps to train.
Returns:
inputs: Generated data (type: HF BatchEncoding): result from calling HF tokenizer.
labels: "Up shift" each token to create the labels.
"""
try:
import torch
except ImportError:
raise ImportError(
"You need to install/import 'torch' package to use this function."
)

samples = target.generate(
condition=None,
do_sample=True,
temperature=1.0,
num_return_sequences=batch_size * training_steps,
)

tokenizer = base.text_generator.tokenizer
inputs = tokenizer(
samples, padding=True, truncation=True, return_tensors="pt"
) # return pytorch tensor

labels = inputs.input_ids[:, 1:].clone()
labels = torch.nn.functional.pad(
labels, (0, 1), value=-100
) # Pad with -100 on the right

# Adjust input_ids by removing the last token to match labels' size
inputs.input_ids = inputs.input_ids[:, :-1]

return inputs, labels

@classmethod
def _prepare_training_dataset(cls, inputs, labels):
"""Return customized Dataset object, to be used in HF Trainer class.
Helper function of fit()
Args:
inputs: generate inputs
labels: labels from generate inputs
Returns:
Dataset object
"""
try:
import torch
from torch.utils.data import Dataset
except ImportError:
raise ImportError(
"You need both 'torch' and 'torch.utils.data' packages to use this function."
)

class TrainingDataset(Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels

def __getitem__(self, idx):
item = {
key: torch.tensor(val[idx]) for key, val in self.encodings.items()
}
item["labels"] = torch.tensor(self.labels[idx])
return item

def __len__(self):
return len(self.labels)

return TrainingDataset(inputs["input_ids"], labels)

0 comments on commit 9530e40

Please sign in to comment.