From d325ed32f7eaadfce37942899b1c3ce8ff36e79e Mon Sep 17 00:00:00 2001 From: Sofia Faltenbacher Date: Mon, 4 Nov 2024 14:15:27 +0100 Subject: [PATCH 1/2] feat(GraphBaseAccessMixin): are_nodes_d_separated --- .../constraint/independence_tests/common.py | 6 +- causy/graph.py | 130 ++++++++-- tests/test_path_classification.py | 233 ++++++++++++++++++ tests/test_underlying_path_retrieval.py | 93 ------- 4 files changed, 340 insertions(+), 122 deletions(-) create mode 100644 tests/test_path_classification.py delete mode 100644 tests/test_underlying_path_retrieval.py diff --git a/causy/causal_discovery/constraint/independence_tests/common.py b/causy/causal_discovery/constraint/independence_tests/common.py index c47312b..13e5f79 100644 --- a/causy/causal_discovery/constraint/independence_tests/common.py +++ b/causy/causal_discovery/constraint/independence_tests/common.py @@ -32,7 +32,7 @@ class CorrelationCoefficientTest( parallel: BoolParameter = False def process( - self, nodes: List[str], graph: BaseGraphInterface + self, nodes: List[str], graph: BaseGraphInterface ) -> Optional[TestResult]: """ Test if u and v are independent and delete edge in graph if they are. @@ -204,7 +204,9 @@ def process( helper = torch.mm(torch.sqrt(diagonal_matrix), inverse_cov_matrix) precision_matrix = torch.mm(helper, torch.sqrt(diagonal_matrix)) - par_corr = (-1 * precision_matrix[0][1]) / torch.sqrt(precision_matrix[0][0] * precision_matrix[1][1]) + par_corr = (-1 * precision_matrix[0][1]) / torch.sqrt( + precision_matrix[0][0] * precision_matrix[1][1] + ) sample_size = len(graph.nodes[nodes[0]].values) nb_of_control_vars = len(nodes) - 2 diff --git a/causy/graph.py b/causy/graph.py index 2011470..605662f 100644 --- a/causy/graph.py +++ b/causy/graph.py @@ -274,6 +274,109 @@ def directed_path_exists(self, u: Union[Node, str], v: Union[Node, str]) -> bool return True return False + def edge_of_type_exists( + self, + u: Union[Node, str], + v: Union[Node, str], + edge_type: EdgeTypeInterface = DirectedEdge(), + ) -> bool: + """ + Check if an edge of a specific type exists between u and v. + :param u: node u + :param v: node v + :param edge_type: the type of the edge to check for + :return: True if an edge of this type exists, False otherwise + """ + + if isinstance(u, Node): + u = u.id + if isinstance(v, Node): + v = v.id + + if not self.directed_edge_exists(u, v): + return False + + if self.edges[u][v].edge_type != edge_type: + return False + + return True + + def are_nodes_d_separated( + self, + u: Union[Node, str], + v: Union[Node, str], + conditioning_set: List[Union[Node, str]], + ) -> bool: + """ + Check if nodes u and v are d-separated given a conditioning set. We check whether there is an open path, i.e. a path on which all colliders are in the conditioning set and all non-colliders are not in the conditioning set. If there is no open path, u and v are d-separated. + + :param u: First node + :param v: Second node + :param conditioning_set: Set of nodes to condition on + :return: True if u and v are d-separated given conditioning_set, False otherwise + """ + + # Convert Node instances to their IDs + if isinstance(u, Node): + u = u.id + if isinstance(v, Node): + v = v.id + + # u and v may not be in the conditioning set, throw error + if u in conditioning_set or v in conditioning_set: + raise ValueError("Nodes u and v may not be in the conditioning set") + + # check whether there is an open path on which all colliders are in the conditioning set and all non-colliders are not in the conditioning set + list_of_results_for_paths = [] + for path in self.all_paths_on_underlying_undirected_graph(u, v): + if len(path) == 2: + list_of_results_for_paths.append(False) + + for i in range(1, len(path) - 1): + is_path_blocked = False + + # paths are d-separated if a collider is not in the conditioning set + if path[i] not in conditioning_set: + if self.edge_of_type_exists( + path[i - 1].id, path[i].id, DirectedEdge() + ) and self.edge_of_type_exists( + path[i + 1].id, path[i].id, DirectedEdge() + ): + is_path_blocked = True + + # paths are d-separated if a non-collider is in the conditioning set + elif path[i] in conditioning_set: + if not ( + self.edge_of_type_exists( + path[i - 1].id, path[i].id, DirectedEdge() + ) + and self.edge_of_type_exists( + path[i + 1].id, path[i].id, DirectedEdge() + ) + ): + is_path_blocked = True + + list_of_results_for_paths.append(is_path_blocked) + + # if there is at least one open path, u and v are not d-separated + if False in list_of_results_for_paths: + return False + return True + + # Check for collider conditions (X -> Z <- Y) + for z in self.edges.get(x, []): + for z_target in self.edges.get(z, []): + if z_target == y and z not in conditioning_set: + # If z is not conditioned on and has descendants, + # we need to check those descendants too. + descendants = self.get_descendants(z) + if not any( + descendant in conditioning_set for descendant in descendants + ): + return False + + return True + def all_paths_on_underlying_undirected_graph( self, u: Union[Node, str], v: Union[Node, str], visited=None, path=None ) -> List[List[Node]]: @@ -307,33 +410,6 @@ def all_paths_on_underlying_undirected_graph( path.pop() visited.remove(u) - def edge_of_type_exists( - self, - u: Union[Node, str], - v: Union[Node, str], - edge_type: EdgeTypeInterface = DirectedEdge(), - ) -> bool: - """ - Check if an edge of a specific type exists between u and v. - :param u: node u - :param v: node v - :param edge_type: the type of the edge to check for - :return: True if an edge of this type exists, False otherwise - """ - - if isinstance(u, Node): - u = u.id - if isinstance(v, Node): - v = v.id - - if not self.edge_exists(u, v): - return False - - if self.edges[u][v].edge_type != edge_type: - return False - - return True - def retrieve_edges(self) -> List[Edge]: """ Retrieve all edges diff --git a/tests/test_path_classification.py b/tests/test_path_classification.py new file mode 100644 index 0000000..3294572 --- /dev/null +++ b/tests/test_path_classification.py @@ -0,0 +1,233 @@ +from causy.edge_types import DirectedEdge +from causy.graph import GraphBaseAccessMixin, GraphManager +from tests.utils import CausyTestCase + + +class GraphTestCase(CausyTestCase): + def test_all_paths_on_underlying_undirected_graph(self): + new_graph_manager = GraphManager + new_graph_manager.__bases__ = ( + GraphBaseAccessMixin, + DirectedEdge.GraphAccessMixin, + ) + graph = new_graph_manager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + graph.add_directed_edge(node1, node2, {"test": "test"}) + graph.add_directed_edge(node2, node3, {"test": "test"}) + self.assertIn( + [node1, node2, node3], + [ + x + for x in [ + l + for l in graph.all_paths_on_underlying_undirected_graph( + node1, node3 + ) + ] + ], + ) + self.assertNotIn( + [node1, node2], + [l for l in graph.all_paths_on_underlying_undirected_graph(node1, node3)], + ) + self.assertNotIn( + [node2, node3], + [l for l in graph.all_paths_on_underlying_undirected_graph(node1, node3)], + ) + self.assertNotIn( + [node1, node3], + [l for l in graph.all_paths_on_underlying_undirected_graph(node1, node3)], + ) + + def test_all_paths_on_underlying_undirected_graph_2(self): + new_graph_manager = GraphManager + new_graph_manager.__bases__ = ( + GraphBaseAccessMixin, + DirectedEdge.GraphAccessMixin, + ) + graph = new_graph_manager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + graph.add_directed_edge(node3, node2, {"test": "test"}) + graph.add_directed_edge(node2, node1, {"test": "test"}) + self.assertIn( + [node1, node2, node3], + [ + x + for x in [ + l + for l in graph.all_paths_on_underlying_undirected_graph( + node1, node3 + ) + ] + ], + ) + self.assertNotIn( + [node1, node2], + [l for l in graph.all_paths_on_underlying_undirected_graph(node1, node3)], + ) + self.assertNotIn( + [node2, node3], + [l for l in graph.all_paths_on_underlying_undirected_graph(node1, node3)], + ) + self.assertNotIn( + [node1, node3], + [l for l in graph.all_paths_on_underlying_undirected_graph(node1, node3)], + ) + + def test_all_paths_on_underlying_undirected_graph_collider_path(self): + new_graph_manager = GraphManager + new_graph_manager.__bases__ = ( + GraphBaseAccessMixin, + DirectedEdge.GraphAccessMixin, + ) + graph = new_graph_manager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + graph.add_directed_edge(node3, node2, {"test": "test"}) + graph.add_directed_edge(node1, node2, {"test": "test"}) + self.assertIn( + [node1, node2, node3], + [ + x + for x in [ + l + for l in graph.all_paths_on_underlying_undirected_graph( + node1, node3 + ) + ] + ], + ) + self.assertNotIn( + [node1, node2], + [l for l in graph.all_paths_on_underlying_undirected_graph(node1, node3)], + ) + self.assertNotIn( + [node2, node3], + [l for l in graph.all_paths_on_underlying_undirected_graph(node1, node3)], + ) + self.assertNotIn( + [node1, node3], + [l for l in graph.all_paths_on_underlying_undirected_graph(node1, node3)], + ) + + def test_all_paths_on_underlying_undirected_graph_several_paths(self): + new_graph_manager = GraphManager + new_graph_manager.__bases__ = ( + GraphBaseAccessMixin, + DirectedEdge.GraphAccessMixin, + ) + graph = new_graph_manager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + graph.add_directed_edge(node1, node2, {"test": "test"}) + graph.add_directed_edge(node2, node3, {"test": "test"}) + graph.add_directed_edge(node1, node3, {"test": "test"}) + self.assertIn( + [node1, node2, node3], + [ + x + for x in [ + l + for l in graph.all_paths_on_underlying_undirected_graph( + node1, node3 + ) + ] + ], + ) + self.assertIn( + [node1, node3], + [ + x + for x in [ + l + for l in graph.all_paths_on_underlying_undirected_graph( + node1, node3 + ) + ] + ], + ) + self.assertNotIn( + [node1, node2], + [l for l in graph.all_paths_on_underlying_undirected_graph(node1, node3)], + ) + self.assertNotIn( + [node2, node3], + [l for l in graph.all_paths_on_underlying_undirected_graph(node1, node3)], + ) + + def test_are_nodes_d_separated(self): + new_graph_manager = GraphManager + new_graph_manager.__bases__ = ( + GraphBaseAccessMixin, + DirectedEdge.GraphAccessMixin, + ) + graph = new_graph_manager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + graph.add_directed_edge(node1, node2, {"test": "test"}) + graph.add_directed_edge(node2, node3, {"test": "test"}) + self.assertFalse(graph.are_nodes_d_separated(node1, node3, [])) + + def test_are_nodes_d_separated_2(self): + new_graph_manager = GraphManager + new_graph_manager.__bases__ = ( + GraphBaseAccessMixin, + DirectedEdge.GraphAccessMixin, + ) + graph = new_graph_manager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + graph.add_directed_edge(node1, node2, {"test": "test"}) + graph.add_directed_edge(node2, node3, {"test": "test"}) + self.assertTrue(graph.are_nodes_d_separated(node1, node3, [node2])) + + def test_are_nodes_d_separated_3(self): + new_graph_manager = GraphManager + new_graph_manager.__bases__ = ( + GraphBaseAccessMixin, + DirectedEdge.GraphAccessMixin, + ) + graph = new_graph_manager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + graph.add_directed_edge(node1, node2, {"test": "test"}) + graph.add_directed_edge(node3, node2, {"test": "test"}) + self.assertTrue(graph.are_nodes_d_separated(node1, node3, [])) + + def test_are_nodes_d_separated_4(self): + new_graph_manager = GraphManager + new_graph_manager.__bases__ = ( + GraphBaseAccessMixin, + DirectedEdge.GraphAccessMixin, + ) + graph = new_graph_manager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + graph.add_directed_edge(node1, node2, {"test": "test"}) + graph.add_directed_edge(node3, node2, {"test": "test"}) + self.assertFalse(graph.are_nodes_d_separated(node1, node3, [node2])) + + def test_are_nodes_d_separated_3(self): + new_graph_manager = GraphManager + new_graph_manager.__bases__ = ( + GraphBaseAccessMixin, + DirectedEdge.GraphAccessMixin, + ) + graph = new_graph_manager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + graph.add_directed_edge(node1, node2, {"test": "test"}) + graph.add_directed_edge(node3, node2, {"test": "test"}) + graph.add_directed_edge(node1, node3, {"test": "test"}) + self.assertFalse(graph.are_nodes_d_separated(node1, node3, [])) diff --git a/tests/test_underlying_path_retrieval.py b/tests/test_underlying_path_retrieval.py deleted file mode 100644 index a468b16..0000000 --- a/tests/test_underlying_path_retrieval.py +++ /dev/null @@ -1,93 +0,0 @@ -from causy.edge_types import DirectedEdge -from causy.graph import GraphBaseAccessMixin, GraphManager -from tests.utils import CausyTestCase - - -class GraphTestCase(CausyTestCase): - def test_all_paths_on_underlying_undirected_graph(self): - new_graph_manager = GraphManager - new_graph_manager.__bases__ = ( - GraphBaseAccessMixin, - DirectedEdge.GraphAccessMixin, - ) - graph = new_graph_manager() - node1 = graph.add_node("test1", [1, 2, 3]) - node2 = graph.add_node("test2", [1, 2, 3]) - node3 = graph.add_node("test3", [1, 2, 3]) - graph.add_directed_edge(node1, node2, {"test": "test"}) - graph.add_directed_edge(node2, node3, {"test": "test"}) - self.assertIn( - [node1, node2, node3], - [ - x - for x in [ - l - for l in graph.all_paths_on_underlying_undirected_graph( - node1, node3 - ) - ] - ], - ) - - def test_all_paths_on_underlying_undirected_graph_2(self): - new_graph_manager = GraphManager - new_graph_manager.__bases__ = ( - GraphBaseAccessMixin, - DirectedEdge.GraphAccessMixin, - ) - graph = new_graph_manager() - node1 = graph.add_node("test1", [1, 2, 3]) - node2 = graph.add_node("test2", [1, 2, 3]) - node3 = graph.add_node("test3", [1, 2, 3]) - graph.add_directed_edge(node3, node2, {"test": "test"}) - graph.add_directed_edge(node2, node1, {"test": "test"}) - self.assertIn( - [node1, node2, node3], - [ - x - for x in [ - l - for l in graph.all_paths_on_underlying_undirected_graph( - node1, node3 - ) - ] - ], - ) - - def test_all_paths_on_underlying_undirected_graph_several_paths(self): - new_graph_manager = GraphManager - new_graph_manager.__bases__ = ( - GraphBaseAccessMixin, - DirectedEdge.GraphAccessMixin, - ) - graph = new_graph_manager() - node1 = graph.add_node("test1", [1, 2, 3]) - node2 = graph.add_node("test2", [1, 2, 3]) - node3 = graph.add_node("test3", [1, 2, 3]) - graph.add_directed_edge(node1, node2, {"test": "test"}) - graph.add_directed_edge(node2, node3, {"test": "test"}) - graph.add_directed_edge(node1, node3, {"test": "test"}) - self.assertIn( - [node1, node2, node3], - [ - x - for x in [ - l - for l in graph.all_paths_on_underlying_undirected_graph( - node1, node3 - ) - ] - ], - ) - self.assertIn( - [node1, node3], - [ - x - for x in [ - l - for l in graph.all_paths_on_underlying_undirected_graph( - node1, node3 - ) - ] - ], - ) From edc19cfe6b930aaf4ed8714356ff15c2e2a117dc Mon Sep 17 00:00:00 2001 From: Sofia Faltenbacher Date: Mon, 4 Nov 2024 14:20:14 +0100 Subject: [PATCH 2/2] fix(GraphBaseAccessMixin): are_nodes_d_separated --- causy/graph.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/causy/graph.py b/causy/graph.py index 605662f..0ff7fc2 100644 --- a/causy/graph.py +++ b/causy/graph.py @@ -363,20 +363,6 @@ def are_nodes_d_separated( return False return True - # Check for collider conditions (X -> Z <- Y) - for z in self.edges.get(x, []): - for z_target in self.edges.get(z, []): - if z_target == y and z not in conditioning_set: - # If z is not conditioned on and has descendants, - # we need to check those descendants too. - descendants = self.get_descendants(z) - if not any( - descendant in conditioning_set for descendant in descendants - ): - return False - - return True - def all_paths_on_underlying_undirected_graph( self, u: Union[Node, str], v: Union[Node, str], visited=None, path=None ) -> List[List[Node]]: