From 86feb0ca3e4543bf313b63e2f5ac3664c0e95e8a Mon Sep 17 00:00:00 2001 From: Srini Iyer Date: Wed, 26 Jul 2023 19:03:59 +0000 Subject: [PATCH] Broadcast json index file rather than loading on all ranks --- metaseq/data/jsonl_dataset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/metaseq/data/jsonl_dataset.py b/metaseq/data/jsonl_dataset.py index 36f3d1961..8c48d6571 100644 --- a/metaseq/data/jsonl_dataset.py +++ b/metaseq/data/jsonl_dataset.py @@ -56,12 +56,12 @@ def __init__( # TODO(susan): Fix this fairseq reference. _build_index fails otherwise. self.cache = Path(f"{resolved_path}.fairseq.idx.npy") # only build the cache in on the primary worker to prevent overloading nfs - if distributed_utils.get_global_rank() != 0: - distributed_utils.global_barrier() if self.cache.exists() and not recache: - distributed_utils.global_barrier() logger.info(f"Loading up cache: {self.cache}") - self.offsets = np.load(self.cache, allow_pickle=True) + # Loading on rank 0 and distributing, for speed. Reading the same file on all ranks causes significant slowdown + if distributed_utils.get_global_rank() == 0: + self.offsets = torch.from_numpy(np.load(self.cache, allow_pickle=True)) + self.offsets = distributed_utils.broadcast_tensors([self.offsets] if distributed_utils.get_global_rank() == 0 else None, src_rank=0, group=distributed_utils.get_global_group())[0] elif distributed_utils.get_global_rank() == 0: self.offsets = self._build_index(path) np.save(self.cache, self.offsets, allow_pickle=False)