-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Errors when trying to load a VariationalBNN #15
Comments
Hi, Apologies for the slow response! I think this is just due to the variational parameter attributes being initialized lazily by Pyro. If your (deterministic) network doesn't have any buffers, i.e. only parameters, you shouldn't need to save/load the state dict and the param store should contain everything you need. Otherwise, if you do need to load the state dict, just run a forward pass through your BNN by calling Let me know if neither option resolves the error, in that case I'd need to take a closer look at what's going on :) |
Hi, I have encountered the same error and I might have found a solution. You need to load the state dict using the
Hope this helps ! |
Thanks Camille,
I will try your solution 😉
Cheers,
Francesco
Il 19 ott 2022, 16:30 +0200, Camille Besombes ***@***.***>, ha scritto:
… Hi,
I have encountered the same error and I might have found a solution. You need to load the state dict using the .netattribute of your model :
pyro.clear_param_store()
model.net.load_state_dict(torch.load(os.path.join(save_model_path, "best_model.pt")))
pyro.get_param_store().load(os.path.join(save_model_path, "param_store.pt"))
Hope this helps !
—
Reply to this email directly, view it on GitHub, or unsubscribe.
You are receiving this because you authored the thread.Message ID: ***@***.***>
|
Forgot in my answer that it is necessary to save it as well as following : |
Hi, I had the same issue and the solution of @Cam-B04 worked correctly, thank you. |
Hi all,
I'm new to TyXe, but I'm experimenting an issue when I'm trying to load a (previously) trained model from the disk.
To be more precise, the returned error is as in the following:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for VariationalBNN: Unexpected key(s) in state_dict: net_guide.rnn.weight_ih_l0.loc_unconstrained etc.
In particular, to save the model, I use a code like this:
pyro.get_param_store().save(os.path.join(output_dir, "param_store.pt"))
torch.save(model.state_dict(), os.path.join(output_dir, "best_mode.pt"))
To load the model (defined as
tyxe.VariationalBNN(net, prior, likelihood, guide)
) instead:pyro.clear_param_store()
model.load_state_dict(torch.load(os.path.join(save_model_path, "best_model.pt")))
pyro.get_param_store().load(os.path.join(save_model_path, "param_store.pt"))
Where is the error?
Thank you so much.
The text was updated successfully, but these errors were encountered: