-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Working version of coverage and all sbc metrics
- Loading branch information
Showing
16 changed files
with
262 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,8 @@ | ||
from metrics.all_sbc import AllSBC | ||
from metrics.coverage_fraction import CoverageFraction | ||
|
||
Metrics = { | ||
CoverageFraction.__name__: CoverageFraction, | ||
AllSBC.__name__: AllSBC | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from typing import Any | ||
from torch import tensor | ||
from sbi.analysis import run_sbc, check_sbc | ||
|
||
from metrics.metric import Metric | ||
from utils.config import get_item | ||
|
||
class AllSBC(Metric): | ||
def __init__(self, model: Any, data: Any, out_dir: str | None = None) -> None: | ||
super().__init__(model, data, out_dir) | ||
|
||
def _collect_data_params(self): | ||
self.thetas = tensor(self.data.theta_true()) | ||
self.y_true = tensor(self.data.x_true()) | ||
|
||
self.samples_per_inference = get_item( | ||
"metrics_common", "samples_per_inference", raise_exception=False | ||
) | ||
|
||
def calculate(self): | ||
ranks, dap_samples = run_sbc( | ||
self.thetas, self.y_true, self.model.posterior, num_posterior_samples=self.samples_per_inference | ||
) | ||
|
||
sbc_stats = check_sbc( | ||
ranks, self.thetas, dap_samples, | ||
num_posterior_samples=self.samples_per_inference | ||
) | ||
self.output = sbc_stats | ||
return sbc_stats | ||
|
||
def __call__(self, **kwds: Any) -> Any: | ||
self._collect_data_params() | ||
self.calculate() | ||
self._finish() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import numpy as np | ||
from typing import Any | ||
from tqdm import tqdm | ||
|
||
from metrics.metric import Metric | ||
from utils.config import get_item | ||
|
||
class CoverageFraction(Metric): | ||
""" | ||
""" | ||
def __init__(self, model: Any, data: Any, out_dir: str | None = None) -> None: | ||
super().__init__(model, data, out_dir) | ||
self._collect_data_params() | ||
|
||
def _collect_data_params(self): | ||
self.thetas = self.data.theta_true() | ||
self.y_true = self.data.x_true() | ||
|
||
self.samples_per_inference = get_item( | ||
"metrics_common", "samples_per_inference", raise_exception=False | ||
) | ||
self.percentiles = get_item( | ||
"metrics_common", "percentiles", raise_exception=False | ||
) | ||
|
||
def _run_model_inference(self, samples_per_inference, y_inference): | ||
samples = self.model.sample_posterior(samples_per_inference, y_inference) | ||
return samples | ||
|
||
def calculate(self): | ||
all_samples = np.empty( | ||
(len(self.y_true), self.samples_per_inference, np.shape(self.thetas)[1]) | ||
) | ||
count_array = [] | ||
iterator = enumerate(self.y_true) | ||
if get_item("metrics_common", "use_progress_bar", raise_exception=False): | ||
iterator = tqdm( | ||
iterator, | ||
desc="Sampling from the posterior for each observation", | ||
unit="observation" | ||
) | ||
for y_sample_index, y_sample in iterator: | ||
samples = self._run_model_inference(self.samples_per_inference, y_sample) | ||
all_samples[y_sample_index] = samples | ||
|
||
count_vector = [] | ||
# step through the percentile list | ||
for cov in self.percentiles: | ||
percentile_lower = 50.0 - cov / 2 | ||
percentile_upper = 50.0 + cov / 2 | ||
|
||
# find the percentile for the posterior for this observation | ||
# this is n_params dimensional | ||
# the units are in parameter space | ||
confidence_lower = np.percentile( | ||
samples.cpu(), | ||
percentile_lower, | ||
axis=0 | ||
) | ||
confidence_upper = np.percentile( | ||
samples.cpu(), | ||
percentile_upper, | ||
axis=0 | ||
) | ||
|
||
# this is asking if the true parameter value | ||
# is contained between the | ||
# upper and lower confidence intervals | ||
# checks separately for each side of the 50th percentile | ||
|
||
count = np.logical_and( | ||
confidence_upper - self.thetas[y_sample_index,:] > 0, | ||
self.thetas[y_sample_index,:] - confidence_lower > 0 | ||
) | ||
count_vector.append(count) | ||
# each time the above is > 0, adds a count | ||
count_array.append(count_vector) | ||
|
||
count_sum_array = np.sum(count_array, axis=0) | ||
frac_lens_within_vol = np.array(count_sum_array) | ||
coverage = frac_lens_within_vol / len(self.y_true) | ||
|
||
self.output = coverage | ||
|
||
return all_samples, coverage | ||
|
||
def __call__(self, **kwds: Any) -> Any: | ||
self.calculate() | ||
self._finish() | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.