Skip to content

Commit

Permalink
changing import statments to exclude src
Browse files Browse the repository at this point in the history
  • Loading branch information
beckynevin committed Feb 1, 2024
1 parent 3703b57 commit de99cc0
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 15 deletions.
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
[tool.poetry]
name = "deeptemplate-science"
name = "DeepDiagnostics"
packages = [{include = "*", from="src"}]
version = "0.1.0"
description = "a template for a science motivated project, focus on ease of reproducablity"
authors = ["voetberg <[email protected]>"]
description = "a package for diagnosing posterior quality from inference methods"
authors = ["Becky Nevin <[email protected]>"]
license = "MIT"

[tool.poetry.dependencies]
Expand Down
4 changes: 2 additions & 2 deletions src/scripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Includes utilities for posterior diagnostics as well as some
inference functions.
"""
from src.scripts.io import ModelLoader
from scripts.io import ModelLoader

import argparse
from sbi.analysis import run_sbc, sbc_rank_plot, check_sbc, pairplot
Expand Down Expand Up @@ -62,7 +62,7 @@ def improved_corner_plot(self, posterior):
"""


class Diagnose_on_the_fly_data_generation:
class Diagnose_generative:
def posterior_predictive(self,
theta_true,
x_true,
Expand Down
26 changes: 16 additions & 10 deletions tests/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

# flake8: noqa
#sys.path.append("..")
from src.scripts.evaluate import Diagnose, InferenceModel
print(sys.path)
from scripts.evaluate import Diagnose_static, Diagnose_generative
from scripts.io import ModelLoader
#from src.scripts import evaluate


Expand All @@ -21,16 +23,20 @@


@pytest.fixture
def diagnose_instance():
return Diagnose()
def diagnose_static_instance():
return Diagnose_static()

@pytest.fixture
def diagnose_generative_instance():
return Diagnose_generative()


@pytest.fixture
def inference_instance():
inference_model = InferenceModel()
modelloader = ModelLoader()
path = "savedmodels/sbi/"
model_name = "sbi_linear"
posterior = inference_model.load_model_pkl(path, model_name)
model_name = "sbi_linear_from_data"
posterior = modelloader.load_model_pkl(path, model_name)
return posterior


Expand Down Expand Up @@ -68,7 +74,7 @@ def simulator(thetas): # , percent_errors):
return torch.Tensor(y.T)


def test_generate_sbc_samples(diagnose_instance, inference_instance):
def test_generate_sbc_samples(diagnose_generative_instance, inference_instance):
# Mock data
low_bounds = torch.tensor([0, -10])
high_bounds = torch.tensor([10, 10])
Expand All @@ -80,14 +86,14 @@ def test_generate_sbc_samples(diagnose_instance, inference_instance):
num_posterior_samples = 1000

# Generate SBC samples
thetas, ys, ranks, dap_samples = diagnose_instance.generate_sbc_samples(
thetas, ys, ranks, dap_samples = diagnose_generative_instance.generate_sbc_samples(
prior, posterior, simulator_test, num_sbc_runs, num_posterior_samples
)

# Add assertions based on the expected behavior of the method


def test_run_all_sbc(diagnose_instance, inference_instance):
def test_run_all_sbc(diagnose_generative_instance, inference_instance):
labels_list = ["$m$", "$b$"]
colorlist = ["#9C92A3", "#0F5257"]
low_bounds = torch.tensor([0, -10])
Expand All @@ -99,7 +105,7 @@ def test_run_all_sbc(diagnose_instance, inference_instance):

save_path = "plots/"

diagnose_instance.run_all_sbc(
diagnose_generative_instance.run_all_sbc(
prior,
posterior,
simulator_test,
Expand Down

0 comments on commit de99cc0

Please sign in to comment.