Sampling from posterior with 3-D tensor #1288
-
Good evening. I'm working on a model with simulated data stored in a 3-D tensor and a last dimension of 3. I have successfully setup SNPE inference and amortized trained with an embedding net. The next thing I wish to do is sample from this posterior using "posterior.sample((sampleSize,), x=tensor)", where tensor is the simulated data with last dimension of 3. However, I get a value error with a very long message. Attached is a screenshot of the whole error message. What, if anything, can I do to this tensor to fix my issue? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
Hi there, thanks for reporting this! I tried reproducing the issue but all seems fine to me. Could you test whether the following works for you: import torch
from torch import ones, zeros, eye, float32, tensor, as_tensor, randn, rand, randn_like
from torch.nn import Module
from torch import nn
from sbi.inference import NPE
from sbi.utils import BoxUniform
from sbi.neural_nets import posterior_nn
prior = BoxUniform(-ones(3), ones(3))
theta = prior.sample((100,))
x = torch.stack([theta, theta, theta], axis=1)
x = torch.stack([x, x, x], axis=1)
x += randn_like(x) * 0.1
print(f"theta.shape {theta.shape}") # torch.Size([100, 3])
print(f"x.shape {x.shape}") # torch.Size([100, 3, 3, 3])
x_o = x[0] # x[:1] also works.
print(f"x_o.shape {x_o.shape}") # torch.Size([3, 3, 3])
class Embedding3D(Module):
def __init__(self):
super().__init__()
self.net = nn.Linear(3**3, 20)
def __call__(self, x):
return self.net(torch.reshape(x, (-1, 3**3)))
embedding_net = Embedding3D()
density_estimator = posterior_nn("maf", embedding_net=embedding_net)
inference = NPE(prior, density_estimator=density_estimator)
_ = inference.append_simulations(theta, x).train()
posterior = inference.build_posterior()
samples = posterior.sample((100,), x=x_o)
print(f"Posterior samples.shape: {samples.shape}") # torch.Size([100, 3]) If this works, could you share a minimal example to reproduce your issue? Thanks! |
Beta Was this translation helpful? Give feedback.
Hi there,
thanks for reporting this! I tried reproducing the issue but all seems fine to me. Could you test whether the following works for you: