Skip to content

Commit

Permalink
Merge pull request #54 from causy-dev/simplify-requirements
Browse files Browse the repository at this point in the history
fix(packaging): simplify dependencies + error handling
  • Loading branch information
LilithWittmann authored Aug 6, 2024
2 parents 480cb45 + e31e8a8 commit ebd66d8
Show file tree
Hide file tree
Showing 3 changed files with 273 additions and 773 deletions.
37 changes: 26 additions & 11 deletions causy/causal_discovery/constraint/independence_tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,17 @@ def process(

if not set(nodes[2:]).issubset(set([on for on in list(other_neighbours)])):
return
inverse_cov_matrix = torch.inverse(
torch.cov(torch.stack([graph.nodes[node].values for node in nodes]))
cov_matrix = torch.cov(
torch.stack([graph.nodes[node].values for node in nodes])
)
# check if the covariance matrix is ill-conditioned
if torch.det(cov_matrix) == 0:
logger.warning(
"The covariance matrix is ill-conditioned. The precision matrix is not reliable."
)
return

inverse_cov_matrix = torch.inverse(cov_matrix)

n = inverse_cov_matrix.size(0)
diagonal = torch.diag(inverse_cov_matrix)
Expand All @@ -186,15 +194,22 @@ def process(
sample_size = len(graph.nodes[nodes[0]].values)
nb_of_control_vars = len(nodes) - 2

t, critical_t = get_t_and_critical_t(
sample_size,
nb_of_control_vars,
(
(-1 * precision_matrix[0][1])
/ torch.sqrt(precision_matrix[0][0] * precision_matrix[1][1])
).item(),
self.threshold,
)
# prevent math domain error
try:
t, critical_t = get_t_and_critical_t(
sample_size,
nb_of_control_vars,
(
(-1 * precision_matrix[0][1])
/ torch.sqrt(precision_matrix[0][0] * precision_matrix[1][1])
).item(),
self.threshold,
)
except ValueError:
logger.warning(
"Math domain error. The covariance matrix is ill-conditioned. The precision matrix is not reliable."
)
return

if abs(t) < critical_t:
logger.debug(
Expand Down
Loading

0 comments on commit ebd66d8

Please sign in to comment.