diff --git a/examples/sft/alpaca-llama-sft.yaml b/examples/sft/alpaca-llama-sft.yaml index 58422c7ab..f3667c489 100644 --- a/examples/sft/alpaca-llama-sft.yaml +++ b/examples/sft/alpaca-llama-sft.yaml @@ -47,6 +47,6 @@ supervised_data: # Additional settings tokenizer: "allenai/OLMo-1B" max_tune_length: 2048 -epoch: 3 +epoch: 0 initialize_from_hf: false diff --git a/examples/sft/sft.py b/examples/sft/sft.py index 594e1b41f..9813184b9 100644 --- a/examples/sft/sft.py +++ b/examples/sft/sft.py @@ -89,9 +89,9 @@ def train(config: SFTConfig): train_dataset = PermutationDataset(train_dataset, data_key) # Then wrap for epochs - if config.epoch > 0: - logger.info(f"Wrapping dataset for {config.epoch} epochs") - train_dataset = EpochDataset(train_dataset, max_epochs=config.epoch) + # if config.epoch > 0: + # logger.info(f"Wrapping dataset for {config.epoch} epochs") + # train_dataset = EpochDataset(train_dataset, max_epochs=config.epoch) logger.info("Creating optimizer") optimizer = config.optimizer.build(config.trainer.num_train_steps)