Skip to content

Commit

Permalink
Peft-compatible support for continued training
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanik12 committed Mar 30, 2024
1 parent 7504fb0 commit c659dea
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 45 deletions.
68 changes: 57 additions & 11 deletions adaptor/adapter.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import copy
import logging
import os
from typing import List, Dict, Tuple, Union, Optional

from transformers import WEIGHTS_NAME
from peft import PeftModel
from transformers import WEIGHTS_NAME, TrainerState
import torch
from transformers import Trainer, BatchEncoding
from transformers.modeling_utils import unwrap_model
from transformers.trainer import TRAINER_STATE_NAME

from .lang_module import LangModule
from .schedules import Schedule
Expand Down Expand Up @@ -48,7 +51,7 @@ def __init__(self, lang_module: LangModule, schedule: Schedule, args: Adaptation
train_dataset=self.schedule.iterable_dataset(split="train"),
eval_dataset=self.schedule.iterable_dataset(split="eval"),
data_collator=self.flattened_collator,
compute_metrics=None, # would require a static prediction format among objectives
compute_metrics=None, # logged metrics are handled by Objectives
callbacks=orig_callbacks + [schedule.should_stop_check_callback()],
**kwargs)

Expand Down Expand Up @@ -95,10 +98,43 @@ def evaluate(self, *args, **kwargs) -> Dict[str, float]:

return out

def _save_module(self, module: torch.nn.Module, output_module_path: str) -> None:
# simple wrapper to save an arbitrary model to a directory in a standard format
# for each objective, we also persist a shared tokenizer to make each Objective independently loadable
self.model.tokenizer.save_pretrained(output_module_path)

if hasattr(module, "save_pretrained") or hasattr(unwrap_model(module), "save_pretrained"):
# if the head module has "save_pretrained" method, it will be called for persistence
module.save_pretrained(output_module_path, use_diff=False, safe_serialization=False)
else:
# otherwise, we persist only a raw pytorch module
torch.save(module.state_dict(), os.path.join(output_module_path, WEIGHTS_NAME))

def save_model(self, output_dir: Optional[str] = None, **kwargs) -> None:
# HF native reload compatibility
objectives_counter = {str(obj): 0 for obj in self.schedule.objectives["train"].values()}

os.makedirs(output_dir, exist_ok=True)

# also save the base model, if any of our objectives are peft models
if (self.args.save_peft_base_model and
any(isinstance(o.compatible_head_model, PeftModel) for o in self.schedule.objectives["train"].values())):
# For simplicity, we assume that base models for all pefts are the same
# -- this might be violated only if the user passes custom model_head to Objective
# and additionally creates a peft module on it.
# Thus, we retrieve a base model from an arbitrary (i.e. the first) peft-model objective
peft_obj = next(o for o in self.schedule.objectives["train"].values()
if isinstance(o.compatible_head_model, PeftModel))

orig_model = copy.deepcopy(peft_obj.compatible_head_model)
while isinstance(orig_model, PeftModel):
# we find cases where unload() does not return the base model on the first call
orig_model = orig_model.unload()

base_model_path = os.path.join(output_dir, "base_model")
self._save_module(orig_model, base_model_path)
logger.info(f"Base model for PEFT objectives saved in {base_model_path}")

for objective_id in self.schedule.objectives["train"].keys():
module = self.model.trainable_models[str(objective_id)]
objective = self.schedule.objectives["train"][int(objective_id)]
Expand All @@ -109,15 +145,25 @@ def save_model(self, output_dir: Optional[str] = None, **kwargs) -> None:
output_module_path += "_{}".format(objectives_counter[str(objective)])
objectives_counter[str(objective)] += 1

# we persist a shared tokenizer and training args either way
self.model.tokenizer.save_pretrained(output_module_path)
# training args are shared and persisted in the output_dir root
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
if isinstance(module, PeftModel) and self.args.save_peft_base_model:
base_model_path = os.path.abspath(os.path.join(output_dir, "base_model"))
module.peft_config['default'].base_model_name_or_path = base_model_path
logger.warning("Base model for PEFT objective %s set to %s", objective, base_model_path)

