Skip to content

Commit

Permalink
Update dist_ckpt_load_strictness for sft and steerlm2 (#293)
Browse files Browse the repository at this point in the history
Signed-off-by: Dong Hyuk Chang <[email protected]>
Co-authored-by: Dong Hyuk Chang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored and terrykong committed Oct 25, 2024
1 parent 56534cc commit 77af2a8
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 1 deletion.
3 changes: 3 additions & 0 deletions examples/nlp/gpt/train_gpt_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False):
if cfg.model.get("seq_len_interpolation_factor", None) is not None:
gpt_cfg.seq_len_interpolation_factor = cfg.model.seq_len_interpolation_factor

if cfg.model.get("dist_ckpt_load_strictness", None) is not None:
gpt_cfg.dist_ckpt_load_strictness = cfg.model.dist_ckpt_load_strictness

gpt_cfg.inference = cfg.model.get("inference", {})

# This is needed when modifying a hparam file directly to load `.ckpt` files.
Expand Down
3 changes: 3 additions & 0 deletions examples/nlp/gpt/train_steerlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False):
if cfg.model.get("seq_len_interpolation_factor", None) is not None:
gpt_cfg.seq_len_interpolation_factor = cfg.model.seq_len_interpolation_factor

if cfg.model.get("dist_ckpt_load_strictness", None) is not None:
gpt_cfg.dist_ckpt_load_strictness = cfg.model.dist_ckpt_load_strictness

gpt_cfg.inference = cfg.model.get("inference", {})

# This is needed when modifying a hparam file directly to load `.ckpt` files.
Expand Down
2 changes: 1 addition & 1 deletion nemo_aligner/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def load_checkpoint_model_config(restore_path):

with tempfile.TemporaryDirectory() as tmpdir:
# Extracts only model config
members = NLPSaveRestoreConnector._filtered_tar_info(restore_path, filter_fn=lambda name: '.yaml' in name)
members = NLPSaveRestoreConnector._filtered_tar_info(restore_path, filter_fn=lambda name: ".yaml" in name)
NLPSaveRestoreConnector._unpack_nemo_file(restore_path, tmpdir, members=members)
cfg = OmegaConf.load(os.path.join(tmpdir, config_name_in_ckpt))

Expand Down

0 comments on commit 77af2a8

Please sign in to comment.