Skip to content

Commit

Permalink
Merge pull request #62 from causy-dev/d-separation_check
Browse files Browse the repository at this point in the history
D separation check
  • Loading branch information
this-is-sofia authored Nov 4, 2024
2 parents 25ec80d + edc19cf commit b4ff852
Show file tree
Hide file tree
Showing 4 changed files with 326 additions and 122 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class CorrelationCoefficientTest(
parallel: BoolParameter = False

def process(
self, nodes: List[str], graph: BaseGraphInterface
self, nodes: List[str], graph: BaseGraphInterface
) -> Optional[TestResult]:
"""
Test if u and v are independent and delete edge in graph if they are.
Expand Down Expand Up @@ -204,7 +204,9 @@ def process(
helper = torch.mm(torch.sqrt(diagonal_matrix), inverse_cov_matrix)
precision_matrix = torch.mm(helper, torch.sqrt(diagonal_matrix))

par_corr = (-1 * precision_matrix[0][1]) / torch.sqrt(precision_matrix[0][0] * precision_matrix[1][1])
par_corr = (-1 * precision_matrix[0][1]) / torch.sqrt(
precision_matrix[0][0] * precision_matrix[1][1]
)

sample_size = len(graph.nodes[nodes[0]].values)
nb_of_control_vars = len(nodes) - 2
Expand Down
116 changes: 89 additions & 27 deletions causy/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,95 @@ def directed_path_exists(self, u: Union[Node, str], v: Union[Node, str]) -> bool
return True
return False

def edge_of_type_exists(
self,
u: Union[Node, str],
v: Union[Node, str],
edge_type: EdgeTypeInterface = DirectedEdge(),
) -> bool:
"""
Check if an edge of a specific type exists between u and v.
:param u: node u
:param v: node v
:param edge_type: the type of the edge to check for
:return: True if an edge of this type exists, False otherwise
"""

if isinstance(u, Node):
u = u.id
if isinstance(v, Node):
v = v.id

if not self.directed_edge_exists(u, v):
return False

if self.edges[u][v].edge_type != edge_type:
return False

return True

def are_nodes_d_separated(
self,
u: Union[Node, str],
v: Union[Node, str],
conditioning_set: List[Union[Node, str]],
) -> bool:
"""
Check if nodes u and v are d-separated given a conditioning set. We check whether there is an open path, i.e. a path on which all colliders are in the conditioning set and all non-colliders are not in the conditioning set. If there is no open path, u and v are d-separated.
:param u: First node
:param v: Second node
:param conditioning_set: Set of nodes to condition on
:return: True if u and v are d-separated given conditioning_set, False otherwise
"""

# Convert Node instances to their IDs
if isinstance(u, Node):
u = u.id
if isinstance(v, Node):
v = v.id

# u and v may not be in the conditioning set, throw error
if u in conditioning_set or v in conditioning_set:
raise ValueError("Nodes u and v may not be in the conditioning set")

# check whether there is an open path on which all colliders are in the conditioning set and all non-colliders are not in the conditioning set
list_of_results_for_paths = []
for path in self.all_paths_on_underlying_undirected_graph(u, v):
if len(path) == 2:
list_of_results_for_paths.append(False)

for i in range(1, len(path) - 1):
is_path_blocked = False

# paths are d-separated if a collider is not in the conditioning set
if path[i] not in conditioning_set:
if self.edge_of_type_exists(
path[i - 1].id, path[i].id, DirectedEdge()
) and self.edge_of_type_exists(
path[i + 1].id, path[i].id, DirectedEdge()
):
is_path_blocked = True

# paths are d-separated if a non-collider is in the conditioning set
elif path[i] in conditioning_set:
if not (
self.edge_of_type_exists(
path[i - 1].id, path[i].id, DirectedEdge()
)
and self.edge_of_type_exists(
path[i + 1].id, path[i].id, DirectedEdge()
)
):
is_path_blocked = True

list_of_results_for_paths.append(is_path_blocked)

# if there is at least one open path, u and v are not d-separated
if False in list_of_results_for_paths:
return False
return True

def all_paths_on_underlying_undirected_graph(
self, u: Union[Node, str], v: Union[Node, str], visited=None, path=None
) -> List[List[Node]]:
Expand Down Expand Up @@ -307,33 +396,6 @@ def all_paths_on_underlying_undirected_graph(
path.pop()
visited.remove(u)

def edge_of_type_exists(
self,
u: Union[Node, str],
v: Union[Node, str],
edge_type: EdgeTypeInterface = DirectedEdge(),
) -> bool:
"""
Check if an edge of a specific type exists between u and v.
:param u: node u
:param v: node v
:param edge_type: the type of the edge to check for
:return: True if an edge of this type exists, False otherwise
"""

if isinstance(u, Node):
u = u.id
if isinstance(v, Node):
v = v.id

if not self.edge_exists(u, v):
return False

if self.edges[u][v].edge_type != edge_type:
return False

return True

def retrieve_edges(self) -> List[Edge]:
"""
Retrieve all edges
Expand Down
233 changes: 233 additions & 0 deletions tests/test_path_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
from causy.edge_types import DirectedEdge
from causy.graph import GraphBaseAccessMixin, GraphManager
from tests.utils import CausyTestCase


class GraphTestCase(CausyTestCase):
def test_all_paths_on_underlying_undirected_graph(self):
new_graph_manager = GraphManager
new_graph_manager.__bases__ = (
GraphBaseAccessMixin,
DirectedEdge.GraphAccessMixin,
)
graph = new_graph_manager()
node1 = graph.add_node("test1", [1, 2, 3])
node2 = graph.add_node("test2", [1, 2, 3])
node3 = graph.add_node("test3", [1, 2, 3])
graph.add_directed_edge(node1, node2, {"test": "test"})
graph.add_directed_edge(node2, node3, {"test": "test"})
self.assertIn(
[node1, node2, node3],
[
x
for x in [
l
for l in graph.all_paths_on_underlying_undirected_graph(
node1, node3
)
]
],
)
self.assertNotIn(
[node1, node2],
[l for l in graph.all_paths_on_underlying_undirected_graph(node1, node3)],
)
self.assertNotIn(
[node2, node3],
[l for l in graph.all_paths_on_underlying_undirected_graph(node1, node3)],
)
self.assertNotIn(
[node1, node3],
[l for l in graph.all_paths_on_underlying_undirected_graph(node1, node3)],
)

def test_all_paths_on_underlying_undirected_graph_2(self):
new_graph_manager = GraphManager
new_graph_manager.__bases__ = (
GraphBaseAccessMixin,
DirectedEdge.GraphAccessMixin,
)
graph = new_graph_manager()
node1 = graph.add_node("test1", [1, 2, 3])
node2 = graph.add_node("test2", [1, 2, 3])
node3 = graph.add_node("test3", [1, 2, 3])
graph.add_directed_edge(node3, node2, {"test": "test"})
graph.add_directed_edge(node2, node1, {"test": "test"})
self.assertIn(
[node1, node2, node3],
[
x
for x in [
l
for l in graph.all_paths_on_underlying_undirected_graph(
node1, node3
)
]
],
)
self.assertNotIn(
[node1, node2],
[l for l in graph.all_paths_on_underlying_undirected_graph(node1, node3)],
)
self.assertNotIn(
[node2, node3],
[l for l in graph.all_paths_on_underlying_undirected_graph(node1, node3)],
)
self.assertNotIn(
[node1, node3],
[l for l in graph.all_paths_on_underlying_undirected_graph(node1, node3)],
)

def test_all_paths_on_underlying_undirected_graph_collider_path(self):
new_graph_manager = GraphManager
new_graph_manager.__bases__ = (
GraphBaseAccessMixin,
DirectedEdge.GraphAccessMixin,
)
graph = new_graph_manager()
node1 = graph.add_node("test1", [1, 2, 3])
node2 = graph.add_node("test2", [1, 2, 3])
node3 = graph.add_node("test3", [1, 2, 3])
graph.add_directed_edge(node3, node2, {"test": "test"})
graph.add_directed_edge(node1, node2, {"test": "test"})
self.assertIn(
[node1, node2, node3],
[
x
for x in [
l
for l in graph.all_paths_on_underlying_undirected_graph(
node1, node3
)
]
],
)
self.assertNotIn(
[node1, node2],
[l for l in graph.all_paths_on_underlying_undirected_graph(node1, node3)],
)
self.assertNotIn(
[node2, node3],
[l for l in graph.all_paths_on_underlying_undirected_graph(node1, node3)],
)
self.assertNotIn(
[node1, node3],
[l for l in graph.all_paths_on_underlying_undirected_graph(node1, node3)],
)

def test_all_paths_on_underlying_undirected_graph_several_paths(self):
new_graph_manager = GraphManager
new_graph_manager.__bases__ = (
GraphBaseAccessMixin,
DirectedEdge.GraphAccessMixin,
)
graph = new_graph_manager()
node1 = graph.add_node("test1", [1, 2, 3])
node2 = graph.add_node("test2", [1, 2, 3])
node3 = graph.add_node("test3", [1, 2, 3])
graph.add_directed_edge(node1, node2, {"test": "test"})
graph.add_directed_edge(node2, node3, {"test": "test"})
graph.add_directed_edge(node1, node3, {"test": "test"})
self.assertIn(
[node1, node2, node3],
[
x
for x in [
l
for l in graph.all_paths_on_underlying_undirected_graph(
node1, node3
)
]
],
)
self.assertIn(
[node1, node3],
[
x
for x in [
l
for l in graph.all_paths_on_underlying_undirected_graph(
node1, node3
)
]
],
)
self.assertNotIn(
[node1, node2],
[l for l in graph.all_paths_on_underlying_undirected_graph(node1, node3)],
)
self.assertNotIn(
[node2, node3],
[l for l in graph.all_paths_on_underlying_undirected_graph(node1, node3)],
)

def test_are_nodes_d_separated(self):
new_graph_manager = GraphManager
new_graph_manager.__bases__ = (
GraphBaseAccessMixin,
DirectedEdge.GraphAccessMixin,
)
graph = new_graph_manager()
node1 = graph.add_node("test1", [1, 2, 3])
node2 = graph.add_node("test2", [1, 2, 3])
node3 = graph.add_node("test3", [1, 2, 3])
graph.add_directed_edge(node1, node2, {"test": "test"})
graph.add_directed_edge(node2, node3, {"test": "test"})
self.assertFalse(graph.are_nodes_d_separated(node1, node3, []))

def test_are_nodes_d_separated_2(self):
new_graph_manager = GraphManager
new_graph_manager.__bases__ = (
GraphBaseAccessMixin,
DirectedEdge.GraphAccessMixin,
)
graph = new_graph_manager()
node1 = graph.add_node("test1", [1, 2, 3])
node2 = graph.add_node("test2", [1, 2, 3])
node3 = graph.add_node("test3", [1, 2, 3])
graph.add_directed_edge(node1, node2, {"test": "test"})
graph.add_directed_edge(node2, node3, {"test": "test"})
self.assertTrue(graph.are_nodes_d_separated(node1, node3, [node2]))

def test_are_nodes_d_separated_3(self):
new_graph_manager = GraphManager
new_graph_manager.__bases__ = (
GraphBaseAccessMixin,
DirectedEdge.GraphAccessMixin,
)
graph = new_graph_manager()
node1 = graph.add_node("test1", [1, 2, 3])
node2 = graph.add_node("test2", [1, 2, 3])
node3 = graph.add_node("test3", [1, 2, 3])
graph.add_directed_edge(node1, node2, {"test": "test"})
graph.add_directed_edge(node3, node2, {"test": "test"})
self.assertTrue(graph.are_nodes_d_separated(node1, node3, []))

def test_are_nodes_d_separated_4(self):
new_graph_manager = GraphManager
new_graph_manager.__bases__ = (
GraphBaseAccessMixin,
DirectedEdge.GraphAccessMixin,
)
graph = new_graph_manager()
node1 = graph.add_node("test1", [1, 2, 3])
node2 = graph.add_node("test2", [1, 2, 3])
node3 = graph.add_node("test3", [1, 2, 3])
graph.add_directed_edge(node1, node2, {"test": "test"})
graph.add_directed_edge(node3, node2, {"test": "test"})
self.assertFalse(graph.are_nodes_d_separated(node1, node3, [node2]))

def test_are_nodes_d_separated_3(self):
new_graph_manager = GraphManager
new_graph_manager.__bases__ = (
GraphBaseAccessMixin,
DirectedEdge.GraphAccessMixin,
)
graph = new_graph_manager()
node1 = graph.add_node("test1", [1, 2, 3])
node2 = graph.add_node("test2", [1, 2, 3])
node3 = graph.add_node("test3", [1, 2, 3])
graph.add_directed_edge(node1, node2, {"test": "test"})
graph.add_directed_edge(node3, node2, {"test": "test"})
graph.add_directed_edge(node1, node3, {"test": "test"})
self.assertFalse(graph.are_nodes_d_separated(node1, node3, []))
Loading

0 comments on commit b4ff852

Please sign in to comment.