Skip to content

Commit

Permalink
fix(pearson_coefficient): ensure that it works the same as the python…
Browse files Browse the repository at this point in the history
… correlation function + tests
  • Loading branch information
LilithWittmann committed Oct 22, 2023
1 parent 53cd4a9 commit 3db5e7f
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 3 deletions.
3 changes: 3 additions & 0 deletions causy/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ def __init__(
if shuffle_combinations is not None:
self.shuffle_combinations = shuffle_combinations

if chunked is not None:
self.chunked = chunked

def serialize(self):
result = super().serialize()
result["params"]["shuffle_combinations"] = self.shuffle_combinations
Expand Down
4 changes: 2 additions & 2 deletions causy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,6 @@ def retrieve_edges(graph):

def pearson_correlation(x, y):
cov_xy = torch.mean((x - x.mean()) * (y - y.mean()))
std_x = x.std()
std_y = y.std()
std_x = x.std(unbiased=False)
std_y = y.std(unbiased=False)
return cov_xy / (std_x * std_y)
41 changes: 40 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,47 @@
import unittest
from statistics import correlation

from causy.independence_tests import CalculateCorrelations
from causy.utils import (
pearson_correlation,
serialize_module_name,
load_pipeline_artefact_by_definition,
load_pipeline_steps_by_definition,
)
import torch


class UtilsTestCase(unittest.TestCase):
pass
def test_pearson_correlation(self):
self.assertEqual(
pearson_correlation(
torch.tensor([1, 2, 3], dtype=torch.float64),
torch.tensor([1, 2, 3], dtype=torch.float64),
).item(),
1,
)
self.assertEqual(
pearson_correlation(
torch.tensor([1, 2, 3], dtype=torch.float64),
torch.tensor([3, 2, 1], dtype=torch.float64),
).item(),
-1,
)

def test_serialize_module_name(self):
self.assertEqual(serialize_module_name(self), "tests.test_utils.UtilsTestCase")

def test_load_pipeline_artefact_by_definition(self):
step = {"name": "causy.independence_tests.CalculateCorrelations"}
self.assertIsInstance(
load_pipeline_artefact_by_definition(step), CalculateCorrelations
)

def load_pipeline_steps_by_definition(self):
steps = [{"name": "causy.independence_tests.CalculateCorrelations"}]
self.assertIsInstance(
load_pipeline_steps_by_definition(steps)[0], CalculateCorrelations
)


if __name__ == "__main__":
Expand Down

0 comments on commit 3db5e7f

Please sign in to comment.