diff --git a/src/depiction/calibration/models/linear_model.py b/src/depiction/calibration/models/linear_model.py index 14c4c46..dc2f819 100644 --- a/src/depiction/calibration/models/linear_model.py +++ b/src/depiction/calibration/models/linear_model.py @@ -46,14 +46,25 @@ def zero(cls) -> LinearModel: return cls([0, 0]) @classmethod - def fit_lsq(cls, x_arr: NDArray[float], y_arr: NDArray[float]) -> LinearModel: + def fit_lsq(cls, x_arr: NDArray[float], y_arr: NDArray[float], min_points: int | None = None) -> LinearModel: """Fits a linear model to the given data using least squares regression.""" + if min_points is not None and x_arr.size < min_points: + return cls.fit_constant(y_arr=y_arr) model = sklearn.linear_model.LinearRegression() model.fit(x_arr[:, np.newaxis], y_arr[:, np.newaxis]) return LinearModel(coef=[model.intercept_[0], model.coef_[0, 0]]) @classmethod - def fit_siegelslopes(cls, x_arr: NDArray[float], y_arr: NDArray[float]) -> LinearModel: + def fit_siegelslopes( + cls, x_arr: NDArray[float], y_arr: NDArray[float], min_points: int | None = None + ) -> LinearModel: """Fits a linear model to the given data using robust Siegel-Slopes regression.""" + if min_points is not None and x_arr.size < min_points: + return cls.fit_constant(y_arr) slope, intercept = scipy.stats.siegelslopes(y=y_arr, x=x_arr) return LinearModel(coef=[intercept, slope]) + + @classmethod + def fit_constant(cls, y_arr: NDArray[float]) -> LinearModel: + """Fits a constant model to the given data.""" + return LinearModel(coef=[np.mean(y_arr), 0]) diff --git a/tests/unit/calibration/models/test_linear_model.py b/tests/unit/calibration/models/test_linear_model.py index 83298fe..3809925 100644 --- a/tests/unit/calibration/models/test_linear_model.py +++ b/tests/unit/calibration/models/test_linear_model.py @@ -55,8 +55,20 @@ def test_zero() -> None: def test_fit_lsq() -> None: model = LinearModel.fit_lsq(np.array([1, 2, 3]), np.array([4, 5, 6])) - assert model.intercept == pytest.approx(3, abs=1e-7) assert model.slope == pytest.approx(1, abs=1e-7) + assert model.intercept == pytest.approx(3, abs=1e-7) + + +def test_fit_lsq_when_one_point() -> None: + model = LinearModel.fit_lsq(np.array([1]), np.array([2])) + assert model.slope == pytest.approx(0, 1e-10) + assert model.intercept == pytest.approx(2, 1e-10) + + +def test_fit_lsq_when_few_points() -> None: + result = LinearModel.fit_lsq(np.array([1, 2, 3]), np.array([2, 3, 4]), min_points=5) + assert result.slope == 0 + assert result.intercept == pytest.approx(3, abs=1e-10) def test_fit_linear_siegelslopes() -> None: @@ -66,5 +78,24 @@ def test_fit_linear_siegelslopes() -> None: np.testing.assert_array_almost_equal(np.array([0, 0.01]), model.coef, decimal=7) +def test_fit_linear_siegelslopes_when_one_point() -> None: + model = LinearModel.fit_siegelslopes(np.array([1]), np.array([2])) + assert model.slope == pytest.approx(0, abs=1e-10) + assert model.intercept == pytest.approx(2, abs=1e-10) + + +def test_fit_linear_siegelslopes_when_few_points(mocker) -> None: + mocker.patch("scipy.stats.siegelslopes", return_value=(0, 3)) + result = LinearModel.fit_siegelslopes(np.array([1, 2, 3]), np.array([2, 3, 4]), min_points=5) + assert result.slope == 0 + assert result.intercept == pytest.approx(3, abs=1e-10) + + +def test_fit_constant() -> None: + model = LinearModel.fit_constant(np.array([1, 2, 3])) + assert model.slope == 0 + assert model.intercept == pytest.approx(2, abs=1e-10) + + if __name__ == "__main__": pytest.main()