if hasattr(module, "save_pretrained") or hasattr(unwrap_model(module), "save_pretrained"):
# if the head module has "save_pretrained" method, it will be called for persistence
module.save_pretrained(output_module_path, use_diff=True)
else:
# otherwise, we persist only a raw pytorch module
torch.save(module.state_dict(), os.path.join(output_module_path, WEIGHTS_NAME))

self._save_module(module, output_module_path)
logger.info(f"Model of objective {str(objective)} saved in {output_module_path}")

def _load_optimizer_and_scheduler(self, checkpoint: str) -> None:
# Customizations to support continued training

# If the training already State exists, it overrides newly-initialized TrainerState (initialized in HF.train())
possible_state_path = os.path.join(self.model.model_name_or_path, TRAINER_STATE_NAME)
if os.path.exists(possible_state_path):
self.state = TrainerState.load_from_json(possible_state_path)
logger.warning("Restoring training on global step %s", self.state.global_step)

# in case of continued training, optimizer exists on model.model_name_or_path
# if the optmizer.pt does not exist, the `super()._load_optimizer_and_scheduler` does not do anything
return super()._load_optimizer_and_scheduler(checkpoint=self.model.model_name_or_path)
86 changes: 63 additions & 23 deletions adaptor/lang_module.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import logging
import inspect
import os
from typing import List, Dict, Any, Optional

import torch
from peft import PeftConfig
from transformers import PreTrainedTokenizer, AutoTokenizer, AutoModelForSequenceClassification, \
AutoModelForTokenClassification, AutoModelForSeq2SeqLM, AutoModelForCausalLM, \
AutoModelForMaskedLM, AutoModelForQuestionAnswering

from .utils import Head
from .utils import Head, HEAD_TO_MODEL_CLS

logger = logging.getLogger()

Expand All @@ -24,18 +26,44 @@ class LangModule(torch.nn.Module):
"""

tokenizer: PreTrainedTokenizer
model_name_or_path: str
trainable_models: torch.nn.ModuleDict
heads_output_sizes: Dict[str, int] = {}

def __init__(self, model_name_or_path: str) -> None:
super().__init__()
self.model_name_or_path = model_name_or_path
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.tokenizer = self._find_and_load_tokenizer(model_name_or_path)

# head_kwargs = head_kwargs if head_kwargs is not None else [{}] * len(head_types)
# self._load_pretrained_with_heads(model_name_or_path, head_types, head_kwargs)
self.trainable_models = torch.nn.ModuleDict()

@staticmethod
def _find_and_load_tokenizer(model_name_or_path) -> PreTrainedTokenizer:
try:
# New training
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
logger.info("Loaded tokenizer from %s", model_name_or_path)
except OSError:
# Continued training
# in Adaptor checkpoints, tokenizers are persisted in the respective objectives' subdirs
# Hence, here we also look for the tokenizer in the model_name_or_path's subdirs
root = model_name_or_path
# continued training
subdirs = [path for path in os.listdir(root)
if os.path.isdir(os.path.join(root, path))]
subdirs_with_tokenizer = [os.path.join(root, subdir) for subdir in subdirs
if any(f.startswith("tokenizer") for f in os.listdir(os.path.join(root, subdir)))]
if not subdirs_with_tokenizer:
raise OSError("Could not find a tokenizer in any of the subdirectories %s "
"of given model_name_or_path='%s'", subdirs, root)
tokenizer_dir = subdirs_with_tokenizer[0]

tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
logger.info("Loaded tokenizer from %s", tokenizer_dir)
return tokenizer

@staticmethod
def load_head(model_name_or_path: str,
head_type: Head,
Expand All @@ -45,48 +73,60 @@ def load_head(model_name_or_path: str,
:param model_name_or_path: base model identifier
:param head_type: type of the requested head
:param head_kwargs: additional initialization arguments, adjusting its default, or persisted config
:return: transformer with a gead of requested type
:return: transformer with a head of requested type or a new pytorch model
"""

