Skip to content

Commit

Permalink
functionality to fall back to a constant model (useful for calibration)
Browse files Browse the repository at this point in the history
  • Loading branch information
leoschwarz committed Jun 28, 2024
1 parent 3777b76 commit b33dce1
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 3 deletions.
15 changes: 13 additions & 2 deletions src/depiction/calibration/models/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,25 @@ def zero(cls) -> LinearModel:
return cls([0, 0])

@classmethod
def fit_lsq(cls, x_arr: NDArray[float], y_arr: NDArray[float]) -> LinearModel:
def fit_lsq(cls, x_arr: NDArray[float], y_arr: NDArray[float], min_points: int | None = None) -> LinearModel:
"""Fits a linear model to the given data using least squares regression."""
if min_points is not None and x_arr.size < min_points:
return cls.fit_constant(y_arr=y_arr)
model = sklearn.linear_model.LinearRegression()
model.fit(x_arr[:, np.newaxis], y_arr[:, np.newaxis])
return LinearModel(coef=[model.intercept_[0], model.coef_[0, 0]])

@classmethod
def fit_siegelslopes(cls, x_arr: NDArray[float], y_arr: NDArray[float]) -> LinearModel:
def fit_siegelslopes(
cls, x_arr: NDArray[float], y_arr: NDArray[float], min_points: int | None = None
) -> LinearModel:
"""Fits a linear model to the given data using robust Siegel-Slopes regression."""
if min_points is not None and x_arr.size < min_points:
return cls.fit_constant(y_arr)
slope, intercept = scipy.stats.siegelslopes(y=y_arr, x=x_arr)
return LinearModel(coef=[intercept, slope])

@classmethod
def fit_constant(cls, y_arr: NDArray[float]) -> LinearModel:
"""Fits a constant model to the given data."""
return LinearModel(coef=[np.mean(y_arr), 0])
33 changes: 32 additions & 1 deletion tests/unit/calibration/models/test_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,20 @@ def test_zero() -> None:

def test_fit_lsq() -> None:
model = LinearModel.fit_lsq(np.array([1, 2, 3]), np.array([4, 5, 6]))
assert model.intercept == pytest.approx(3, abs=1e-7)
assert model.slope == pytest.approx(1, abs=1e-7)
assert model.intercept == pytest.approx(3, abs=1e-7)


def test_fit_lsq_when_one_point() -> None:
model = LinearModel.fit_lsq(np.array([1]), np.array([2]))
assert model.slope == pytest.approx(0, 1e-10)
assert model.intercept == pytest.approx(2, 1e-10)


def test_fit_lsq_when_few_points() -> None:
result = LinearModel.fit_lsq(np.array([1, 2, 3]), np.array([2, 3, 4]), min_points=5)
assert result.slope == 0
assert result.intercept == pytest.approx(3, abs=1e-10)


def test_fit_linear_siegelslopes() -> None:
Expand All @@ -66,5 +78,24 @@ def test_fit_linear_siegelslopes() -> None:
np.testing.assert_array_almost_equal(np.array([0, 0.01]), model.coef, decimal=7)


def test_fit_linear_siegelslopes_when_one_point() -> None:
model = LinearModel.fit_siegelslopes(np.array([1]), np.array([2]))
assert model.slope == pytest.approx(0, abs=1e-10)
assert model.intercept == pytest.approx(2, abs=1e-10)


def test_fit_linear_siegelslopes_when_few_points(mocker) -> None:
mocker.patch("scipy.stats.siegelslopes", return_value=(0, 3))
result = LinearModel.fit_siegelslopes(np.array([1, 2, 3]), np.array([2, 3, 4]), min_points=5)
assert result.slope == 0
assert result.intercept == pytest.approx(3, abs=1e-10)


def test_fit_constant() -> None:
model = LinearModel.fit_constant(np.array([1, 2, 3]))
assert model.slope == 0
assert model.intercept == pytest.approx(2, abs=1e-10)


if __name__ == "__main__":
pytest.main()

0 comments on commit b33dce1

Please sign in to comment.