Skip to content

Commit

Permalink
Merge pull request #60 from voetberg/tarp_plot
Browse files Browse the repository at this point in the history
Tarp plot
  • Loading branch information
bnord authored May 6, 2024
2 parents 3b5406f + 8d0a7ac commit 7345dcc
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 11 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ getdist = "^1.4.7"
h5py = "^3.10.0"
numpy = "^1.26.4"
matplotlib = "^3.8.3"
tarp = "^0.1.1"
deprecation = "^2.1.0"


[tool.poetry.group.dev.dependencies]
Expand Down
4 changes: 3 additions & 1 deletion src/plots/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from plots.cdf_ranks import CDFRanks
from plots.coverage_fraction import CoverageFraction
from plots.ranks import Ranks
from plots.tarp import TARP

Plots = {
CDFRanks.__name__: CDFRanks,
CoverageFraction.__name__: CoverageFraction,
Ranks.__name__: Ranks
Ranks.__name__: Ranks,
TARP.__name__: TARP
}
103 changes: 103 additions & 0 deletions src/plots/tarp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from typing import Optional, Union
from torch import tensor
import numpy as np
import tarp

import matplotlib.pyplot as plt
import matplotlib.colors as plt_colors

from plots.plot import Display
from utils.config import get_item

class TARP(Display):
def __init__(self, model, data, save: bool, show: bool, out_dir: str | None = None):
super().__init__(model, data, save, show, out_dir)

def _plot_name(self):
return "tarp.png"

def _data_setup(self):
self.rng = np.random.default_rng(get_item("common", "random_seed", raise_exception=False))
samples_per_inference = get_item(
"metrics_common", "samples_per_inference", raise_exception=False
)
num_simulations = get_item("metrics_common", "number_simulations", raise_exception=False)

n_dims = self.data.theta_true().shape[1]
self.posterior_samples = np.zeros((num_simulations, samples_per_inference, n_dims))
self.thetas = np.zeros((num_simulations, n_dims))
for n in range(num_simulations):
sample_index = self.rng.integers(0, len(self.data.theta_true()))

theta = self.data.theta_true()[sample_index,:]
x = self.data.x_true()[sample_index,:]
self.posterior_samples[n] = self.model.sample_posterior(samples_per_inference, x)
self.thetas[n] = theta

self.posterior_samples = np.swapaxes(self.posterior_samples, 0,1)
def _plot_settings(self):
self.line_style = get_item("plots_common", "line_style_cycle", raise_exception=False)


def _get_hex_sigma_colors(self, n_colors, colorway=None):

if colorway is None:
colorway = get_item("plots_common", "default_colorway", raise_exception=False)

cmap = plt.cm.get_cmap(colorway)
hex_colors = []
arr=np.linspace(0,1, n_colors)
for hit in arr:
hex_colors.append(plt_colors.rgb2hex(cmap(hit)))

return hex_colors

def _plot(
self,
coverage_sigma:int = 3,
reference_point:Union[str, np.ndarray]='random',
metric:bool="euclidean",
normalize:bool=True,
bootstrap_calculation:bool=True,
coverage_colorway:Optional[str]=None,
coverage_alpha:float=0.2,
y_label:str="Expected Coverage",
x_label:str="Expected Coverage",
title:str='Test of Accuracy with Random Points'
):

coverage_probability, credibility = tarp.get_tarp_coverage(
self.posterior_samples,
self.thetas,
references=reference_point,
metric = metric,
norm = normalize,
bootstrap=bootstrap_calculation
)
figure_size = get_item("plots_common", "figure_size", raise_exception=False)
k_sigma = range(1,coverage_sigma+1)
_, ax = plt.subplots(1, 1, figsize=figure_size)

ax.plot([0, 1], [0, 1], ls=self.line_style[0], color='k', label="Ideal")
ax.plot(
credibility,
coverage_probability.mean(axis=0),
ls=self.line_style[-1],
label='TARP')

k_sigma = range(1,coverage_sigma+1)
colors = self._get_hex_sigma_colors(coverage_sigma, colorway=coverage_colorway)
for sigma, color in zip(k_sigma, colors):
ax.fill_between(
credibility,
coverage_probability.mean(axis=0) - sigma * coverage_probability.std(axis=0),
coverage_probability.mean(axis=0) + sigma * coverage_probability.std(axis=0),
alpha = coverage_alpha,
color=color
)

ax.legend()
ax.set_ylabel(y_label)
ax.set_xlabel(x_label)
ax.set_title(title)

11 changes: 8 additions & 3 deletions src/utils/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
"common":{
"out_dir":"./DeepDiagnosticsResources/results/",
"temp_config": "./DeepDiagnosticsResources/temp/temp_config.yml",
"sim_location": "DeepDiagnosticsResources_Simulators"
"sim_location": "DeepDiagnosticsResources_Simulators",
"random_seed":42
},
"model": {
"model_engine": "SBIModel"
Expand All @@ -23,12 +24,16 @@
"plots":{
"CDFRanks":{},
"Ranks":{"num_bins":None},
"CoverageFraction":{}
"CoverageFraction":{},
"TARP":{
"coverage_sigma":3 # How many sigma to show coverage over
}
},
"metrics_common": {
"use_progress_bar": False,
"samples_per_inference":1000,
"percentiles":[75, 85, 95]
"percentiles":[75, 85, 95],
"number_simulations": 50
},
"metrics":{
"AllSBC":{},
Expand Down
16 changes: 9 additions & 7 deletions tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,17 @@
Plots,
CDFRanks,
Ranks,
CoverageFraction
CoverageFraction,
TARP
)

@pytest.fixture
def plot_config(config_factory):
out_dir = "./temp_results/"
metrics_settings={"use_progress_bar":False, "samples_per_inference":10, "percentiles":[95]}
config = config_factory(out_dir=out_dir, metrics_settings=metrics_settings)
return config
Config(config)


def test_all_plot_catalogued():
'''Each metrics gets its own file, and each metric is included in the Metrics dictionary
Expand All @@ -32,25 +34,25 @@ def test_all_defaults(plot_config, mock_model, mock_data):
Ensures each metric has a default set of parameters and is included in the defaults list
Ensures each test can initialize, regardless of the veracity of the output
"""
Config(plot_config)
for plot_name, plot_obj in Plots.items():
assert plot_name in Defaults['plots']
plot_obj(mock_model, mock_data, save=True, show=False)

def test_plot_cdf(plot_config, mock_model, mock_data):
Config(plot_config)
plot = CDFRanks(mock_model, mock_data, save=True, show=False)
plot(**get_item("plots", "CDFRanks", raise_exception=False))
assert os.path.exists(f"{plot.out_path}/{plot.plot_name}")

def test_plot_ranks(plot_config, mock_model, mock_data):
Config(plot_config)
plot = Ranks(mock_model, mock_data, save=True, show=False)
plot(**get_item("plots", "Ranks", raise_exception=False))
assert os.path.exists(f"{plot.out_path}/{plot.plot_name}")

def test_plot_coverage(plot_config, mock_model, mock_data):
Config(plot_config)
plot = CoverageFraction(mock_model, mock_data, save=True, show=False)
plot(**get_item("plots", "CoverageFraction", raise_exception=False))
assert os.path.exists(f"{plot.out_path}/{plot.plot_name}")
assert os.path.exists(f"{plot.out_path}/{plot.plot_name}")

def test_plot_tarp(plot_config, mock_model, mock_data):
plot = TARP(mock_model, mock_data, save=True, show=False)
plot(**get_item("plots", "TARP", raise_exception=False))

0 comments on commit 7345dcc

Please sign in to comment.