diff --git a/metaseq/checkpoint_utils.py b/metaseq/checkpoint_utils.py index 472fedd8b..eeecf638e 100644 --- a/metaseq/checkpoint_utils.py +++ b/metaseq/checkpoint_utils.py @@ -1,3 +1,4 @@ +# fmt: off # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. # # This source code is licensed under the MIT license found in the @@ -33,6 +34,7 @@ def save_checkpoint( trainer, epoch_itr, val_loss, + log_training_trajectory=False, training_finished=False, async_callback_fn=None, ): @@ -74,8 +76,7 @@ def is_better(a, b): ) checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = ( not end_of_epoch - and cfg.save_interval_updates > 0 - and updates % cfg.save_interval_updates == 0 + and ((cfg.save_interval_updates > 0 and updates % cfg.save_interval_updates == 0) or (log_training_trajectory and updates in [10, 20, 50, 100, 200, 500])) ) checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = ( val_loss is not None diff --git a/metaseq/dataclass/configs.py b/metaseq/dataclass/configs.py index c618d540e..41ff4ff4e 100644 --- a/metaseq/dataclass/configs.py +++ b/metaseq/dataclass/configs.py @@ -385,6 +385,12 @@ class DatasetConfig(MetaseqDataclass): "help": "if set, validate language model at the beginning of training or fine-tuning process" }, ) + log_training_trajectory: bool = field( + default=False, + metadata={ + "help": "(InstructOPT specific) if set, evaluate and save checkpoints more frequently in early stage of training" + }, + ) fixed_validation_seed: Optional[int] = field( default=None, metadata={"help": "specified random seed for validation"} ) diff --git a/metaseq_cli/train.py b/metaseq_cli/train.py index 3669c2b7e..da2946f04 100644 --- a/metaseq_cli/train.py +++ b/metaseq_cli/train.py @@ -37,7 +37,9 @@ from metaseq.model_parallel.megatron_trainer import MegatronTrainer from metaseq.trainer import Trainer from metaseq.tasks.streaming_language_modeling import StreamingLanguageModelingTask -from metaseq.tasks.streaming_finetune_language_modeling import StreamingFinetuneLanguageModelingTask +from metaseq.tasks.streaming_finetune_language_modeling import ( + StreamingFinetuneLanguageModelingTask, +) logging.basicConfig( @@ -418,6 +420,10 @@ def validate_and_save( and num_updates >= cfg.dataset.validate_after_updates and was_successful_step ) + or ( + cfg.dataset.log_training_trajectory + and num_updates in [10, 20, 50, 100, 200, 500] + ) ) do_validate = ( ( @@ -432,6 +438,10 @@ def validate_and_save( and was_successful_step ) or (cfg.dataset.validate_at_beginning and num_updates == 0) + or ( + cfg.dataset.log_training_trajectory + and num_updates in [0, 10, 20, 50, 100, 200, 500] + ) ) and not cfg.dataset.disable_validation valid_losses = [None] if do_validate: @@ -446,6 +456,7 @@ def validate_and_save( trainer, epoch_itr, valid_losses[0], + log_training_trajectory=cfg.dataset.validate_after_updates, training_finished=should_stop, async_callback_fn=functools.partial(post_checkpoint_callback, cfg) if cfg.checkpoint.cloud_upload_path