-
Notifications
You must be signed in to change notification settings - Fork 160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Evaluation of posterior fit #1023
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1023 +/- ##
==========================================
- Coverage 76.37% 76.36% -0.02%
==========================================
Files 84 84
Lines 6507 6512 +5
==========================================
+ Hits 4970 4973 +3
- Misses 1537 1539 +2
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
sbi/analysis/sensitivity_analysis.py
Outdated
@@ -501,3 +502,53 @@ def project(self, theta: Tensor, num_dimensions: int) -> Tensor: | |||
projected_theta = torch.mm(theta, projection_mat) | |||
|
|||
return projected_theta | |||
|
|||
|
|||
def posterior_shrinkage(prior_std, post_std): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, the function does not do much work.
Maybe change inputs to prior and posterior samples and do the estimation of the standard deviation within?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Type hints should also be here and not in the docstring.
sbi/analysis/sensitivity_analysis.py
Outdated
|
||
Parameters | ||
---------- | ||
prior_std : float, array-like |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typing should also be valid for tensors (i.e. vectors)
I guess one can define this metric also in a multivariate setting, just with the "total stddev" i.e. the sum.
sbi/analysis/sensitivity_analysis.py
Outdated
---------- | ||
prior_std : float, array-like | ||
The standard deviation of the prior distribution. | ||
post_std : float, array-like |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
sbi/analysis/sensitivity_analysis.py
Outdated
return 1 - (post_std / prior_std) ** 2 | ||
|
||
|
||
def posterior_zscore(true_mean, post_mean, post_std): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above i.e. type hints here.
And the metric should be suitable for vector input.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey. I left a few comments, let me know if something is unclear.
Not sure if it also is in the right folder, maybe move them to the metrics.py in utils?
@manuelgloeckler Thanks for the comments. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, looks good.
Just a few signature type hints are missing.
And I would stick with torch, if we do not have to use numpy. Our posteriors return torch.Tensors so lets also stick with it here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, thanks for implementing these changes.
Happy to merge it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this!
The test contain some repetitions and can be simplified with fixtures, but let's move this to an issue.
Add to function for evaluation of posterior fit.
Fixes issue #1017
Checklist
Put an
x
in the boxes that apply. You can also fill these out after creatingthe PR. If you're unsure about any of them, don't hesitate to ask. We're here to
help! This is simply a reminder of what we are going to look for before merging
your code.
guidelines
with
pytest.mark.slow
.guidelines
main
(or there are no conflicts withmain
)