From b0af57d7c329e663dcd0046be94d73bd6e909660 Mon Sep 17 00:00:00 2001 From: Laurent Sorber Date: Fri, 19 Apr 2024 15:39:37 +0200 Subject: [PATCH] test: cover optional dependencies (#17) --- src/conformal_tights/_darts_forecaster.py | 4 +-- tests/test_optional_dependencies.py | 30 +++++++++++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) create mode 100644 tests/test_optional_dependencies.py diff --git a/src/conformal_tights/_darts_forecaster.py b/src/conformal_tights/_darts_forecaster.py index a997bf7..afe2ac0 100644 --- a/src/conformal_tights/_darts_forecaster.py +++ b/src/conformal_tights/_darts_forecaster.py @@ -35,8 +35,8 @@ _LikelihoodMixin, ) except ImportError: - FUTURE_LAGS_TYPE = int - LAGS_TYPE = list[int] + FUTURE_LAGS_TYPE = tuple[int, int] | list[int] | dict[str, tuple[int, int] | list[int]] + LAGS_TYPE = int | list[int] | dict[str, int | list[int]] class TimeSeries: ... diff --git a/tests/test_optional_dependencies.py b/tests/test_optional_dependencies.py new file mode 100644 index 0000000..616a419 --- /dev/null +++ b/tests/test_optional_dependencies.py @@ -0,0 +1,30 @@ +"""Test this package's optional dependencies.""" + +import sys +from unittest.mock import patch + +import pytest + + +@pytest.mark.parametrize("optional_dependency", ["darts", "pandas"]) +def test_optional_dependencies(optional_dependency: str) -> None: + """Test that we get an expected error when an optional dependency are not available.""" + # Prevent the optional dependency from being loaded. + with patch.dict("sys.modules", {optional_dependency: None}): + # Unload Conformal Tights. + mods_to_unload = [mod for mod in sys.modules if mod.startswith("conformal_tights")] + for mod in mods_to_unload: + del sys.modules[mod] + + # Reload Conformal Tights now that the selected optional dependency is not available. + from conformal_tights import ConformalCoherentQuantileRegressor, DartsForecaster + + # Test that we raise the appropriate error. + conformal_predictor = ConformalCoherentQuantileRegressor() + with pytest.raises(ImportError, match=f".*install.*{optional_dependency}.*"): + _ = DartsForecaster(model=conformal_predictor) + + # Unload Conformal Tights again. + mods_to_unload = [mod for mod in sys.modules if mod.startswith("conformal_tights")] + for mod in mods_to_unload: + del sys.modules[mod]