Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

add log_training_trajectory option #549

Open
wants to merge 9 commits into
base: opt_instruct
Choose a base branch
from
5 changes: 3 additions & 2 deletions metaseq/checkpoint_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# fmt: off
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved.
#
# This source code is licensed under the MIT license found in the
Expand Down Expand Up @@ -33,6 +34,7 @@ def save_checkpoint(
trainer,
epoch_itr,
val_loss,
log_training_trajectory=False,
training_finished=False,
async_callback_fn=None,
):
Expand Down Expand Up @@ -74,8 +76,7 @@ def is_better(a, b):
)
checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = (
not end_of_epoch
and cfg.save_interval_updates > 0
and updates % cfg.save_interval_updates == 0
and ((cfg.save_interval_updates > 0 and updates % cfg.save_interval_updates == 0) or (log_training_trajectory and updates in [10, 20, 50, 100, 200, 500]))
)
checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = (
val_loss is not None
Expand Down
6 changes: 6 additions & 0 deletions metaseq/dataclass/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,12 @@ class DatasetConfig(MetaseqDataclass):
"help": "if set, validate language model at the beginning of training or fine-tuning process"
},
)
log_training_trajectory: bool = field(
default=False,
metadata={
"help": "(InstructOPT specific) if set, evaluate and save checkpoints more frequently in early stage of training"
},
)
fixed_validation_seed: Optional[int] = field(
default=None, metadata={"help": "specified random seed for validation"}
)
Expand Down
13 changes: 12 additions & 1 deletion metaseq_cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@
from metaseq.model_parallel.megatron_trainer import MegatronTrainer
from metaseq.trainer import Trainer
from metaseq.tasks.streaming_language_modeling import StreamingLanguageModelingTask
from metaseq.tasks.streaming_finetune_language_modeling import StreamingFinetuneLanguageModelingTask
from metaseq.tasks.streaming_finetune_language_modeling import (
StreamingFinetuneLanguageModelingTask,
)


logging.basicConfig(
Expand Down Expand Up @@ -418,6 +420,10 @@ def validate_and_save(
and num_updates >= cfg.dataset.validate_after_updates
and was_successful_step
)
or (
cfg.dataset.log_training_trajectory
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to save? Just validate?

Copy link
Contributor Author

@todpole3 todpole3 Dec 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To plot the trajectory for generation tasks as well (ROUGE); eval will be fast

and num_updates in [10, 20, 50, 100, 200, 500]
)
)
do_validate = (
(
Expand All @@ -432,6 +438,10 @@ def validate_and_save(
and was_successful_step
)
or (cfg.dataset.validate_at_beginning and num_updates == 0)
or (
cfg.dataset.log_training_trajectory
and num_updates in [0, 10, 20, 50, 100, 200, 500]
)
) and not cfg.dataset.disable_validation
valid_losses = [None]
if do_validate:
Expand All @@ -446,6 +456,7 @@ def validate_and_save(
trainer,
epoch_itr,
valid_losses[0],
log_training_trajectory=cfg.dataset.validate_after_updates,
training_finished=should_stop,
async_callback_fn=functools.partial(post_checkpoint_callback, cfg)
if cfg.checkpoint.cloud_upload_path
Expand Down