diff --git a/tests/neural_nets_factory.py b/tests/neural_nets_factory.py deleted file mode 100644 index af8c4d4a1..000000000 --- a/tests/neural_nets_factory.py +++ /dev/null @@ -1,44 +0,0 @@ -import pytest - -from sbi.utils.get_nn_models import classifier_nn, likelihood_nn, posterior_nn - -models_to_test = [ - "mdn", - "made", - "maf", - "maf_rqs", - "nsf", - "mnle", - "zuko_bpf", - "zuko_gf", - "zuko_maf", - "zuko_naf", - "zuko_ncsf", - "zuko_nice", - "zuko_nsf", - "zuko_sospf", - "zuko_unaf", -] - - -@pytest.mark.parametrize( - "model", ["linear", "mlp", "resnet"], ids=["linear", "mlp", "resnet"] -) -def test_deprecated_import_classifier_nn(model: str): - with pytest.warns(DeprecationWarning): - build_fcn = classifier_nn(model) - assert callable(build_fcn) - - -@pytest.mark.parametrize("model", models_to_test, ids=models_to_test) -def test_deprecated_import_likelihood_nn(model: str): - with pytest.warns(DeprecationWarning): - build_fcn = likelihood_nn(model) - assert callable(build_fcn) - - -@pytest.mark.parametrize("model", models_to_test, ids=models_to_test) -def test_deprecated_import_posterior_nn(model: str): - with pytest.warns(DeprecationWarning): - build_fcn = posterior_nn(model) - assert callable(build_fcn)