Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap 1D PyTorch distributions #1286

Merged
merged 1 commit into from
Sep 26, 2024
Merged

Wrap 1D PyTorch distributions #1286

merged 1 commit into from
Sep 26, 2024

Conversation

michaeldeistler
Copy link
Contributor

@michaeldeistler michaeldeistler commented Sep 24, 2024

Closes #1285 an #1284

1D pytorch distributions such as torch.distributions.Exponential, .Uniform, or .Normal do not, by default return any sample or batch dimension. E.g.:

dist = torch.distributions.Exponential(torch.tensor(3.0))
dist.sample((10,)).shape  # (10,)

sbi will raise an error that the sample dimension is missing. A simple solution is to add a batch dimension to dist as follows:

dist = torch.distributions.Exponential(torch.tensor([3.0]))
dist.sample((10,)).shape  # (10, 1)

Unfortunately, this dist will return the batch dimension also for `.log_prob():

dist = torch.distributions.Exponential(torch.tensor([3.0]))
samples = dist.sample((10,))
dist.log_prob(samples).shape  # (10, 1)

This will lead to unexpected errors in sbi. The point of this PR is to wrap those batched 1D distributions to get rid of their batch dimension in .log_prob().

Copy link

codecov bot commented Sep 24, 2024

Codecov Report

Attention: Patch coverage is 86.36364% with 3 lines in your changes missing coverage. Please review.

Project coverage is 78.36%. Comparing base (299854e) to head (34b081e).
Report is 8 commits behind head on main.

Files with missing lines Patch % Lines
sbi/utils/user_input_checks_utils.py 85.00% 3 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1286       +/-   ##
===========================================
- Coverage   89.24%   78.36%   -10.89%     
===========================================
  Files         119      119               
  Lines        8695     8727       +32     
===========================================
- Hits         7760     6839      -921     
- Misses        935     1888      +953     
Flag Coverage Δ
unittests 78.36% <86.36%> (-10.89%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
sbi/utils/user_input_checks.py 76.56% <100.00%> (+0.24%) ⬆️
sbi/utils/user_input_checks_utils.py 89.17% <85.00%> (-0.61%) ⬇️

... and 33 files with indirect coverage changes

Copy link
Contributor

@gmoss13 gmoss13 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @michaeldeistler! I left a couple of minor comments, please have a look but I am happy to merge this when the pyright issue is fixed. By the way, I think you meant that this PR addresses #1283 instead of #1284 :)

sbi/utils/user_input_checks.py Outdated Show resolved Hide resolved
sbi/utils/user_input_checks_utils.py Show resolved Hide resolved
tests/user_input_checks_test.py Show resolved Hide resolved
@michaeldeistler michaeldeistler merged commit 1d4ee7a into main Sep 26, 2024
6 checks passed
@michaeldeistler michaeldeistler deleted the priorfix branch September 26, 2024 01:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Changing the density estimator for SNPE gives the same error but randomly
2 participants