Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

turning UNet into Bayesian UNet #25

Open
taborzbislaw opened this issue Nov 30, 2022 · 5 comments
Open

turning UNet into Bayesian UNet #25

taborzbislaw opened this issue Nov 30, 2022 · 5 comments

Comments

@taborzbislaw
Copy link

Hi,

I am trying to use your library to turn UNet into a Bayesian Unet. I paste the code below: in the implementation UNet works as a pixel-to-pixel translator for 3D data. The code follows your regression example (as I am also doing regression but for higher dimensional data).

When I run the code I got a run-time error:
ValueError: Expected parameter scale (Tensor of shape (4, 1, 32, 32, 16)) of distribution Normal(loc: torch.Size([4, 1, 32, 32, 16]), scale: torch.Size([4, 1, 32, 32, 16])) to satisfy the constraint GreaterThan(lower_bound=0.0), but found invalid values...

I expect that the problem is with a wrong selection of the prior and/or guide. I would appreciate any suggestion which will make the model to learn.

Regards,
Zbisław

The code:

from functools import partial

import torch
import torch.nn as nn
import torch.utils.data as data

import pyro
import pyro.distributions as dist

import tyxe

def double_convolution(in_channels, out_channels):
"""
In the original paper implementation, the convolution operations were
not padded but we are padding them here. This is because, we need the
output result size to be same as input size.
"""
conv_op = nn.Sequential(
nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
return conv_op

class UNet(nn.Module):
def init(self, num_classes):
super(UNet, self).init()

    self.max_pool3d = nn.MaxPool3d(kernel_size=2, stride=2)

    # contracting path
    # each convolution is applied twice
    self.down_convolution_1 = double_convolution(3, 64)
    self.down_convolution_2 = double_convolution(64, 128)
    self.down_convolution_3 = double_convolution(128, 256)
    self.down_convolution_4 = double_convolution(256, 512)
    self.down_convolution_5 = double_convolution(512, 1024)

    # expanding path
    self.up_transpose_1 = nn.ConvTranspose3d(
        in_channels=1024, out_channels=512,
        kernel_size=2, 
        stride=2)
    # below, `in_channels` again becomes 1024 as we are concatinating
    self.up_convolution_1 = double_convolution(1024, 512)
    self.up_transpose_2 = nn.ConvTranspose3d(
        in_channels=512, out_channels=256,
        kernel_size=2, 
        stride=2)
    self.up_convolution_2 = double_convolution(512, 256)
    self.up_transpose_3 = nn.ConvTranspose3d(
        in_channels=256, out_channels=128,
        kernel_size=2, 
        stride=2)
    self.up_convolution_3 = double_convolution(256, 128)
    self.up_transpose_4 = nn.ConvTranspose3d(
        in_channels=128, out_channels=64,
        kernel_size=2, 
        stride=2)
    self.up_convolution_4 = double_convolution(128, 64)

    # output => increase the `out_channels` as per the number of classes
    self.out = nn.Conv3d(
        in_channels=64, out_channels=num_classes, 
        kernel_size=1
    ) 

def forward(self, x):
    down_1 = self.down_convolution_1(x)
    down_2 = self.max_pool3d(down_1)
    down_3 = self.down_convolution_2(down_2)
    down_4 = self.max_pool3d(down_3)
    down_5 = self.down_convolution_3(down_4)
    down_6 = self.max_pool3d(down_5)
    down_7 = self.down_convolution_4(down_6)
    #down_8 = self.max_pool3d(down_7)
    #down_9 = self.down_convolution_5(down_8)        
    
    #up_1 = self.up_transpose_1(down_9)
    #x = self.up_convolution_1(torch.cat([down_7, up_1], 1))

    #up_2 = self.up_transpose_2(x)
    up_2 = self.up_transpose_2(down_7)
    x = self.up_convolution_2(torch.cat([down_5, up_2], 1))

    up_3 = self.up_transpose_3(x)
    x = self.up_convolution_3(torch.cat([down_3, up_3], 1))

    up_4 = self.up_transpose_4(x)
    x = self.up_convolution_4(torch.cat([down_1, up_4], 1))

    out = self.out(x)
    return out

################################################################################

if name == 'main':

pyro.set_rng_seed(42)

x = torch.randn((4,3,32,32,32)) * 0.5 + 0.5
y = torch.mean(x,1,keepdim=True) + torch.randn((4,1,32,32,32))*0.01

for _ in range(10):
    x2 = torch.randn((4,3,32,32,32)) * 0.5 + 0.5
    x = torch.cat([x, x2])

    y2 = torch.mean(x2,1,keepdim=True) + torch.randn((4,1,32,32,32))*0.01
    y = torch.cat([y, y2])

batchSize = 4
dataset = data.TensorDataset(x, y)
loader = data.DataLoader(dataset, batch_size=batchSize)




net = UNet(1)
prior = tyxe.priors.IIDPrior(dist.Normal(0, 1))
obs_model = tyxe.likelihoods.HeteroskedasticGaussian((4,1,32,32,32))
guide = partial(tyxe.guides.AutoNormal, init_scale=0.01)
bnn = tyxe.VariationalBNN(net, prior, obs_model, guide)


pyro.clear_param_store()
optim = pyro.optim.Adam({"lr": 1e-4})
elbos = []
def callback(bnn, i, e):
    elbos.append(e)
    
with tyxe.poutine.local_reparameterization():
    bnn.fit(loader, optim, 10000, callback)
@LarsBentsen
Copy link

LarsBentsen commented Jan 2, 2023

Hi,
I'm also using HeteroskedasticGaussian likelihood and I'm not entirely sure if this will solve your issue, but in the "likelihoods.py" file it says under the HeteroskedasticGaussian class that it expects predictions to be 2d to predict both the mean and std (i.e. the aleatoric uncertainty?). For your net, it might help to change the output layer of your network to:

self.out = nn.Conv3d(
in_channels=64, out_channels=num_classes * 2,
kernel_size=1
)
However, I'm uncertain if this will solve your issues entirely and I apologise if this is a very trivial proposal...

@taborzbislaw
Copy link
Author

Hi,

thank you for the suggestion but it does not work. I got the same effect of blowing tensors.

@hpplyt
Copy link
Collaborator

hpplyt commented Jan 11, 2023

Sorry for the delayed response, I had completely missed the initial issue over Neurips.

Does the error occur right away or after a few training iterations? And have you by any chance checked if you are getting NaNs or negative values for the scales of the Normal distribution? Either would fail the constraint check iirc.

@taborzbislaw
Copy link
Author

Hi.

The error occurs right after the start of the training. I got something like this:

ValueError: Expected parameter scale (Tensor of shape (4, 2, 32, 32, 16)) of distribution Normal(loc: torch.Size([4, 2, 32, 32, 16]), scale: torch.Size([4, 2, 32, 32, 16])) to satisfy the constraint GreaterThan(lower_bound=0.0), but found invalid values:
tensor([[[[[8.8535e+16, 0.0000e+00, 3.9392e+16, ..., 1.2329e+16,
5.3384e+15, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 4.3650e+16, ..., 5.4064e+16,
8.3643e+16, 0.0000e+00],
[5.4920e+16, 8.3632e+16, 2.0063e+17, ..., 0.0000e+00,
3.9833e+16, 0.0000e+00],
...,

I expect that something may be wrong with my model. I simply copied your regression example and replaced your model with UNet which is a little bit more complicated that regression. The complete code is in the enclosed archive - the code generates its own data so it can be just run.

UNet.py.zip

@taborzbislaw taborzbislaw reopened this Jan 11, 2023
@hpplyt
Copy link
Collaborator

hpplyt commented Jan 16, 2023

Ok interesting. My guess would be that this is an issue with initialisation since you are getting fairly extreme values for the standard deviations. If you replace the init_scale argument for the guide with something around the standard deviation used for initialising the weights of the deterministic version of your network that might fix things (or just go with 1e-5 or so). You might also need to initialise the means of your guide to the deterministic weights as in the ResNet example.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants