Skip to content

Commit

Permalink
Merge pull request #72 from causy-dev/add_d_separation_for_cpdags
Browse files Browse the repository at this point in the history
Add d separation for cpdags
  • Loading branch information
this-is-sofia authored Jan 28, 2025
2 parents 2dfab11 + 76da8a9 commit 743570d
Show file tree
Hide file tree
Showing 4 changed files with 498 additions and 31 deletions.
121 changes: 116 additions & 5 deletions causy/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,9 +342,7 @@ def _is_a_collider_blocking(self, path, conditioning_set) -> bool:
"""
is_path_blocked = False
for i in range(1, len(path) - 1):
if self.edge_of_type_exists(
path[i - 1], path[i], DirectedEdge()
) and self.edge_of_type_exists(path[i + 1], path[i], DirectedEdge()):
if self._is_collider(path[i - 1], path[i], path[i + 1]):
# if the node is a collider, check if the node or any of its descendants are in the conditioning set
is_path_blocked = True
for descendant in self.descendants_of_node(path[i]):
Expand Down Expand Up @@ -373,14 +371,14 @@ def _is_a_non_collider_in_conditioning_set(self, path, conditioning_set) -> bool
is_path_blocked = True
return is_path_blocked

def are_nodes_d_separated(
def are_nodes_d_separated_dag(
self,
u: Union[Node, str],
v: Union[Node, str],
conditioning_set: List[Union[Node, str]],
) -> bool:
"""
Check if nodes u and v are d-separated given a conditioning set. We check whether there is an open path, i.e. a path on which all colliders are in the conditioning set and all non-colliders are not in the conditioning set. If there is no open path, u and v are d-separated.
Check if nodes u and v are d-separated given a conditioning set in a DAG. We check whether there is an open path, i.e. a path on which all colliders are in the conditioning set and all non-colliders are not in the conditioning set. If there is no open path, u and v are d-separated.
:param u: First node
:param v: Second node
Expand Down Expand Up @@ -422,6 +420,119 @@ def are_nodes_d_separated(
return False
return True

def are_nodes_d_separated_cpdag(
self,
u: Union[Node, str],
v: Union[Node, str],
conditioning_set: List[Union[Node, str]],
) -> bool:
"""
Check if nodes u and v are d-separated given a conditioning set in a CPDAG following Perkovic, 2020. (Identifying causal effects in maximally oriented partially directed acyclic graphs)
:param u: First node
:param v: Second node
:param conditioning_set: Set of nodes to condition on
:return: True if u and v are d-separated given conditioning_set, False otherwise
"""

# Convert Node instances to their IDs
if isinstance(u, Node):
u = u.id
if isinstance(v, Node):
v = v.id

# u and v may not be in the conditioning set, throw error
if u in conditioning_set or v in conditioning_set:
raise ValueError("Nodes u and v may not be in the conditioning set")

# If there are no paths between u and v, they are d-separated
if list(self.all_paths_on_underlying_undirected_graph(u, v)) == []:
return True

list_of_results_for_paths = []
for path in self.all_paths_on_underlying_undirected_graph(u, v):
# If the path only has two nodes, it cannot be blocked and is open. Therefore, u and v are not d-separated
if len(path) == 2:
return False
is_path_blocked = False
if self._is_path_of_definite_status(path):
if self._is_definite_noncollider_in_conditioning_set(
path, conditioning_set
):
is_path_blocked = True
if self._is_a_collider_blocking(path, conditioning_set):
is_path_blocked = True
list_of_results_for_paths.append(is_path_blocked)
if False in list_of_results_for_paths:
return False
return True

def _is_collider(self, u: Node, v: Node, w: Node) -> bool:
"""
Check if a node is a collider in a triple u -> v <- w
:param u: Node u
:param v: Node v
:param w: Node w
:return: True if the node is a collider, False otherwise
"""
if self.edge_of_type_exists(u, v, DirectedEdge()) and self.edge_of_type_exists(
w, v, DirectedEdge()
):
return True
return False

def _is_definite_noncollider(self, u: Node, v: Node, w: Node) -> bool:
"""
Check if a node is a definite non-collider in a triple u - v - w
:param u: Node u
:param v: Node v
:param w: Node w
:return: True if the node is a definite non-collider, False otherwise
"""
# if there is an outgoing edge from the middle node, it is a definite non-collider
if self.edge_of_type_exists(v, u, DirectedEdge()) or self.edge_of_type_exists(
v, w, DirectedEdge()
):
return True
# if there is no edge between the outer nodes and it is not a collider, it is a definite non-collider
if (not self.edge_exists(u, w)) and (not self._is_collider(u, v, w)):
return True
return False

def _is_path_of_definite_status(self, path: List[Node]):
"""
Check if a path is of definite status, i.e. if every node is either the start or end node, a collider or a definite non-collider (i.e. has an outgoing edge or is an unshielded triple on the path)
:param path: The path to check
:return: True if the path is of definite status, False otherwise
"""

for i in range(1, len(path) - 1):
# is node a collider?
if self._is_collider(path[i - 1], path[i], path[i + 1]):
continue
# is node a definite non-collider because it has an outgoing edge?
elif self._is_definite_noncollider(path[i - 1], path[i], path[i + 1]):
continue
return False
return True

def _is_definite_noncollider_in_conditioning_set(
self, path, conditioning_set
) -> bool:
"""
Check if a path is blocked by a definite non-collider which is in the conditioning set.
:param path:
:param conditioning_set:
:return:
"""
is_path_blocked = False
for i in range(1, len(path) - 1):
if path[i] in conditioning_set and self._is_definite_noncollider(
path[i - 1], path[i], path[i + 1]
):
is_path_blocked = True
return is_path_blocked

def all_paths_on_underlying_undirected_graph(
self, u: Union[Node, str], v: Union[Node, str], visited=None, path=None
) -> List[List[Node]]:
Expand Down
Loading

0 comments on commit 743570d

Please sign in to comment.