Skip to content

Commit

Permalink
remove pytest from gaussian.py
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxHalford committed Oct 2, 2023
1 parent 487341b commit 2df17f1
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 27 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,6 @@ benchmarks/.asv

# Cargo file
Cargo.lock

# WASM
/*.html
27 changes: 0 additions & 27 deletions river/proba/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import numpy as np
import pandas as pd
import pytest
from scipy.stats import multivariate_normal

from river import covariance, stats
Expand Down Expand Up @@ -313,29 +312,3 @@ def sample(self) -> dict[str, float]:
@property
def mode(self) -> dict:
return self.mu


@pytest.mark.parametrize(
"p",
[
pytest.param(
p,
id=f"{p=}",
)
for p in [1, 3, 5]
],
)
def test_univariate_multivariate_consistency(p):
X = pd.DataFrame(np.random.random((30, p)), columns=range(p))

multi = MultivariateGaussian()
single = {c: Gaussian() for c in X.columns}

for x in X.to_dict(orient="records"):
multi = multi.update(x)
for c, s in single.items():
s.update(x[c])

for c in X.columns:
assert math.isclose(multi.mu[c], single[c].mu)
assert math.isclose(multi.sigma[c][c], single[c].sigma)
31 changes: 31 additions & 0 deletions river/proba/test_gaussian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest
import pandas as pd
import numpy as np
import math
from river import proba


@pytest.mark.parametrize(
"p",
[
pytest.param(
p,
id=f"{p=}",
)
for p in [1, 3, 5]
],
)
def test_univariate_multivariate_consistency(p):
X = pd.DataFrame(np.random.random((30, p)), columns=range(p))

multi = proba.MultivariateGaussian()
single = {c: proba.Gaussian() for c in X.columns}

for x in X.to_dict(orient="records"):
multi = multi.update(x)
for c, s in single.items():
s.update(x[c])

for c in X.columns:
assert math.isclose(multi.mu[c], single[c].mu)
assert math.isclose(multi.sigma[c][c], single[c].sigma)

0 comments on commit 2df17f1

Please sign in to comment.