Skip to content

Commit

Permalink
Working version of coverage and all sbc metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
voetberg committed Apr 4, 2024
1 parent d9ed3b0 commit da2af98
Show file tree
Hide file tree
Showing 16 changed files with 262 additions and 23 deletions.
1 change: 0 additions & 1 deletion src/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@

from data.h5_data import H5Data
from data.pickle_data import PickleData
from data.simulator import Simulator

DataModules = {
"H5Data": H5Data,
Expand Down
22 changes: 21 additions & 1 deletion src/data/h5_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,24 @@ def save(self, data:dict[str, Any], path: str): # Todo typing for data dict
with h5py.File(path, "w") as file:
# Save each array as a dataset in the HDF5 file
for key, value in data_arrays.items():
file.create_dataset(key, data=value)
file.create_dataset(key, data=value)

def x_true(self):
# From Data
return self.data['xs']

def y_true(self):
return self.simulator(self.theta_true(), self.x_true())

def prior(self):
# From Data
raise NotImplementedError

def theta_true(self):
return self.data['thetas']

def sigma_true(self):
try:
return super().sigma_true()
except (AssertionError, KeyError):
return 1
5 changes: 5 additions & 0 deletions src/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from metrics.all_sbc import AllSBC
from metrics.coverage_fraction import CoverageFraction

Metrics = {
CoverageFraction.__name__: CoverageFraction,
AllSBC.__name__: AllSBC

}
35 changes: 35 additions & 0 deletions src/metrics/all_sbc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Any
from torch import tensor
from sbi.analysis import run_sbc, check_sbc

from metrics.metric import Metric
from utils.config import get_item

class AllSBC(Metric):
def __init__(self, model: Any, data: Any, out_dir: str | None = None) -> None:
super().__init__(model, data, out_dir)

def _collect_data_params(self):
self.thetas = tensor(self.data.theta_true())
self.y_true = tensor(self.data.x_true())

self.samples_per_inference = get_item(
"metrics_common", "samples_per_inference", raise_exception=False
)

def calculate(self):
ranks, dap_samples = run_sbc(
self.thetas, self.y_true, self.model.posterior, num_posterior_samples=self.samples_per_inference
)

sbc_stats = check_sbc(
ranks, self.thetas, dap_samples,
num_posterior_samples=self.samples_per_inference
)
self.output = sbc_stats
return sbc_stats

def __call__(self, **kwds: Any) -> Any:
self._collect_data_params()
self.calculate()
self._finish()
92 changes: 92 additions & 0 deletions src/metrics/coverage_fraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import numpy as np
from typing import Any
from tqdm import tqdm

from metrics.metric import Metric
from utils.config import get_item

class CoverageFraction(Metric):
"""
"""
def __init__(self, model: Any, data: Any, out_dir: str | None = None) -> None:
super().__init__(model, data, out_dir)
self._collect_data_params()

def _collect_data_params(self):
self.thetas = self.data.theta_true()
self.y_true = self.data.x_true()

self.samples_per_inference = get_item(
"metrics_common", "samples_per_inference", raise_exception=False
)
self.percentiles = get_item(
"metrics_common", "percentiles", raise_exception=False
)

def _run_model_inference(self, samples_per_inference, y_inference):
samples = self.model.sample_posterior(samples_per_inference, y_inference)
return samples

def calculate(self):
all_samples = np.empty(
(len(self.y_true), self.samples_per_inference, np.shape(self.thetas)[1])
)
count_array = []
iterator = enumerate(self.y_true)
if get_item("metrics_common", "use_progress_bar", raise_exception=False):
iterator = tqdm(
iterator,
desc="Sampling from the posterior for each observation",
unit="observation"
)
for y_sample_index, y_sample in iterator:
samples = self._run_model_inference(self.samples_per_inference, y_sample)
all_samples[y_sample_index] = samples

count_vector = []
# step through the percentile list
for cov in self.percentiles:
percentile_lower = 50.0 - cov / 2
percentile_upper = 50.0 + cov / 2

# find the percentile for the posterior for this observation
# this is n_params dimensional
# the units are in parameter space
confidence_lower = np.percentile(
samples.cpu(),
percentile_lower,
axis=0
)
confidence_upper = np.percentile(
samples.cpu(),
percentile_upper,
axis=0
)

# this is asking if the true parameter value
# is contained between the
# upper and lower confidence intervals
# checks separately for each side of the 50th percentile

count = np.logical_and(
confidence_upper - self.thetas[y_sample_index,:] > 0,
self.thetas[y_sample_index,:] - confidence_lower > 0
)
count_vector.append(count)
# each time the above is > 0, adds a count
count_array.append(count_vector)

count_sum_array = np.sum(count_array, axis=0)
frac_lens_within_vol = np.array(count_sum_array)
coverage = frac_lens_within_vol / len(self.y_true)

self.output = coverage

return all_samples, coverage

def __call__(self, **kwds: Any) -> Any:
self.calculate()
self._finish()


4 changes: 2 additions & 2 deletions src/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ def _finish(self):
json.dump(data, f, ensure_ascii=True)
f.close()

def __call__(self, *args: Any, **kwds: Any) -> Any:
def __call__(self,**kwds: Any) -> Any:
self._collect_data_params()
self._run_model_inference()
self.calculate()
self.calculate(kwds)
self._finish()


8 changes: 6 additions & 2 deletions src/models/sbi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@ def _load(self, path:str) -> None:
posterior = pickle.load(file)
self.posterior = posterior

def sample_posterior(self, n_samples:int, data): # TODO typing
return self.posterior.sample((n_samples,), x=data.y_true)
def sample_posterior(self, n_samples:int, y_true): # TODO typing
return self.posterior.sample(
(n_samples,),
x=y_true,
show_progress_bars=False
).cpu()

def predict_posterior(self, data):
posterior_samples = self.sample_posterior(data.y_true)
Expand Down
4 changes: 2 additions & 2 deletions src/plots/plot.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import os
from typing import Any, Optional
from typing import Optional
import matplotlib.pyplot as plt
from matplotlib import rcParams

from utils.config import get_item, get_section
from utils.config import get_section

class Display:
def __init__(self, model, data, save:bool, show:bool, out_path:Optional[str]):
Expand Down
2 changes: 1 addition & 1 deletion src/scripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def sbc_statistics(self,
num_posterior_samples=num_posterior_samples
)
return check_stats

def plot_1d_ranks(
self,
ranks,
Expand Down
2 changes: 1 addition & 1 deletion src/scripts/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
matplotlib.rcParams["axes.spines.right"] = False
matplotlib.rcParams["axes.spines.top"] = False


class Display:
def mackelab_corner_plot(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def get_section(section, raise_exception=True):

class Config:
ENV_VAR_PATH = "DeepDiagnostics_Config"
def __init__(self, config_path:Optional[str]) -> None:
def __init__(self, config_path:Optional[str]=None) -> None:
if config_path is not None:
# Add it to the env vars in case we need to get it later.
os.environ[self.ENV_VAR_PATH] = config_path
Expand Down
10 changes: 7 additions & 3 deletions src/utils/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"data":{
"data_engine": "H5Data"
},
"plot_common": {
"plots_common": {
"axis_spines": False,
"tight_layout": True,
"colorway": "virdids",
Expand All @@ -19,10 +19,14 @@
"plots":{
"type_of_plot":{"specific_kwargs"}
},
"metric_common": {
"metrics_common": {
"use_progress_bar": False,
"samples_per_inference":1000,
"percentiles":[75, 85, 95]

},
"metrics":{
"type_of_metrics":{"specific_kwargs"}
"AllSBC":{},
"CoverageFraction": {},
}
}
38 changes: 33 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,44 @@
import pytest
import yaml
import numpy as np

from data import H5Data, Simulator
from data import H5Data
from data.simulator import Simulator
from models import SBIModel
from utils.register import register_simulator


class MockSimulator(Simulator):
def __init__(self):
pass
def __call__(self, thetas):
return thetas

def __call__(self, thetas, samples):
thetas = np.atleast_2d(thetas)
# Check if the input has the correct shape
if thetas.shape[1] != 2:
raise ValueError("Input tensor must have shape (n, 2) where n is the number of parameter sets.")

# Unpack the parameters
if thetas.shape[0] == 1:
# If there's only one set of parameters, extract them directly
m, b = thetas[0, 0], thetas[0, 1]
else:
# If there are multiple sets of parameters, extract them for each row
m, b = thetas[:, 0], thetas[:, 1]
x = np.linspace(0, 100, samples)
rs = np.random.RandomState()#2147483648)#
# I'm thinking sigma could actually be a function of x
# if we want to get fancy down the road
# Generate random noise (epsilon) based on a normal distribution with mean 0 and standard deviation sigma
sigma = 1
epsilon = rs.normal(loc=0, scale=sigma, size=(len(x), thetas.shape[0]))

# Initialize an empty array to store the results for each set of parameters
y = np.zeros((len(x), thetas.shape[0]))
for i in range(thetas.shape[0]):
m, b = thetas[i, 0], thetas[i, 1]
y[:, i] = m * x + b + epsilon[:, i]
return y.T


@pytest.fixture
Expand Down Expand Up @@ -67,9 +95,9 @@ def factory(

# Dict settings
if plot_settings is not None:
config['plot_common'] = plot_settings
config['plots_common'] = plot_settings
if metrics_settings is not None:
config['metric_common'] = metrics_settings
config['metrics_common'] = metrics_settings

if metrics is not None:
if isinstance(metrics, dict):
Expand Down
7 changes: 3 additions & 4 deletions tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pytest
import subprocess
import os

Expand All @@ -14,7 +13,7 @@ def test_parser_args(model_path, data_path, simulator_name):

def test_parser_config(config_factory, model_path, data_path, simulator_name):
config_path = config_factory(model_path=model_path, data_path=data_path, simulator=simulator_name)
command = ["diagnose", f"--config", config_path]
command = ["diagnose", "--config", config_path]
process = subprocess.run(command)
exit_code = process.returncode
assert exit_code == 0
Expand All @@ -23,7 +22,7 @@ def test_parser_config(config_factory, model_path, data_path, simulator_name):
def test_main_no_methods(config_factory, model_path, data_path, simulator_name):
out_dir = "./test_out_dir/"
config_path = config_factory(model_path=model_path, data_path=data_path, simulator=simulator_name, plots=[], metrics=[], out_dir=out_dir)
command = ["diagnose", f"--config", config_path]
command = ["diagnose", "--config", config_path]
process = subprocess.run(command)
exit_code = process.returncode
assert exit_code == 0
Expand All @@ -34,7 +33,7 @@ def test_main_no_methods(config_factory, model_path, data_path, simulator_name):

def test_main_missing_config():
config_path = "there_is_no_config_at_this_path.yml"
command = ["diagnose", f"--config", config_path]
command = ["diagnose", "--config", config_path]
process = subprocess.run(command)
exit_code = process.returncode
assert exit_code == 1
Expand Down
Loading

0 comments on commit da2af98

Please sign in to comment.