diff --git a/tests/estimation/test_msm_weighting.py b/tests/estimation/test_msm_weighting.py index c15375a59..c87fd7a56 100644 --- a/tests/estimation/test_msm_weighting.py +++ b/tests/estimation/test_msm_weighting.py @@ -22,7 +22,7 @@ def expected_values(): cov_np = np.diag([1, 2, 3]) cov_pd = pd.DataFrame(cov_np) -test_cases = itertools.product([cov_np, cov_pd], ["diagonal", "optimal"]) +test_cases = itertools.product([cov_np, cov_pd], ["diagonal", "optimal", "identity"]) @pytest.mark.parametrize("moments_cov, method", test_cases) @@ -38,7 +38,11 @@ def test_get_weighting_matrix(moments_cov, method): assert calculated.columns.equals(moments_cov.columns) calculated = calculated.to_numpy() - expected = np.diag(1 / np.array([1, 2, 3])) + if method == "identity": + expected = np.identity(cov_np.shape[0]) + else: + expected = np.diag(1 / np.array([1, 2, 3])) + aaae(calculated, expected)