From 8cb92e243203511c05607fb9488cd873ccf9211e Mon Sep 17 00:00:00 2001 From: Sofia Faltenbacher Date: Tue, 28 Jan 2025 16:00:43 +0100 Subject: [PATCH 1/2] feat(graph): are_nodes_d_separated_cpdag following the criterion by Perkovic, 2020 --- causy/graph.py | 121 ++++++++++- tests/test_graph.py | 321 ++++++++++++++++++++++++++++++ tests/test_path_classification.py | 52 ++--- tests/test_pc_e2e.py | 22 ++ 4 files changed, 485 insertions(+), 31 deletions(-) diff --git a/causy/graph.py b/causy/graph.py index b7d0696..901a861 100644 --- a/causy/graph.py +++ b/causy/graph.py @@ -342,9 +342,7 @@ def _is_a_collider_blocking(self, path, conditioning_set) -> bool: """ is_path_blocked = False for i in range(1, len(path) - 1): - if self.edge_of_type_exists( - path[i - 1], path[i], DirectedEdge() - ) and self.edge_of_type_exists(path[i + 1], path[i], DirectedEdge()): + if self._is_collider(path[i - 1], path[i], path[i + 1]): # if the node is a collider, check if the node or any of its descendants are in the conditioning set is_path_blocked = True for descendant in self.descendants_of_node(path[i]): @@ -373,14 +371,14 @@ def _is_a_non_collider_in_conditioning_set(self, path, conditioning_set) -> bool is_path_blocked = True return is_path_blocked - def are_nodes_d_separated( + def are_nodes_d_separated_dag( self, u: Union[Node, str], v: Union[Node, str], conditioning_set: List[Union[Node, str]], ) -> bool: """ - Check if nodes u and v are d-separated given a conditioning set. We check whether there is an open path, i.e. a path on which all colliders are in the conditioning set and all non-colliders are not in the conditioning set. If there is no open path, u and v are d-separated. + Check if nodes u and v are d-separated given a conditioning set in a DAG. We check whether there is an open path, i.e. a path on which all colliders are in the conditioning set and all non-colliders are not in the conditioning set. If there is no open path, u and v are d-separated. :param u: First node :param v: Second node @@ -422,6 +420,119 @@ def are_nodes_d_separated( return False return True + def are_nodes_d_separated_cpdag( + self, + u: Union[Node, str], + v: Union[Node, str], + conditioning_set: List[Union[Node, str]], + ) -> bool: + """ + Check if nodes u and v are d-separated given a conditioning set in a CPDAG following Perkovic, 2020. (Identifying causal effects in maximally oriented partially directed acyclic graphs) + + :param u: First node + :param v: Second node + :param conditioning_set: Set of nodes to condition on + :return: True if u and v are d-separated given conditioning_set, False otherwise + """ + + # Convert Node instances to their IDs + if isinstance(u, Node): + u = u.id + if isinstance(v, Node): + v = v.id + + # u and v may not be in the conditioning set, throw error + if u in conditioning_set or v in conditioning_set: + raise ValueError("Nodes u and v may not be in the conditioning set") + + # If there are no paths between u and v, they are d-separated + if list(self.all_paths_on_underlying_undirected_graph(u, v)) == []: + return True + + list_of_results_for_paths = [] + for path in self.all_paths_on_underlying_undirected_graph(u, v): + # If the path only has two nodes, it cannot be blocked and is open. Therefore, u and v are not d-separated + if len(path) == 2: + return False + is_path_blocked = False + if self._is_path_of_definite_status(path): + if self._is_definite_noncollider_in_conditioning_set( + path, conditioning_set + ): + is_path_blocked = True + if self._is_a_collider_blocking(path, conditioning_set): + is_path_blocked = True + list_of_results_for_paths.append(is_path_blocked) + if False in list_of_results_for_paths: + return False + return True + + def _is_collider(self, u: Node, v: Node, w: Node) -> bool: + """ + Check if a node is a collider in a triple u -> v <- w + :param u: Node u + :param v: Node v + :param w: Node w + :return: True if the node is a collider, False otherwise + """ + if self.edge_of_type_exists(u, v, DirectedEdge()) and self.edge_of_type_exists( + w, v, DirectedEdge() + ): + return True + return False + + def _is_definite_noncollider(self, u: Node, v: Node, w: Node) -> bool: + """ + Check if a node is a definite non-collider in a triple u - v - w + :param u: Node u + :param v: Node v + :param w: Node w + :return: True if the node is a definite non-collider, False otherwise + """ + # if there is an outgoing edge from the middle node, it is a definite non-collider + if self.edge_of_type_exists(v, u, DirectedEdge()) or self.edge_of_type_exists( + v, w, DirectedEdge() + ): + return True + # if there is no edge between the outer nodes and it is not a collider, it is a definite non-collider + if (not self.edge_exists(u, w)) and (not self._is_collider(u, v, w)): + return True + return False + + def _is_path_of_definite_status(self, path: List[Node]): + """ + Check if a path is of definite status, i.e. if every node is either the start or end node, a collider or a definite non-collider (i.e. has an outgoing edge or is an unshielded triple on the path) + :param path: The path to check + :return: True if the path is of definite status, False otherwise + """ + + for i in range(1, len(path) - 1): + # is node a collider? + if self._is_collider(path[i - 1], path[i], path[i + 1]): + continue + # is node a definite non-collider because it has an outgoing edge? + elif self._is_definite_noncollider(path[i - 1], path[i], path[i + 1]): + continue + return False + return True + + def _is_definite_noncollider_in_conditioning_set( + self, path, conditioning_set + ) -> bool: + """ + Check if a path is blocked by a definite non-collider which is in the conditioning set. + :param path: + :param conditioning_set: + :return: + """ + is_path_blocked = False + for i in range(1, len(path) - 1): + if path[i] in conditioning_set and self._is_definite_noncollider( + path[i - 1], path[i], path[i + 1] + ): + is_path_blocked = True + return is_path_blocked + def all_paths_on_underlying_undirected_graph( self, u: Union[Node, str], v: Union[Node, str], visited=None, path=None ) -> List[List[Node]]: diff --git a/tests/test_graph.py b/tests/test_graph.py index cdf36f0..709d310 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -332,3 +332,324 @@ def test_descendants_of_node_not_in_descendants_cycles(self): self.assertIn(node2, result) self.assertIn(node3, result) self.assertIn(node1, result) + + def test_is_path_of_definite_status_directed_edges_three_nodes(self): + graph = GraphManager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + graph.add_directed_edge(node1, node2, {"test": "test"}) + graph.add_directed_edge(node2, node3, {"test": "test"}) + self.assertTrue(graph._is_path_of_definite_status([node1, node2, node3])) + self.assertTrue(graph._is_path_of_definite_status([node3, node2, node1])) + self.assertTrue(graph._is_path_of_definite_status([node1, node2])) + self.assertTrue(graph._is_path_of_definite_status([node3, node2])) + self.assertTrue(graph._is_path_of_definite_status([node2, node1])) + self.assertTrue(graph._is_path_of_definite_status([node2, node3])) + + def test_is_path_of_definite_status_undirected_edges_three_nodes_true(self): + graph = GraphManager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + graph.add_edge(node1, node2, {"test": "test"}) + graph.add_edge(node2, node3, {"test": "test"}) + self.assertTrue(graph._is_path_of_definite_status([node1, node2, node3])) + self.assertTrue(graph._is_path_of_definite_status([node3, node2, node1])) + self.assertTrue(graph._is_path_of_definite_status([node1, node2])) + self.assertTrue(graph._is_path_of_definite_status([node3, node2])) + self.assertTrue(graph._is_path_of_definite_status([node2, node1])) + self.assertTrue(graph._is_path_of_definite_status([node2, node3])) + + def test_is_path_of_definite_status_undirected_edges_three_nodes_false(self): + graph = GraphManager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + graph.add_edge(node1, node2, {"test": "test"}) + graph.add_edge(node2, node3, {"test": "test"}) + graph.add_edge(node1, node3, {"test": "test"}) + self.assertFalse(graph._is_path_of_definite_status([node1, node2, node3])) + self.assertFalse(graph._is_path_of_definite_status([node3, node2, node1])) + self.assertTrue(graph._is_path_of_definite_status([node1, node2])) + self.assertTrue(graph._is_path_of_definite_status([node3, node2])) + self.assertTrue(graph._is_path_of_definite_status([node2, node1])) + self.assertTrue(graph._is_path_of_definite_status([node2, node3])) + + def test_is_path_of_definite_status_undirected_edges_four_nodes(self): + graph = GraphManager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + node4 = graph.add_node("test4", [1, 2, 3]) + graph.add_edge(node1, node2, {"test": "test"}) + graph.add_edge(node2, node3, {"test": "test"}) + graph.add_edge(node3, node4, {"test": "test"}) + self.assertTrue(graph._is_path_of_definite_status([node1, node2, node3, node4])) + self.assertTrue(graph._is_path_of_definite_status([node4, node3, node2, node1])) + self.assertTrue(graph._is_path_of_definite_status([node1, node2, node3])) + self.assertTrue(graph._is_path_of_definite_status([node3, node2, node1])) + self.assertTrue(graph._is_path_of_definite_status([node2, node3, node4])) + self.assertTrue(graph._is_path_of_definite_status([node4, node3, node2])) + self.assertTrue(graph._is_path_of_definite_status([node1, node2])) + self.assertTrue(graph._is_path_of_definite_status([node3, node2])) + self.assertTrue(graph._is_path_of_definite_status([node2, node1])) + self.assertTrue(graph._is_path_of_definite_status([node2, node3])) + self.assertTrue(graph._is_path_of_definite_status([node3, node4])) + self.assertTrue(graph._is_path_of_definite_status([node4, node3])) + + def test_is_path_of_definite_status_undirected_and_directed_edges_three_nodes_true( + self, + ): + graph = GraphManager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + graph.add_edge(node1, node2, {"test": "test"}) + graph.add_directed_edge(node2, node3, {"test": "test"}) + self.assertTrue(graph._is_path_of_definite_status([node1, node2, node3])) + self.assertTrue(graph._is_path_of_definite_status([node3, node2, node1])) + self.assertTrue(graph._is_path_of_definite_status([node1, node2])) + self.assertTrue(graph._is_path_of_definite_status([node3, node2])) + self.assertTrue(graph._is_path_of_definite_status([node2, node1])) + self.assertTrue(graph._is_path_of_definite_status([node2, node3])) + + def test_is_path_of_definite_status_undirected_and_directed_edges_three_nodes_false( + self, + ): + graph = GraphManager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + graph.add_edge(node1, node2, {"test": "test"}) + graph.add_edge(node1, node3, {"test": "test"}) + graph.add_directed_edge(node3, node2, {"test": "test"}) + self.assertFalse(graph._is_path_of_definite_status([node1, node2, node3])) + self.assertFalse(graph._is_path_of_definite_status([node3, node2, node1])) + self.assertTrue(graph._is_path_of_definite_status([node1, node2])) + self.assertTrue(graph._is_path_of_definite_status([node3, node2])) + self.assertTrue(graph._is_path_of_definite_status([node2, node1])) + self.assertTrue(graph._is_path_of_definite_status([node2, node3])) + + def test_is_path_of_definite_status_undirected_and_directed_edges_four_nodes_true( + self, + ): + graph = GraphManager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + node4 = graph.add_node("test4", [1, 2, 3]) + graph.add_edge(node2, node3, {"test": "test"}) + graph.add_directed_edge(node2, node1, {"test": "test"}) + graph.add_directed_edge(node3, node4, {"test": "test"}) + self.assertTrue(graph._is_path_of_definite_status([node1, node2, node3, node4])) + self.assertTrue(graph._is_path_of_definite_status([node4, node3, node2, node1])) + self.assertTrue(graph._is_path_of_definite_status([node1, node2, node3])) + self.assertTrue(graph._is_path_of_definite_status([node3, node2, node1])) + self.assertTrue(graph._is_path_of_definite_status([node2, node3, node4])) + self.assertTrue(graph._is_path_of_definite_status([node4, node3, node2])) + + def test_is_path_of_definite_status_undirected_and_directed_edges_four_nodes_false( + self, + ): + graph = GraphManager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + node4 = graph.add_node("test4", [1, 2, 3]) + graph.add_edge(node1, node2, {"test": "test"}) + graph.add_edge(node2, node3, {"test": "test"}) + graph.add_edge(node1, node3, {"test": "test"}) + graph.add_directed_edge(node3, node4, {"test": "test"}) + self.assertFalse( + graph._is_path_of_definite_status([node1, node2, node3, node4]) + ) + self.assertFalse( + graph._is_path_of_definite_status([node4, node3, node2, node1]) + ) + self.assertFalse(graph._is_path_of_definite_status([node1, node2, node3])) + self.assertFalse(graph._is_path_of_definite_status([node3, node2, node1])) + self.assertTrue(graph._is_path_of_definite_status([node2, node3, node4])) + self.assertTrue(graph._is_path_of_definite_status([node4, node3, node2])) + + def test_is_path_of_definite_status_directed_edges_four_nodes(self): + graph = GraphManager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + node4 = graph.add_node("test4", [1, 2, 3]) + graph.add_directed_edge(node1, node2, {"test": "test"}) + graph.add_directed_edge(node3, node2, {"test": "test"}) + graph.add_directed_edge(node3, node4, {"test": "test"}) + self.assertTrue(graph._is_path_of_definite_status([node1, node2, node3, node4])) + self.assertTrue(graph._is_path_of_definite_status([node4, node3, node2, node1])) + self.assertTrue(graph._is_path_of_definite_status([node1, node2, node3])) + self.assertTrue(graph._is_path_of_definite_status([node3, node2, node1])) + self.assertTrue(graph._is_path_of_definite_status([node2, node3, node4])) + self.assertTrue(graph._is_path_of_definite_status([node4, node3, node2])) + + def test_is_definite_noncollider_in_conditioning_set_three_nodes_directed(self): + graph = GraphManager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + graph.add_directed_edge(node1, node2, {"test": "test"}) + graph.add_directed_edge(node2, node3, {"test": "test"}) + self.assertTrue( + graph._is_definite_noncollider_in_conditioning_set( + [node1, node2, node3], [node2] + ) + ) + + def test_is_definite_noncollider_in_conditioning_set_three_nodes_directed_false( + self, + ): + graph = GraphManager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + graph.add_directed_edge(node1, node2, {"test": "test"}) + graph.add_directed_edge(node2, node3, {"test": "test"}) + self.assertFalse( + graph._is_definite_noncollider_in_conditioning_set( + [node1, node2, node3], [] + ) + ) + + def test_is_definite_noncollider_in_conditioning_set_three_nodes_undirected(self): + graph = GraphManager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + graph.add_edge(node1, node2, {"test": "test"}) + graph.add_edge(node2, node3, {"test": "test"}) + self.assertTrue( + graph._is_definite_noncollider_in_conditioning_set( + [node1, node2, node3], [node2] + ) + ) + + def test_is_definite_noncollider_in_conditioning_set_three_nodes_undirected_false( + self, + ): + graph = GraphManager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + graph.add_edge(node1, node2, {"test": "test"}) + graph.add_edge(node2, node3, {"test": "test"}) + self.assertFalse( + graph._is_definite_noncollider_in_conditioning_set( + [node1, node2, node3], [] + ) + ) + + def test_is_definite_noncollider_in_conditioning_set_four_nodes_directed(self): + graph = GraphManager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + node4 = graph.add_node("test4", [1, 2, 3]) + graph.add_directed_edge(node1, node2, {"test": "test"}) + graph.add_directed_edge(node2, node3, {"test": "test"}) + graph.add_directed_edge(node3, node4, {"test": "test"}) + self.assertTrue( + graph._is_definite_noncollider_in_conditioning_set( + [node1, node2, node3, node4], [node2] + ) + ) + self.assertTrue( + graph._is_definite_noncollider_in_conditioning_set( + [node1, node2, node3, node4], [node3] + ) + ) + self.assertTrue( + graph._is_definite_noncollider_in_conditioning_set( + [node4, node3, node2, node1], [node2] + ) + ) + self.assertTrue( + graph._is_definite_noncollider_in_conditioning_set( + [node4, node3, node2, node1], [node3] + ) + ) + + def test_is_definite_noncollider_outgoing_edge(self): + graph = GraphManager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + graph.add_directed_edge(node1, node2, {"test": "test"}) + graph.add_directed_edge(node2, node3, {"test": "test"}) + self.assertTrue(graph._is_definite_noncollider(node1, node2, node3)) + + def test_is_definite_noncollider_unshielded_triple(self): + graph = GraphManager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + node4 = graph.add_node("test4", [1, 2, 3]) + graph.add_edge(node1, node2, {"test": "test"}) + graph.add_edge(node2, node3, {"test": "test"}) + graph.add_edge(node3, node4, {"test": "test"}) + self.assertTrue(graph._is_definite_noncollider(node1, node2, node3)) + self.assertTrue(graph._is_definite_noncollider(node2, node3, node4)) + + def test_is_collider(self): + graph = GraphManager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + graph.add_directed_edge(node1, node2, {"test": "test"}) + graph.add_directed_edge(node3, node2, {"test": "test"}) + self.assertTrue(graph._is_collider(node1, node2, node3)) + + def test_are_nodes_d_separated_cpdag_three_nodes_directed(self): + graph = GraphManager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + graph.add_directed_edge(node1, node2, {"test": "test"}) + graph.add_directed_edge(node2, node3, {"test": "test"}) + self.assertTrue(graph.are_nodes_d_separated_cpdag(node1, node3, [node2])) + self.assertFalse(graph.are_nodes_d_separated_cpdag(node1, node3, [])) + + def test_are_nodes_d_separated_cpdag_four_nodes_no_colliders(self): + graph = GraphManager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + node4 = graph.add_node("test4", [1, 2, 3]) + graph.add_edge(node2, node3, {"test": "test"}) + graph.add_directed_edge(node1, node2, {"test": "test"}) + graph.add_directed_edge(node3, node4, {"test": "test"}) + self.assertTrue(graph.are_nodes_d_separated_cpdag(node1, node4, [node2, node3])) + self.assertTrue(graph.are_nodes_d_separated_cpdag(node1, node4, [node2])) + self.assertTrue(graph.are_nodes_d_separated_cpdag(node1, node4, [node3])) + self.assertFalse(graph.are_nodes_d_separated_cpdag(node1, node4, [])) + + def test_are_nodes_d_separated_four_nodes_with_colliders(self): + graph = GraphManager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + node4 = graph.add_node("test4", [1, 2, 3]) + graph.add_directed_edge(node1, node2, {"test": "test"}) + graph.add_directed_edge(node3, node2, {"test": "test"}) + graph.add_directed_edge(node2, node4, {"test": "test"}) + self.assertFalse(graph.are_nodes_d_separated_dag(node1, node3, [node4])) + self.assertFalse(graph.are_nodes_d_separated_dag(node1, node3, [node2])) + self.assertTrue(graph.are_nodes_d_separated_dag(node1, node3, [])) + + def test_are_nodes_d_separated_cpdag_four_nodes_with_colliders(self): + graph = GraphManager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + node4 = graph.add_node("test4", [1, 2, 3]) + graph.add_directed_edge(node1, node2, {"test": "test"}) + graph.add_directed_edge(node3, node2, {"test": "test"}) + graph.add_directed_edge(node2, node4, {"test": "test"}) + self.assertFalse(graph.are_nodes_d_separated_cpdag(node1, node3, [node4])) + self.assertFalse(graph.are_nodes_d_separated_cpdag(node1, node3, [node2])) + self.assertTrue(graph.are_nodes_d_separated_cpdag(node1, node3, [])) diff --git a/tests/test_path_classification.py b/tests/test_path_classification.py index abe5eee..500b3ad 100644 --- a/tests/test_path_classification.py +++ b/tests/test_path_classification.py @@ -163,7 +163,7 @@ def test_all_paths_on_underlying_undirected_graph_several_paths(self): [l for l in graph.all_paths_on_underlying_undirected_graph(node1, node3)], ) - def test_are_nodes_d_separated_open_path_mediated(self): + def test_are_nodes_d_separated_dag_open_path_mediated(self): new_graph_manager = GraphManager new_graph_manager.__bases__ = ( GraphBaseAccessMixin, @@ -175,9 +175,9 @@ def test_are_nodes_d_separated_open_path_mediated(self): node3 = graph.add_node("test3", [1, 2, 3]) graph.add_directed_edge(node1, node2, {"test": "test"}) graph.add_directed_edge(node2, node3, {"test": "test"}) - self.assertFalse(graph.are_nodes_d_separated(node1, node3, [])) + self.assertFalse(graph.are_nodes_d_separated_dag(node1, node3, [])) - def test_are_nodes_d_separated_blocked_path_mediated(self): + def test_are_nodes_d_separated_dag_blocked_path_mediated(self): new_graph_manager = GraphManager new_graph_manager.__bases__ = ( GraphBaseAccessMixin, @@ -189,9 +189,9 @@ def test_are_nodes_d_separated_blocked_path_mediated(self): node3 = graph.add_node("test3", [1, 2, 3]) graph.add_directed_edge(node1, node2, {"test": "test"}) graph.add_directed_edge(node2, node3, {"test": "test"}) - (self.assertTrue(graph.are_nodes_d_separated(node1, node3, [node2]))) + (self.assertTrue(graph.are_nodes_d_separated_dag(node1, node3, [node2]))) - def test_are_nodes_d_separated_blocked_path_mediated(self): + def test_are_nodes_d_separated_dag_blocked_path_mediated(self): rdnv = self.seeded_random.normalvariate model = IIDSampleGenerator( edges=[ @@ -212,14 +212,14 @@ def test_are_nodes_d_separated_blocked_path_mediated(self): self.assertEqual( True, - tst.graph.are_nodes_d_separated( + tst.graph.are_nodes_d_separated_dag( tst.graph.node_by_id("X"), tst.graph.node_by_id("Z"), [tst.graph.node_by_id("Y")], ), ) - def test_are_nodes_d_separated_open_path_cpdag(self): + def test_are_nodes_d_separated_dag_open_path_cpdag(self): new_graph_manager = GraphManager new_graph_manager.__bases__ = ( GraphBaseAccessMixin, @@ -232,9 +232,9 @@ def test_are_nodes_d_separated_open_path_cpdag(self): # TODO: check add edge again to see if this is the correct way to test cpdags with undirected edges graph.add_edge(node1, node2, {"test": "test"}) graph.add_edge(node3, node2, {"test": "test"}) - (self.assertTrue(graph.are_nodes_d_separated(node1, node3, [node2]))) + (self.assertTrue(graph.are_nodes_d_separated_dag(node1, node3, [node2]))) - def test_are_nodes_d_separated_open_path_confounder(self): + def test_are_nodes_d_separated_dag_open_path_confounder(self): new_graph_manager = GraphManager new_graph_manager.__bases__ = ( GraphBaseAccessMixin, @@ -246,9 +246,9 @@ def test_are_nodes_d_separated_open_path_confounder(self): node3 = graph.add_node("test3", [1, 2, 3]) graph.add_directed_edge(node2, node1, {"test": "test"}) graph.add_directed_edge(node2, node3, {"test": "test"}) - self.assertFalse(graph.are_nodes_d_separated(node1, node3, [])) + self.assertFalse(graph.are_nodes_d_separated_dag(node1, node3, [])) - def test_are_nodes_d_separated_by_conditioning_on_noncollider(self): + def test_are_nodes_d_separated_dag_by_conditioning_on_noncollider(self): new_graph_manager = GraphManager new_graph_manager.__bases__ = ( GraphBaseAccessMixin, @@ -260,9 +260,9 @@ def test_are_nodes_d_separated_by_conditioning_on_noncollider(self): node3 = graph.add_node("test3", [1, 2, 3]) graph.add_directed_edge(node1, node2, {"test": "test"}) graph.add_directed_edge(node2, node3, {"test": "test"}) - self.assertTrue(graph.are_nodes_d_separated(node1, node3, [node2])) + self.assertTrue(graph.are_nodes_d_separated_dag(node1, node3, [node2])) - def test_are_nodes_d_separated_by_a_collider(self): + def test_are_nodes_d_separated_dag_by_a_collider(self): new_graph_manager = GraphManager new_graph_manager.__bases__ = ( GraphBaseAccessMixin, @@ -274,9 +274,9 @@ def test_are_nodes_d_separated_by_a_collider(self): node3 = graph.add_node("test3", [1, 2, 3]) graph.add_directed_edge(node1, node2, {"test": "test"}) graph.add_directed_edge(node3, node2, {"test": "test"}) - self.assertTrue(graph.are_nodes_d_separated(node1, node3, [])) + self.assertTrue(graph.are_nodes_d_separated_dag(node1, node3, [])) - def test_are_nodes_d_separated_open_by_conditioning_on_collider(self): + def test_are_nodes_d_separated_dag_open_by_conditioning_on_collider(self): new_graph_manager = GraphManager new_graph_manager.__bases__ = ( GraphBaseAccessMixin, @@ -288,9 +288,9 @@ def test_are_nodes_d_separated_open_by_conditioning_on_collider(self): node3 = graph.add_node("test3", [1, 2, 3], "test3") graph.add_directed_edge(node1, node2, {"test": "test"}) graph.add_directed_edge(node3, node2, {"test": "test"}) - self.assertFalse(graph.are_nodes_d_separated(node1, node3, [node2])) + self.assertFalse(graph.are_nodes_d_separated_dag(node1, node3, [node2])) - def test_are_nodes_d_separated_two_paths_one_open_one_blocked(self): + def test_are_nodes_d_separated_dag_two_paths_one_open_one_blocked(self): new_graph_manager = GraphManager new_graph_manager.__bases__ = ( GraphBaseAccessMixin, @@ -303,9 +303,9 @@ def test_are_nodes_d_separated_two_paths_one_open_one_blocked(self): graph.add_directed_edge(node1, node2, {"test": "test"}) graph.add_directed_edge(node3, node2, {"test": "test"}) graph.add_directed_edge(node1, node3, {"test": "test"}) - self.assertFalse(graph.are_nodes_d_separated(node1, node3, [])) + self.assertFalse(graph.are_nodes_d_separated_dag(node1, node3, [])) - def test_are_nodes_d_separated_no_path_empty_conditioning_set(self): + def test_are_nodes_d_separated_dag_no_path_empty_conditioning_set(self): new_graph_manager = GraphManager new_graph_manager.__bases__ = ( GraphBaseAccessMixin, @@ -316,9 +316,9 @@ def test_are_nodes_d_separated_no_path_empty_conditioning_set(self): node2 = graph.add_node("test2", [1, 2, 3], "test2") node3 = graph.add_node("test3", [1, 2, 3], "test3") graph.add_directed_edge(node1, node2, {"test": "test"}) - self.assertTrue(graph.are_nodes_d_separated(node1, node3, [])) + self.assertTrue(graph.are_nodes_d_separated_dag(node1, node3, [])) - def test_are_nodes_d_separated_no_path_with_nonempty_conditioning_set(self): + def test_are_nodes_d_separated_dag_no_path_with_nonempty_conditioning_set(self): new_graph_manager = GraphManager new_graph_manager.__bases__ = ( GraphBaseAccessMixin, @@ -329,9 +329,9 @@ def test_are_nodes_d_separated_no_path_with_nonempty_conditioning_set(self): node2 = graph.add_node("test2", [1, 2, 3], "test2") node3 = graph.add_node("test3", [1, 2, 3], "test3") graph.add_directed_edge(node1, node2, {"test": "test"}) - self.assertTrue(graph.are_nodes_d_separated(node1, node3, [node2])) + self.assertTrue(graph.are_nodes_d_separated_dag(node1, node3, [node2])) - def test_are_nodes_d_separated_connected_by_conditioning_on_collider_with_more_descendants( + def test_are_nodes_d_separated_dag_connected_by_conditioning_on_collider_with_more_descendants( self, ): new_graph_manager = GraphManager @@ -349,9 +349,9 @@ def test_are_nodes_d_separated_connected_by_conditioning_on_collider_with_more_d graph.add_directed_edge(node3, node2, {"test": "test"}) graph.add_directed_edge(node2, node4, {"test": "test"}) graph.add_directed_edge(node2, node5, {"test": "test"}) - self.assertFalse(graph.are_nodes_d_separated(node1, node3, [node2])) + self.assertFalse(graph.are_nodes_d_separated_dag(node1, node3, [node2])) - def test_are_nodes_d_separated_connected_by_conditioning_on_descendant_of_collider( + def test_are_nodes_d_separated_dag_connected_by_conditioning_on_descendant_of_collider( self, ): new_graph_manager = GraphManager @@ -369,4 +369,4 @@ def test_are_nodes_d_separated_connected_by_conditioning_on_descendant_of_collid graph.add_directed_edge(node3, node2, {"test": "test"}) graph.add_directed_edge(node2, node4, {"test": "test"}) graph.add_directed_edge(node2, node5, {"test": "test"}) - self.assertFalse(graph.are_nodes_d_separated(node1, node3, [node4])) + self.assertFalse(graph.are_nodes_d_separated_dag(node1, node3, [node4])) diff --git a/tests/test_pc_e2e.py b/tests/test_pc_e2e.py index f223968..97f72a2 100644 --- a/tests/test_pc_e2e.py +++ b/tests/test_pc_e2e.py @@ -486,3 +486,25 @@ def test_track_triples_three_nodes_pc_unconditionally_independent(self): triples.append(action.data["triple"]) # TODO: find issue with tracking in partial correlation test in this setting pass + + def test_d_separation_on_output_of_pc(self): + rdnv = self.seeded_random.normalvariate + sample_generator = IIDSampleGenerator( + edges=[ + SampleEdge(NodeReference("X"), NodeReference("Y"), 5), + SampleEdge(NodeReference("Y"), NodeReference("Z"), 6), + ], + random=lambda: rdnv(0, 1), + ) + test_data, graph = sample_generator.generate(10000) + tst = PC() + tst.create_graph_from_data(test_data) + tst.create_all_possible_edges() + tst.execute_pipeline_steps() + + self.assertGraphStructureIsEqual(tst.graph, graph) + x = tst.graph.node_by_id("X") + y = tst.graph.node_by_id("Y") + z = tst.graph.node_by_id("Z") + self.assertEqual(tst.graph.are_nodes_d_separated_cpdag(x, z, []), False) + self.assertEqual(tst.graph.are_nodes_d_separated_cpdag(x, z, [y]), True) From 76da8a9da1724a3705b7c440fbc7abb85b148593 Mon Sep 17 00:00:00 2001 From: Sofia Faltenbacher Date: Tue, 28 Jan 2025 16:09:02 +0100 Subject: [PATCH 2/2] test(are_nodes_d_separated_cpdag): add one more edge case to tests --- tests/test_graph.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test_graph.py b/tests/test_graph.py index 709d310..0122d0e 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -653,3 +653,16 @@ def test_are_nodes_d_separated_cpdag_four_nodes_with_colliders(self): self.assertFalse(graph.are_nodes_d_separated_cpdag(node1, node3, [node4])) self.assertFalse(graph.are_nodes_d_separated_cpdag(node1, node3, [node2])) self.assertTrue(graph.are_nodes_d_separated_cpdag(node1, node3, [])) + + def test_are_nodes_d_separated_cpdag_three_nodes_fully_connected_undirected_false( + self, + ): + graph = GraphManager() + node1 = graph.add_node("test1", [1, 2, 3]) + node2 = graph.add_node("test2", [1, 2, 3]) + node3 = graph.add_node("test3", [1, 2, 3]) + graph.add_edge(node1, node2, {"test": "test"}) + graph.add_edge(node2, node3, {"test": "test"}) + graph.add_edge(node1, node3, {"test": "test"}) + self.assertFalse(graph.are_nodes_d_separated_cpdag(node1, node3, [])) + self.assertFalse(graph.are_nodes_d_separated_cpdag(node1, node3, [node2]))