Skip to content

Commit

Permalink
Fixes for resuming from checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelburger committed Sep 24, 2024
1 parent b7ddf39 commit abe2b92
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 17 deletions.
4 changes: 0 additions & 4 deletions petagraph/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,6 @@ def get_dataloader_from_data_stage(
log_rank(f"Found {num_unitigs} unitigs and {num_contigs} contigs", logger=logger, level=logging.INFO, rank=0)
log_rank(f"Found {num_missed} missed accessions", logger=logger, level=logging.INFO, rank=0)

# TODO: if resuming from a checkpoint, we need to skip the already consumed files
if consumed_train_samples > 0:
raise NotImplementedError("Resuming from a checkpoint is not yet supported")

# Compute size and rank of dataloader workers
dp_ranks_size = trainer.parallel_context.dp_pg.size()
dp_rank = trainer.parallel_context.dp_pg.rank()
Expand Down
49 changes: 36 additions & 13 deletions src/nanotron/data/petagraph_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@ def __init__(self,
prefetch_fasta_parsing: int = 10,
log_directory: Path = None,
rank: int = 0,
packed: bool = False,
restart_consumed_files: list[str] = None,
restart_epoch: int = 0,
packed: bool = False
):

self.samples_per_epoch = samples_per_epoch
Expand All @@ -82,35 +80,47 @@ def __init__(self,
self.num_files = len(url_list)
self.current_epoch = 0

self.rank = rank
self.log_directory = log_directory
self.consumed_files_path = self.log_directory / f"consumed_files/consumed_files_rank_{self.rank}.txt"

# TODO: Take list of already consumed lists and remove them from the
# Take list of already consumed lists and remove them from the
# url list, to continue training from the last checkpoint properly
# TODO: REWORK THIS
# - Check if the consumed_files exist
# - If they exist, load them and assume we are restarting from a checkpoint
# - Find the largest epoch number in the consumed files
# - Filter the files that have been consumed/started in the latest epoch
# - Remove them from the url_list then append them to the end of the url_list
# - Set the current epoch to the latest epoch
if restart_consumed_files is not None:

if self.consumed_files_path.exists():
log_msg = f"[PetaGraphStreamDataset:{self.rank}] Consumed files found at {self.consumed_files_path} loading..."
log_rank(log_msg, logger=logger, level=logging.INFO, rank=self.rank)

restart_epoch, restart_consumed_files = self.load_restart_consumed_files(self.consumed_files_path)
log_msg = f"[PetaGraphStreamDataset:{self.rank}] Found {restart_epoch} epoch with {len(restart_consumed_files)} files"
log_rank(log_msg, logger=logger, level=logging.INFO, rank=self.rank)

# All files in restart_consumed_files should be present in the url_list
for f in restart_consumed_files:
assert f in url_list, f"File {f} from restart not found in the url_list"

# Remove those files from the url list and append them to the end
# of the url list
for f in restart_consumed_files:
restart_consumed_files_set = set(restart_consumed_files)
for f in restart_consumed_files_set:
url_list.remove(f)
url_list.append(f)
url_list.extend(restart_consumed_files)

# Add the consumed files to the consumed files set
self.consumed_files = set(restart_consumed_files)

# Set the current epoch to the restart epoch
self.current_epoch = restart_epoch

log_msg = f"[PetaGraphStreamDataset:{self.rank}] Restarting from epoch {self.current_epoch} with {len(self.consumed_files)} files"
log_rank(log_msg, logger=logger, level=logging.INFO, rank=self.rank)
else:
self.consumed_files = set()

if from_cloud:
# In order to make sure data are shuffled and sharded in the
Expand Down Expand Up @@ -170,9 +180,6 @@ def __init__(self,
for _ in range(warmup_sample_size):
_ = next(self.iterable_dataset)

self.rank = rank
self.log_directory = log_directory
self.consumed_files = set()
self.consumed_seq_len_queue = deque(maxlen=1000)
if self.log_directory is not None:
self.logging_func(f"[PetaGraphStreamDataset] Logging to {self.log_directory} on rank {self.rank}")
Expand All @@ -186,7 +193,23 @@ def __init__(self,

@staticmethod
def load_restart_consumed_files(restart_file: Path):
raise NotImplementedError("Loading restart files not implemented yet")
"""Load the consumed files from the restart file
Returns the latest epoch and the files consumed in the latest epoch
Parameters:
----------
restart_file (Path): The path to the restart file
"""
with open(restart_file, "r") as f:
consumed_files = f.readlines()
consumed_files = [f.strip() for f in consumed_files]
consumed_files_tuples = [(int(f.split("_")[0]), f.split("_")[1]) for f in consumed_files]

latest_epoch = max([f[0] for f in consumed_files_tuples])
latest_files = [f[1] for f in consumed_files_tuples if f[0] == latest_epoch]

return latest_epoch, latest_files

def decompression_func(self, input_data):
path, data = input_data
Expand Down

0 comments on commit abe2b92

Please sign in to comment.