Skip to content

Commit

Permalink
Merge pull request #198 from Modalities/rank_0_logging
Browse files Browse the repository at this point in the history
feat: moved print_rank_0 function towards a better fitting place
  • Loading branch information
mali-git authored Aug 2, 2024
2 parents 65bf611 + f2f7e4f commit c82cb5a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
9 changes: 1 addition & 8 deletions src/modalities/dataloader/open_gptx_dataset/mmap_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,7 @@
from numpy._typing import NDArray
from torch.utils.data import Dataset


def print_rank_0(message: str):
"""If distributed is initialized, print only on rank 0."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(message, flush=True)
else:
print(message, flush=True)
from modalities.util import print_rank_0


def get_best_fitting_dtype(vocab_size: Optional[int] = None) -> np.dtype:
Expand Down
9 changes: 9 additions & 0 deletions src/modalities/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@
from modalities.running_env.fsdp.reducer import Reducer


def print_rank_0(message: str):
"""If distributed is initialized, print only on rank 0."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(message, flush=True)
else:
print(message, flush=True)


def parse_enum_by_name(name: str, enum_type: Type[Enum]) -> Enum:
try:
return enum_type[name]
Expand Down

0 comments on commit c82cb5a

Please sign in to comment.