Skip to content

Commit

Permalink
Logging of model reload in continued training
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanik12 committed Apr 10, 2024
1 parent 475bc30 commit dae2584
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 12 deletions.
2 changes: 1 addition & 1 deletion adaptor/lang_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def load_head(model_name_or_path: str,
new_head = torch.load(model_name_or_path, **head_kwargs)
except ValueError:
# model type is recognized, but could not be loaded
raise ValueError("Could not load model from %s as a transformer or peft model.", model_name_or_path)
raise ValueError("Could not load model from %s as a transformer or peft model." % model_name_or_path)

return new_head

Expand Down
14 changes: 8 additions & 6 deletions adaptor/objectives/objective_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,11 @@ def register_compatible_head_model(self,
# Support for continued training:
checkpoint_dir = None
possible_checkpoint_path = os.path.join(lang_module.model_name_or_path, str(self))
if os.path.exists(possible_checkpoint_path):
if other_objective is not None:
logger.warning("Objective %s will use %s head of %s objective",
self, self.compatible_head.name, other_objective)
preloaded_module = other_objective.compatible_head_model
elif os.path.exists(possible_checkpoint_path):
logger.warning("Reloading objective %s's module from checkpoint %s", str(self), possible_checkpoint_path)
checkpoint_dir = possible_checkpoint_path

Expand All @@ -478,11 +482,9 @@ def register_compatible_head_model(self,
trainer_state = TrainerState.load_from_json(os.path.join(lang_module.model_name_or_path,
"trainer_state.json"))
self.data_iteration_offset = trainer_state.global_step

elif other_objective is not None:
logger.warning("Objective %s will use %s head of %s objective",
self, self.compatible_head.name, other_objective)
preloaded_module = other_objective.compatible_head_model
else:
logger.warning("No checkpoint found on %s. Attempting to load a model from '%s'.",
possible_checkpoint_path, lang_module.model_name_or_path)

return lang_module.load_training_head(self.compatible_head,
str(id(self)),
Expand Down
18 changes: 13 additions & 5 deletions adaptor/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import abc
import logging
from enum import Enum
from typing import Dict, Iterable, Iterator, Optional
import os

import torch
import peft
Expand All @@ -9,6 +11,9 @@
from transformers import BatchEncoding, TrainingArguments


logger = logging.getLogger()


class Head(Enum):
SEQ_CLASSIFICATION = 1
TOKEN_CLASSIFICATION = 2
Expand Down Expand Up @@ -57,9 +62,12 @@ class AdaptationDataset(IterableDataset, abc.ABC):
"""

def __init__(self, length: Optional[int] = None):
worker_info = torch.utils.data.get_worker_info()

self.length = length // worker_info.num_workers
self.world_size = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
if self.world_size > 1:
logger.warning("World size for data sampling: %s" % self.world_size)
self.length = length // self.world_size
else:
self.length = length

def __getitem__(self, index: int) -> BatchEncoding:
raise ValueError("We shouldn't ever get here?")
Expand Down Expand Up @@ -112,9 +120,9 @@ def __iter__(self) -> Iterator[Dict[str, torch.LongTensor]]:
if i < self.offset:
continue

if worker_info is not None:
if self.world_size > 1 and worker_info is not None:
# multi-gpu DataParallel
if (i - worker_info.id) % worker_info.num_workers == 0:
if (i - worker_info.id) % world_size == 0:
# sample modulo number of all workers match this worker rank
yield encoded_sample
else:
Expand Down

0 comments on commit dae2584

Please sign in to comment.