-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flipout implementation for recurrent networks (lstm, gru) #6
Comments
Thanks for your kind words about the project! We currently don't have plans for adding RNN support for flipout ourselves. But it would of course be great to have in the library. I had a quick look at what the pytorch RNN classes do under the hood and they all seem to be calling I think it would just require a couple of small changes to the Reparameterization and FlipoutMessenger, specifically not assuming that the functions live in |
Not ready for a PR yet but have a look at: https://github.com/lbasora/TyXe/blob/56ae7d32f6bf877d142bba99e2c38d84b0c25999/tyxe/poutine/reparameterization_messengers.py It's just an initial attempt to test the monkey-patching and the working principle. Let me know whether the solution is more or less what you had in mind. The code has not been properly tested and is still incomplete: dropout, bidirectional, multilayer options not implemented yet. Here is a very basic code snipped to play with it: from functools import partial
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import pyro
import pyro.distributions as dist
import tyxe
from tqdm.auto import trange
class Lstm(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(3, 2, batch_first=True)
self.out = nn.Linear(2, 1)
def forward(self, x):
_, (ht, _) = self.lstm(x)
return self.out(ht.squeeze())
class Gru(nn.Module):
def __init__(self):
super().__init__()
self.gru = nn.GRU(3, 2, batch_first=True)
self.out = nn.Linear(2, 1)
def forward(self, x):
_, ht = self.gru(x)
return self.out(ht.squeeze())
net = Lstm()
x = torch.rand(5, 4, 3)
y = torch.rand(5, 1)
pyro.set_rng_seed(42)
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda") if USE_CUDA else torch.device("cpu")
prior = tyxe.priors.IIDPrior(
dist.Normal(torch.tensor(0.0, device=DEVICE), torch.tensor(1.0, device=DEVICE))
)
likelihood = tyxe.likelihoods.HomoskedasticGaussian(len(x), scale=0.1)
guide = partial(tyxe.guides.AutoNormal, init_scale=0.1)
bnn = tyxe.VariationalBNN(net, prior, likelihood, guide)
ds = TensorDataset(x, y)
dl = DataLoader(ds, batch_size=len(x), pin_memory=USE_CUDA)
pyro.clear_param_store()
optim = pyro.optim.Adam({"lr": 1e-3})
num_epochs = 1
pbar = trange(num_epochs)
elbos = []
def callback(_i, _ii, e):
elbos.append(e / len(dl.sampler))
pbar.update()
with tyxe.poutine.flipout():
bnn.fit(dl, optim, num_epochs, device=DEVICE, callback=callback) |
Thanks, I'll try to have a closer look at this in the next couple of days. In terms of the interface it is what I have in mind. Regarding the implementation, I was hoping it would be possible to avoid re-implementing the forward passes by hand. Since what you're doing ultimately relies on calling I'm also a little bit worried that with the way this is implemented a new flipout mask would be sampled at every time step corresponding to a different weight sample. However, the sample from the weight posterior should be tied across all time steps, I think we'd need some extra logic for caching the sign multipliers in the original reparameterize function. Regarding testing, just as a basic sanity check, have you tried passing an input that is repeated along the batch dimension through the network inside and outside of a flipout context? Inside the context you should get different outputs (if the implementation works), whereas outside they should all be identical since the same sample for the weights should be used across the batch. |
To avoid re-implementing the forward pass, are you suggesting TyXe should provide the user with LSTM/GRU Flipout layers? As in bayesian-torch or in Edward2? But then what about the nice TyXe feature of turning a regular pytorch LSTM/GRU model into a BNN model? I agree with you current implementation isn't correct because as you say a new flipout mask is applied at each time step. But I don't know yet how we can cache the sign multipliers. Let me know if you see a practial way to do that. Good suggestion for the testing. |
That is definitely the priority :) I was just thinking that if we add code to the library that is an "explicit" pytorch implementation of a gru/lstm forward pass we might as well expose it rather than hiding it in the reparameterization messengers. This seems to be something that some people are interested in. That would of course require having an additional helper function for conversion, but I think this all could be as simple as putting your code for the forward pass in a separate function, set up This also has the advantage that it's a bit more explicit that the forward pass is being changed. I'm a bit concerned that such a pure pytorch implementation would be a lot slower than the cuda kernel that pytorch calls.
I had another look at the implementation and realized that I was already caching the masks by setting them as attributes on the sampled weights. So the samples should actually be consistent across time steps. I'll need to figure out how to test this though to be absolutely sure, let me know if you have any thoughts on this. |
I've commited some changes (8e1e5b3) to try to take into account your last comments. Please let me know whether the interface is convenient for you or you want some changes. I've tested and we obtain the same results that with the pytorch GRU/LSTM implementation when not using flipout. I also did the basic sanity check you suggested: when input is repeated along the batch dimension inside the flipout context we get different outputs whereas outside they are identical. I haven't tested though the consistency regarding the cached sign multipliers. The current LSTM/GRU is not complete yet. Features like multilayer, bidirectionality and dropout can be added later though. |
Thanks for the excellent TyXe initiative.
Currently, flipout is implemented in TyXe for linear and convolutional layers. Are you considering supporting as well RNN in the near future?
If I understood well, for linear and conv layers you monkey-patched F.linear and F.conv, but I didn't see any equivalent functions for RNNs in torch.nn.functional allowing for a similar solution. Do you have any idea on how to implement this?
The text was updated successfully, but these errors were encountered: