Skip to content

Commit

Permalink
[FIX] test_term overflow new version bug
Browse files Browse the repository at this point in the history
  • Loading branch information
zihuaihuai committed Sep 26, 2024
1 parent dad2c29 commit 4a3ea80
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python_unittests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: [3.9]
python-version: [3.9, 3.10, 3.11, 3.12]
os: [ubuntu-latest, windows-latest, macos-latest]

runs-on: ${{ matrix.os }}
Expand Down
12 changes: 8 additions & 4 deletions brainstat/stats/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,12 +377,16 @@ def __sub__(self, t: Union[ArrayLike, "FixedEffect"]) -> "FixedEffect":
return self
df /= df.abs().sum(0)
df.index = self.m.index

m = self.m / self.m.abs().sum(0)
merged = m.T.merge(df.T, how="outer", indicator=True)
mask = (merged._merge.values == "left_only")[: self.m.shape[1]]

m_normalized_T = m.T
df_normalized_T = df.T

merged = m_normalized_T.merge(df_normalized_T, how="outer", indicator=True)
merged_index = m_normalized_T.iloc[:, [1]].reset_index().merge(merged, how="left")
mask = merged_index.loc[merged_index['_merge'] == 'left_only','index']
return FixedEffect(
self.m[self.m.columns[mask]], add_intercept=False, _check_categorical=False
self.m[mask], add_intercept=False, _check_categorical=False
)

def _mul(
Expand Down
23 changes: 19 additions & 4 deletions brainstat/tests/test_terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ def test_fixed_init():

def test_fixed_overload():
"""Tests the overloads of the FixedEffect class."""
np.random.seed(2)
random_data = np.random.random_sample((10, 3))
fix01 = FixedEffect(random_data[:, :2], ["x0", "x1"], add_intercept=False)
fix012 = FixedEffect(random_data, ["x0", "x1", "x2"], add_intercept=False)
fix12 = FixedEffect(random_data[:, 1:], ["x2", "x3"], add_intercept=False)
fix2 = FixedEffect(random_data[:, 2], ["x2"], add_intercept=False)
fixi0 = FixedEffect(random_data[:, 0], ["x0"], add_intercept=True)
Expand All @@ -46,11 +48,24 @@ def test_fixed_overload():
expected = np.concatenate((np.ones((10, 1)), random_data[:, 0:2]), axis=1)
assert np.array_equal(fix_add_intercept.m, expected)

# fix_sub = fix01 - fix12
# assert np.array_equal(fix_sub.m.to_numpy(), random_data[:, 0][:, None])
fix_sub = fix01 - fix12
np.testing.assert_allclose(
fix_sub.m,
random_data[:, 0][:, None],
atol=1e-8,
rtol=1e-5,
)

fix_sub2 = fix012 - fix2
np.testing.assert_allclose(
fix_sub2.m,
random_data[:, :2],
atol=1e-8,
rtol=1e-5,
)

# fix_mul = fix01 * fix2
# assert np.array_equal(fix_mul.m.to_numpy(), random_data[:, :2] * random_data[:, 2][:, None])
fix_mul = fix01 * fix2
assert np.array_equal(fix_mul.m, random_data[:, :2] * random_data[:, 2][:, None])


def test_mixed_init():
Expand Down

0 comments on commit 4a3ea80

Please sign in to comment.