diff --git a/pyproject.toml b/pyproject.toml index 2114fd6..eea7716 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,8 @@ getdist = "^1.4.7" h5py = "^3.10.0" numpy = "^1.26.4" matplotlib = "^3.8.3" +tarp = "^0.1.1" +deprecation = "^2.1.0" [tool.poetry.group.dev.dependencies] diff --git a/src/plots/__init__.py b/src/plots/__init__.py index 7ad7227..1d79d45 100644 --- a/src/plots/__init__.py +++ b/src/plots/__init__.py @@ -1,9 +1,11 @@ from plots.cdf_ranks import CDFRanks from plots.coverage_fraction import CoverageFraction from plots.ranks import Ranks +from plots.tarp import TARP Plots = { CDFRanks.__name__: CDFRanks, CoverageFraction.__name__: CoverageFraction, - Ranks.__name__: Ranks + Ranks.__name__: Ranks, + TARP.__name__: TARP } \ No newline at end of file diff --git a/src/plots/tarp.py b/src/plots/tarp.py new file mode 100644 index 0000000..e64ee41 --- /dev/null +++ b/src/plots/tarp.py @@ -0,0 +1,103 @@ +from typing import Optional, Union +from torch import tensor +import numpy as np +import tarp + +import matplotlib.pyplot as plt +import matplotlib.colors as plt_colors + +from plots.plot import Display +from utils.config import get_item + +class TARP(Display): + def __init__(self, model, data, save: bool, show: bool, out_dir: str | None = None): + super().__init__(model, data, save, show, out_dir) + + def _plot_name(self): + return "tarp.png" + + def _data_setup(self): + self.rng = np.random.default_rng(get_item("common", "random_seed", raise_exception=False)) + samples_per_inference = get_item( + "metrics_common", "samples_per_inference", raise_exception=False + ) + num_simulations = get_item("metrics_common", "number_simulations", raise_exception=False) + + n_dims = self.data.theta_true().shape[1] + self.posterior_samples = np.zeros((num_simulations, samples_per_inference, n_dims)) + self.thetas = np.zeros((num_simulations, n_dims)) + for n in range(num_simulations): + sample_index = self.rng.integers(0, len(self.data.theta_true())) + + theta = self.data.theta_true()[sample_index,:] + x = self.data.x_true()[sample_index,:] + self.posterior_samples[n] = self.model.sample_posterior(samples_per_inference, x) + self.thetas[n] = theta + + self.posterior_samples = np.swapaxes(self.posterior_samples, 0,1) + def _plot_settings(self): + self.line_style = get_item("plots_common", "line_style_cycle", raise_exception=False) + + + def _get_hex_sigma_colors(self, n_colors, colorway=None): + + if colorway is None: + colorway = get_item("plots_common", "default_colorway", raise_exception=False) + + cmap = plt.cm.get_cmap(colorway) + hex_colors = [] + arr=np.linspace(0,1, n_colors) + for hit in arr: + hex_colors.append(plt_colors.rgb2hex(cmap(hit))) + + return hex_colors + + def _plot( + self, + coverage_sigma:int = 3, + reference_point:Union[str, np.ndarray]='random', + metric:bool="euclidean", + normalize:bool=True, + bootstrap_calculation:bool=True, + coverage_colorway:Optional[str]=None, + coverage_alpha:float=0.2, + y_label:str="Expected Coverage", + x_label:str="Expected Coverage", + title:str='Test of Accuracy with Random Points' + ): + + coverage_probability, credibility = tarp.get_tarp_coverage( + self.posterior_samples, + self.thetas, + references=reference_point, + metric = metric, + norm = normalize, + bootstrap=bootstrap_calculation + ) + figure_size = get_item("plots_common", "figure_size", raise_exception=False) + k_sigma = range(1,coverage_sigma+1) + _, ax = plt.subplots(1, 1, figsize=figure_size) + + ax.plot([0, 1], [0, 1], ls=self.line_style[0], color='k', label="Ideal") + ax.plot( + credibility, + coverage_probability.mean(axis=0), + ls=self.line_style[-1], + label='TARP') + + k_sigma = range(1,coverage_sigma+1) + colors = self._get_hex_sigma_colors(coverage_sigma, colorway=coverage_colorway) + for sigma, color in zip(k_sigma, colors): + ax.fill_between( + credibility, + coverage_probability.mean(axis=0) - sigma * coverage_probability.std(axis=0), + coverage_probability.mean(axis=0) + sigma * coverage_probability.std(axis=0), + alpha = coverage_alpha, + color=color + ) + + ax.legend() + ax.set_ylabel(y_label) + ax.set_xlabel(x_label) + ax.set_title(title) + \ No newline at end of file diff --git a/src/utils/defaults.py b/src/utils/defaults.py index 0f511ac..6722a03 100644 --- a/src/utils/defaults.py +++ b/src/utils/defaults.py @@ -2,7 +2,8 @@ "common":{ "out_dir":"./DeepDiagnosticsResources/results/", "temp_config": "./DeepDiagnosticsResources/temp/temp_config.yml", - "sim_location": "DeepDiagnosticsResources_Simulators" + "sim_location": "DeepDiagnosticsResources_Simulators", + "random_seed":42 }, "model": { "model_engine": "SBIModel" @@ -23,12 +24,16 @@ "plots":{ "CDFRanks":{}, "Ranks":{"num_bins":None}, - "CoverageFraction":{} + "CoverageFraction":{}, + "TARP":{ + "coverage_sigma":3 # How many sigma to show coverage over + } }, "metrics_common": { "use_progress_bar": False, "samples_per_inference":1000, - "percentiles":[75, 85, 95] + "percentiles":[75, 85, 95], + "number_simulations": 50 }, "metrics":{ "AllSBC":{}, diff --git a/tests/test_plots.py b/tests/test_plots.py index e19bdc5..253343b 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -7,7 +7,8 @@ Plots, CDFRanks, Ranks, - CoverageFraction + CoverageFraction, + TARP ) @pytest.fixture @@ -15,7 +16,8 @@ def plot_config(config_factory): out_dir = "./temp_results/" metrics_settings={"use_progress_bar":False, "samples_per_inference":10, "percentiles":[95]} config = config_factory(out_dir=out_dir, metrics_settings=metrics_settings) - return config + Config(config) + def test_all_plot_catalogued(): '''Each metrics gets its own file, and each metric is included in the Metrics dictionary @@ -32,25 +34,25 @@ def test_all_defaults(plot_config, mock_model, mock_data): Ensures each metric has a default set of parameters and is included in the defaults list Ensures each test can initialize, regardless of the veracity of the output """ - Config(plot_config) for plot_name, plot_obj in Plots.items(): assert plot_name in Defaults['plots'] plot_obj(mock_model, mock_data, save=True, show=False) def test_plot_cdf(plot_config, mock_model, mock_data): - Config(plot_config) plot = CDFRanks(mock_model, mock_data, save=True, show=False) plot(**get_item("plots", "CDFRanks", raise_exception=False)) assert os.path.exists(f"{plot.out_path}/{plot.plot_name}") def test_plot_ranks(plot_config, mock_model, mock_data): - Config(plot_config) plot = Ranks(mock_model, mock_data, save=True, show=False) plot(**get_item("plots", "Ranks", raise_exception=False)) assert os.path.exists(f"{plot.out_path}/{plot.plot_name}") def test_plot_coverage(plot_config, mock_model, mock_data): - Config(plot_config) plot = CoverageFraction(mock_model, mock_data, save=True, show=False) plot(**get_item("plots", "CoverageFraction", raise_exception=False)) - assert os.path.exists(f"{plot.out_path}/{plot.plot_name}") \ No newline at end of file + assert os.path.exists(f"{plot.out_path}/{plot.plot_name}") + +def test_plot_tarp(plot_config, mock_model, mock_data): + plot = TARP(mock_model, mock_data, save=True, show=False) + plot(**get_item("plots", "TARP", raise_exception=False)) \ No newline at end of file