Skip to content

Commit

Permalink
PEFT modules resolution, propagation of peft_config on Objective init
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanik12 committed Apr 17, 2024
1 parent 37d3e8c commit 53979f2
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 17 deletions.
59 changes: 43 additions & 16 deletions adaptor/lang_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import List, Dict, Any, Optional

import torch
from peft import PeftConfig
from peft import PeftConfig, get_peft_model
from transformers import PreTrainedTokenizer, AutoTokenizer

from .utils import Head, HEAD_TO_MODEL_CLS, PEFT_BASE_MODEL_CHECKPOINT_SUBDIR
Expand Down Expand Up @@ -68,6 +68,7 @@ def _find_and_load_tokenizer(model_name_or_path) -> PreTrainedTokenizer:
def load_head(self,
model_name_or_path: str,
head_type: Head,
load_as_peft: bool,
head_kwargs: Dict[str, Any],
continued_training: bool) -> torch.nn.Module:
"""
Expand All @@ -81,28 +82,52 @@ def load_head(self,
try:
# trying to load first as a transformer model, and if it fails, as a peft model
BaseModelCls = HEAD_TO_MODEL_CLS[head_type]["full"]
try:
# first trying to load as a full model
if not load_as_peft:
new_head = BaseModelCls.from_pretrained(model_name_or_path, **head_kwargs)
except OSError:
# if that fails, tring to load as a PEFT model
else:
PeftModelCls = HEAD_TO_MODEL_CLS[head_type]["peft"]
# if that fails, trying to load as a PEFT model
logger.warning("Loading model_name_or_path='%s' as full transformer failed. "
"Attempting to load it as peft model.", model_name_or_path)
# base model resolution
# we want to avoid reloading the base model separately for each lora module
if self.peft_base_model is None:
if not continued_training:
# in a fresh training, PEFT models define their base model in their config
# we avoid reloading the base model separately for each lora module

if continued_training:
# In PEFT training with adaptor, the base model is checkpointed in a pre-defined directory
base_model_path = os.path.join(model_name_or_path, "..", PEFT_BASE_MODEL_CHECKPOINT_SUBDIR)
if self.peft_base_model is None:
self.peft_base_model = BaseModelCls.from_pretrained(base_model_path)

new_head = PeftModelCls.from_pretrained(deepcopy(self.peft_base_model), model_name_or_path,
**head_kwargs)
logger.warning("Reloaded existing PEFT module from %s with base model %s.",
model_name_or_path, base_model_path)
else:
try:
# first try loading as an already-trained PEFT model (=> it already has its PeftConfig)
# if it does not, fall back to loading a brand new PEFT model
peft_model_config = PeftConfig.from_pretrained(model_name_or_path)
base_model_path = peft_model_config.base_model_name_or_path
else:
# In PEFT training with adaptor, the base model is checkpointed in a pre-defined directory
base_model_path = os.path.join(model_name_or_path, "..", PEFT_BASE_MODEL_CHECKPOINT_SUBDIR)
logger.warning("Attempting to reload base model for peft objectives from %s", base_model_path)
self.peft_base_model = BaseModelCls.from_pretrained(base_model_path)
if self.peft_base_model is None:
self.peft_base_model = BaseModelCls.from_pretrained(base_model_path)

ModelCls = HEAD_TO_MODEL_CLS[head_type]["peft"]
new_head = ModelCls.from_pretrained(deepcopy(self.peft_base_model), model_name_or_path, **head_kwargs)
new_head = PeftModelCls.from_pretrained(deepcopy(self.peft_base_model), model_name_or_path,
**head_kwargs)
logger.warning("Reloaded existing PEFT module from %s with base model %s.",
model_name_or_path, base_model_path)
except ValueError:
logger.warning("Initializing a new PEFT module.")
# ValueError: Can't find 'adapter_config.json' at {model_name_or_path}
# -> we initialize a new PEFT model from a full pre-trained transformer (simplest case)
assert 'peft_config' in head_kwargs, \
("Initializing an objective with PEFT model requires to pass a 'peft_config' "
"witin `objective_args_for_head_config`, e.g: "
"`objective = Objective(objective_args_for_head_config={'peft_config': LoraConfig()}`."
" See the docs on https://huggingface.co/docs/peft/main/en/package_reference/config")
if self.peft_base_model is None:
self.peft_base_model = BaseModelCls.from_pretrained(model_name_or_path)
head_kwargs['peft_config'].base_model_name_or_path = model_name_or_path
new_head = get_peft_model(deepcopy(self.peft_base_model), **head_kwargs)
except KeyError:
# requested head type is not in our map
logger.warning("Model in %s is not a transformers model. "
Expand All @@ -116,6 +141,7 @@ def load_head(self,

def load_training_head(self,
head_type: Head,
load_as_peft: bool,
objective_id: str,
checkpoint_dir: Optional[str] = None,
head_kwargs: Optional[Dict[str, Any]] = None,
Expand All @@ -139,6 +165,7 @@ def load_training_head(self,
if new_head is None:
new_head = self.load_head(self.model_name_or_path if checkpoint_dir is None else checkpoint_dir,
head_type,
load_as_peft,
head_kwargs,
continued_training=checkpoint_dir is not None)
# this applies to the 2nd+ -added models: they adopt the shared parameters of the first lang_module
Expand Down
8 changes: 7 additions & 1 deletion adaptor/objectives/objective_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class Objective(abc.ABC):
num_samples_per_log: Dict[str, int]
num_samples_to_prefetch: int = 10

peft_objective: bool

def __init__(self,
lang_module: LangModule,
batch_size: int,
Expand All @@ -59,7 +61,8 @@ def __init__(self,
max_samples_per_eval_log: int = 10000,
data_iteration_offset: int = 0,
prefetch_in_parallel_thread: bool = False,
remember_last_input: Optional[bool] = False):
remember_last_input: Optional[bool] = False,
peft_objective: Optional[bool] = False):
"""
Shared initialisation logic of every Objective.
Registers a compatible model for this objective to given `lang_module`,
Expand Down Expand Up @@ -89,7 +92,9 @@ def __init__(self,
self.batch_size = batch_size
self.tokenizer = lang_module.tokenizer
self.given_id = objective_id
self.peft_objective = peft_objective
self.loss_weight = loss_weight

self.num_steps = 0
self.remember_last_input = remember_last_input
self.last_input = None
Expand Down Expand Up @@ -495,6 +500,7 @@ def register_compatible_head_model(self,
possible_checkpoint_path, lang_module.model_name_or_path)

return lang_module.load_training_head(self.compatible_head,
self.peft_objective,
str(id(self)),
checkpoint_dir,
head_config,
Expand Down

0 comments on commit 53979f2

Please sign in to comment.