Skip to content

Commit

Permalink
Multinode data sampling: fix
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanik12 committed Apr 29, 2024
1 parent 1778386 commit d5f8690
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion adaptor/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,5 +184,5 @@ def _load_optimizer_and_scheduler(self, checkpoint: str) -> None:
logger.warning("Restoring training on global step %s", self.state.global_step)

# in case of continued training, optimizer exists on model.model_name_or_path
# if the optmizer.pt does not exist, the `super()._load_optimizer_and_scheduler` does not do anything
# if the optimizer.pt does not exist, the `super()._load_optimizer_and_scheduler` does not do anything
return super()._load_optimizer_and_scheduler(checkpoint=self.model.model_name_or_path)
2 changes: 1 addition & 1 deletion adaptor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class AdaptationDataset(IterableDataset, abc.ABC):
"""

def __init__(self, length: Optional[int] = None):
self.world_size = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
self.world_size = int(os.environ.get("WORLD_SIZE", 1))
if self.world_size > 1:
logger.warning("World size for data sampling: %s" % self.world_size)
self.length = length // self.world_size
Expand Down

0 comments on commit d5f8690

Please sign in to comment.