Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BNN broadcastable error #28

Open
sgpuzzle19 opened this issue Nov 27, 2023 · 4 comments
Open

BNN broadcastable error #28

sgpuzzle19 opened this issue Nov 27, 2023 · 4 comments

Comments

@sgpuzzle19
Copy link

sgpuzzle19 commented Nov 27, 2023

Hi,
I would be really glad if I could get some help. Thank you!
I am using Tyxe: Pyro model, according to which I have converted the fc layer to probabilistic layer. I am facing the broadcastable error :

Value is not broadcastable with batch_shape+event_shape:

 net.fc.2.bias dist      |   4        
              value          |   4        
           log_prob       |            
 likelihood.data dist      |   8        
                   value       8 |   4    

I am using the categorical likelihood!

Below I have added my model for reference:

model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
model.fc = nn.Sequential(
    nn.Linear(512, 50), 
    nn.Tanh(), 
    nn.Linear(50, 4)
)
prior = tyxe.priors.IIDPrior(dist.Normal(0, 1))
likelihood = tyxe.likelihoods.Categorical(len(training_Set),event_dim=1)
guide = tyxe.guides.AutoNormal
bnn = tyxe.VariationalBNN(model, prior, likelihood, guide)

lr = 1e-3
optimizer = pyro.optim.Adam({"lr": lr})
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
training_5epoch = []
pred_params_num_predict_1 = []

with tyxe.poutine.local_reparameterization():
    training_5epoch = bnn.fit(train_loader, optim=optimizer, num_epochs=5, callback=None, num_particles=1, closed_form_kl=True, device=device)
    pred_params_num_predict_1 = bnn.predict(test_images, num_predictions=1, aggregate=True, guide_traces=None)
@hpplyt
Copy link
Collaborator

hpplyt commented Nov 27, 2023

Hi! I think the issue is that you're giving your Categorical likelihood an event dim, if I remember correctly it should automatically account for the classes on the trailing dimension. Try changing to

likelihood = tyxe.likelihoods.Categorical(len(training_Set))

Hope that solves it! Otherwise I will have to have another look at this in more detail.

@sgpuzzle19
Copy link
Author

Hi,
I tried the way you suggested i.e.
likelihood = tyxe.likelihoods.Categorical(len(training_Set))
still I am getting the below error

Value is not broadcastable with batch_shape+event_shape: {actual_shape} vs {expected_shape}.

                   net.fc.2.bias dist     |   4        
                                value     |   4        
                             log_prob     |            
                 likelihood.data dist   8 |            
                                value 8 4 |            

@hpplyt
Copy link
Collaborator

hpplyt commented Nov 27, 2023

Are your target variables one-hot encoded? They need to be label encoded, i.e. 0, 1, 2, 3 for your four classes.

@sgpuzzle19
Copy link
Author

yes, my target variables are one-hot encoded. Below is an example of my tensor target labels and I have given batch_size = 8 in dataloader.

tensor([[0., 0., 1., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [0., 0., 1., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.]])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants