Skip to content

Commit

Permalink
Add todo for improving checkpoint continuation
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelburger committed Aug 8, 2024
1 parent ab7ac98 commit b7ddf39
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions src/nanotron/data/petagraph_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,17 @@ def __init__(self,
self.num_files = len(url_list)
self.current_epoch = 0

self.consumed_files_path = self.log_directory / f"consumed_files/consumed_files_rank_{self.rank}.txt"

# TODO: Take list of already consumed lists and remove them from the
# url list, to continue training from the last checkpoint properly
# TODO: REWORK THIS
# - Check if the consumed_files exist
# - If they exist, load them and assume we are restarting from a checkpoint
# - Find the largest epoch number in the consumed files
# - Filter the files that have been consumed/started in the latest epoch
# - Remove them from the url_list then append them to the end of the url_list
# - Set the current epoch to the latest epoch
if restart_consumed_files is not None:

# All files in restart_consumed_files should be present in the url_list
Expand Down Expand Up @@ -273,8 +282,7 @@ def generate(self):

if self.log_directory is not None:
if source_path not in self.consumed_files:
out_path = self.log_directory / f"consumed_files/consumed_files_rank_{self.rank}.txt"
with open(out_path, "a") as f:
with open(self.consumed_files_path, "a") as f:
f.write(f"{self.current_epoch}_{source_path}\n")
self.consumed_files.add(source_path)
if len(self.consumed_files) == self.num_files:
Expand Down

0 comments on commit b7ddf39

Please sign in to comment.