if head_type == Head.SEQ_CLASSIFICATION:
new_head = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, **head_kwargs)
elif head_type == Head.TOKEN_CLASSIFICATION:
new_head = AutoModelForTokenClassification.from_pretrained(model_name_or_path, **head_kwargs)
elif head_type == Head.SEQ2SEQ:
new_head = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, **head_kwargs)
elif head_type == Head.CLM:
new_head = AutoModelForCausalLM.from_pretrained(model_name_or_path, **head_kwargs)
elif head_type == Head.MLM:
new_head = AutoModelForMaskedLM.from_pretrained(model_name_or_path, **head_kwargs)
elif head_type == Head.QA:
new_head = AutoModelForQuestionAnswering.from_pretrained(model_name_or_path, **head_kwargs)
else:
try:
# trying to load first as a transformer model, and if it fails, as a peft model
BaseModelCls = HEAD_TO_MODEL_CLS[head_type]["full"]
try:
new_head = BaseModelCls.from_pretrained(model_name_or_path, **head_kwargs)
except OSError:
logger.warning("Loading model_name_or_path='%s' as full transformer failed. "
"Attempting to load it as peft model.", model_name_or_path)

peft_model_config = PeftConfig.from_pretrained(model_name_or_path)
base_model_path = peft_model_config.base_model_name_or_path
base_model = BaseModelCls.from_pretrained(base_model_path)

ModelCls = HEAD_TO_MODEL_CLS[head_type]["peft"]
new_head = ModelCls.from_pretrained(base_model, model_name_or_path, **head_kwargs)
except KeyError:
# requested head type is not in our map
logger.warning("Model in %s is not a transformers model. "
"Trying to load as a Pytorch model." % model_name_or_path)
new_head = torch.load(model_name_or_path, **head_kwargs)
except ValueError:
# model type is recognized, but could not be loaded
raise ValueError("Could not load model from %s as a transformer or peft model.", model_name_or_path)

return new_head

def load_training_head(self,
head_type: Head,
objective_id: str,
checkpoint_dir: Optional[str] = None,
head_kwargs: Optional[Dict[str, Any]] = None,
new_head: Optional[torch.nn.Module] = None) -> torch.nn.Module:
"""
Registers a selected model to this LangModule, i.e. merges its weights with first one of self.trainable_models,
and registers new model into self.trainable_models[objective_id].
:param head_type: if no `new_head` is given, a transformer of self.model_name_or_path
with a head of `head_type` will be registered.
:param objective_id: key of the new_head model.
:param objective_id: key of the new_head model used to route data samples
:param checkpoint_dir: directory to objective's checkpoints. Overrides model_name_or_path in continued training
:param head_kwargs: if transformer is automatically resolved, additional init args of the transformer,
passed to AutoModelForXY.from_pretrained()
:param new_head: if given, this would be a selected model to be registered.
:return:
:return: The module for a newly registered objective.
"""
# manually-initialized head chosen for this objective will also be merged with other objectives and registered
if head_kwargs is None:
head_kwargs = {}
if new_head is None:
new_head = self.load_head(self.model_name_or_path, head_type, head_kwargs)

new_head = self.load_head(self.model_name_or_path if checkpoint_dir is None else checkpoint_dir,
head_type,
head_kwargs)
# this applies to the 2nd+ -added models: they adopt the shared parameters of the first lang_module
if len(self.trainable_models) >= 1:
unmatched_modules = self._partially_merge_models(list(self.trainable_models.values())[0], new_head)
Expand Down Expand Up @@ -121,8 +161,8 @@ def _partially_merge_models(orig_model: torch.nn.Module,
# param present in the model to merge new_model into
new_model_param = getattr(new_model, new_param_key)
orig_model_param = getattr(orig_model, new_param_key)
if orig_model_param.shape == new_model_param.shape and torch.all(
orig_model_param == new_model_param):
if (orig_model_param.shape == new_model_param.shape
and torch.all(orig_model_param == new_model_param)):
setattr(new_model, new_param_key, orig_model_param)
assert id(getattr(orig_model, new_param_key)) == id(getattr(new_model, new_param_key))
else:
Expand Down
Loading

0 comments on commit c659dea

Please sign in to comment.