Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 668520600
  • Loading branch information
Forgotten authored and The swirl_dynamics Authors committed Aug 28, 2024
1 parent 25a1a23 commit 07bf3b5
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions swirl_dynamics/projects/debiasing/rectified_flow/main_train_ens.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,23 @@ def main(argv):
ema_decay=config.ema_decay,
)

if "trained_state_dir" in config:
# Load the parameters from the checkpoint of an already trained model.
logging.info("Loading trained state from %s", config.trained_state_dir)
trained_state = trainers.TrainState.restore_from_orbax_ckpt(
f"{config.trained_state_dir}/checkpoints",
step=None,
ref_state=trainer.train_state,
)

# Modify train_state with params from checkpoint.
trainer.train_state = trainer.train_state.replace(
params=trained_state.params,
flax_mutables=trained_state.flax_mutables,
)
# Avoid having more than one checkpoint.
del trained_state

# Setting up checkpointing.
ckpt_options = checkpoint.CheckpointManagerOptions(
save_interval_steps=config.save_interval_steps,
Expand Down

0 comments on commit 07bf3b5

Please sign in to comment.