You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, I'm trying to reload my pre-trained network to continue training in Cifar-10 experiments cifar.py, while the loss does not converge after reloading (the loss still converges if the model is initialized without reloading). I guess maybe it's the issue of setting __base_optimizer as part of the optimizer state so when I run optimizer.load_state_dict(ckpt['optimizer_state_dict']) the state of base optimizer was directly replaced by the state in the optimizer_state_dict.
Now I solve the problem by making base_optimizer as class member of each optimizer of BNN algorithms, such as self.__base_optimizer = base_optimizer instead of self.state["__base_optimizer"] = base_optimizer. To this end, I will load the state_dict of the optimizer and its base optimizer separately and finally, the training loss converges after reloading.
Following is my code snippet for reloading:
def load_model(model_idx, model, scaler, optimizer, out_path, config, log):
ckpt = None
start_epoch = 0
# Load checkpoint and scaler if available
if config.get("use_checkpoint", None):
try:
ckpt_paths = glob.glob(out_path + f"{config['model']}_chkpt_{model_idx}_*.pth")
ckpt_paths.sort(key=os.path.getmtime)
ckpt = torch.load(ckpt_paths[-1])
model.load_state_dict(ckpt['model_state_dict'])
start_epoch = ckpt["epoch"] + 1
scaler.load_state_dict(ckpt["scaler_state_dict"])
log.info(f"Loaded checkpoint for model {model_idx} at epoch {start_epoch}")
except:
log.info(f"Failed to load checkpoint for model {model_idx}")
optimizer.init_grad_scaler(scaler)
# Load optimizer state if available
# Base optimizer state is loaded separately if available
if ckpt is not None:
try:
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
if ckpt.get("base_optimizer") is not None:
optimizer.get_base_optimizer().load_state_dict(ckpt["base_optimizer"])
log.info(f"Loaded base optimizer state for model {model_idx}")
except:
log.info(f"Failed to load optimizer state for model {model_idx}")
# Load scheduler state if available
if config["lr_schedule"]:
scheduler = wilson_scheduler(optimizer.get_base_optimizer(), config["epochs"], config["lr"], None)
if ckpt is not None:
scheduler.load_state_dict(ckpt["scheduler_state_dict"])
log.info(f"Loaded scheduler state for model {model_idx}")
else:
scheduler = None
return start_epoch, model, optimizer, scaler, scheduler
class MAPOptimizer(BayesianOptimizer):
'''
Maximum A Posteriori
This simply optimizes a point estimate of the parameters with the given base_optimizer.
'''
def __init__(self, params, base_optimizer):
super().__init__(params, {})
# self.state["__base_optimizer"] = base_optimizer
self.__base_optimizer = base_optimizer
Since I'm still looking into other optimizers, it could be a great help if you can inform me of any potential problems of doing so. Thank you very much!
The text was updated successfully, but these errors were encountered:
thanks for raising that issue! I think your change is good and makes sense, and shouldn't introduce any problems down the line. If I find time, I will fix it across the entire repository. If you want you can also open a pull request with the change.
Thanks for your interest in the repository!
Florian
Hi, I'm trying to reload my pre-trained network to continue training in Cifar-10 experiments
cifar.py
, while the loss does not converge after reloading (the loss still converges if the model is initialized without reloading). I guess maybe it's the issue of setting__base_optimizer
as part of the optimizer state so when I runoptimizer.load_state_dict(ckpt['optimizer_state_dict'])
the state of base optimizer was directly replaced by the state in theoptimizer_state_dict
.Now I solve the problem by making
base_optimizer
as class member of each optimizer of BNN algorithms, such asself.__base_optimizer = base_optimizer
instead ofself.state["__base_optimizer"] = base_optimizer
. To this end, I will load the state_dict of the optimizer and its base optimizer separately and finally, the training loss converges after reloading.Following is my code snippet for reloading:
and how I save the model during training:
an example for the optimizer change is:
Since I'm still looking into other optimizers, it could be a great help if you can inform me of any potential problems of doing so. Thank you very much!
The text was updated successfully, but these errors were encountered: