diff --git a/src/nanotron/data/petagraph_dataset.py b/src/nanotron/data/petagraph_dataset.py index 96af2197..000deaef 100644 --- a/src/nanotron/data/petagraph_dataset.py +++ b/src/nanotron/data/petagraph_dataset.py @@ -82,8 +82,17 @@ def __init__(self, self.num_files = len(url_list) self.current_epoch = 0 + self.consumed_files_path = self.log_directory / f"consumed_files/consumed_files_rank_{self.rank}.txt" + # TODO: Take list of already consumed lists and remove them from the # url list, to continue training from the last checkpoint properly + # TODO: REWORK THIS + # - Check if the consumed_files exist + # - If they exist, load them and assume we are restarting from a checkpoint + # - Find the largest epoch number in the consumed files + # - Filter the files that have been consumed/started in the latest epoch + # - Remove them from the url_list then append them to the end of the url_list + # - Set the current epoch to the latest epoch if restart_consumed_files is not None: # All files in restart_consumed_files should be present in the url_list @@ -273,8 +282,7 @@ def generate(self): if self.log_directory is not None: if source_path not in self.consumed_files: - out_path = self.log_directory / f"consumed_files/consumed_files_rank_{self.rank}.txt" - with open(out_path, "a") as f: + with open(self.consumed_files_path, "a") as f: f.write(f"{self.current_epoch}_{source_path}\n") self.consumed_files.add(source_path) if len(self.consumed_files) == self.num_files: