Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

D separation check #62

Merged
merged 2 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class CorrelationCoefficientTest(
parallel: BoolParameter = False

def process(
self, nodes: List[str], graph: BaseGraphInterface
self, nodes: List[str], graph: BaseGraphInterface
) -> Optional[TestResult]:
"""
Test if u and v are independent and delete edge in graph if they are.
Expand Down Expand Up @@ -204,7 +204,9 @@ def process(
helper = torch.mm(torch.sqrt(diagonal_matrix), inverse_cov_matrix)
precision_matrix = torch.mm(helper, torch.sqrt(diagonal_matrix))

par_corr = (-1 * precision_matrix[0][1]) / torch.sqrt(precision_matrix[0][0] * precision_matrix[1][1])
par_corr = (-1 * precision_matrix[0][1]) / torch.sqrt(
precision_matrix[0][0] * precision_matrix[1][1]
)

sample_size = len(graph.nodes[nodes[0]].values)
nb_of_control_vars = len(nodes) - 2
Expand Down
116 changes: 89 additions & 27 deletions causy/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,95 @@ def directed_path_exists(self, u: Union[Node, str], v: Union[Node, str]) -> bool
return True
return False

def edge_of_type_exists(
self,
u: Union[Node, str],
v: Union[Node, str],
edge_type: EdgeTypeInterface = DirectedEdge(),
) -> bool:
"""
Check if an edge of a specific type exists between u and v.
:param u: node u
:param v: node v
:param edge_type: the type of the edge to check for
:return: True if an edge of this type exists, False otherwise
"""

if isinstance(u, Node):
u = u.id
if isinstance(v, Node):
v = v.id

if not self.directed_edge_exists(u, v):
return False

if self.edges[u][v].edge_type != edge_type:
return False

return True

def are_nodes_d_separated(
self,
u: Union[Node, str],
v: Union[Node, str],
conditioning_set: List[Union[Node, str]],
) -> bool:
"""
Check if nodes u and v are d-separated given a conditioning set. We check whether there is an open path, i.e. a path on which all colliders are in the conditioning set and all non-colliders are not in the conditioning set. If there is no open path, u and v are d-separated.

:param u: First node
:param v: Second node
:param conditioning_set: Set of nodes to condition on
:return: True if u and v are d-separated given conditioning_set, False otherwise
"""

# Convert Node instances to their IDs
if isinstance(u, Node):
u = u.id
if isinstance(v, Node):
v = v.id

# u and v may not be in the conditioning set, throw error
if u in conditioning_set or v in conditioning_set:
raise ValueError("Nodes u and v may not be in the conditioning set")

# check whether there is an open path on which all colliders are in the conditioning set and all non-colliders are not in the conditioning set
list_of_results_for_paths = []
for path in self.all_paths_on_underlying_undirected_graph(u, v):
if len(path) == 2:
list_of_results_for_paths.append(False)

for i in range(1, len(path) - 1):
is_path_blocked = False

# paths are d-separated if a collider is not in the conditioning set
if path[i] not in conditioning_set:
if self.edge_of_type_exists(
path[i - 1].id, path[i].id, DirectedEdge()
) and self.edge_of_type_exists(
path[i + 1].id, path[i].id, DirectedEdge()
):
is_path_blocked = True

# paths are d-separated if a non-collider is in the conditioning set
elif path[i] in conditioning_set:
if not (
self.edge_of_type_exists(
path[i - 1].id, path[i].id, DirectedEdge()
)
and self.edge_of_type_exists(
path[i + 1].id, path[i].id, DirectedEdge()
)
):
is_path_blocked = True

list_of_results_for_paths.append(is_path_blocked)

# if there is at least one open path, u and v are not d-separated
if False in list_of_results_for_paths:
return False
return True

def all_paths_on_underlying_undirected_graph(
self, u: Union[Node, str], v: Union[Node, str], visited=None, path=None
) -> List[List[Node]]:
Expand Down Expand Up @@ -307,33 +396,6 @@ def all_paths_on_underlying_undirected_graph(
path.pop()
visited.remove(u)

def edge_of_type_exists(
self,
u: Union[Node, str],
v: Union[Node, str],
edge_type: EdgeTypeInterface = DirectedEdge(),
) -> bool:
"""
Check if an edge of a specific type exists between u and v.
:param u: node u
:param v: node v
:param edge_type: the type of the edge to check for
:return: True if an edge of this type exists, False otherwise
"""

if isinstance(u, Node):
u = u.id
if isinstance(v, Node):
v = v.id

if not self.edge_exists(u, v):
return False

if self.edges[u][v].edge_type != edge_type:
return False

return True

def retrieve_edges(self) -> List[Edge]:
"""
Retrieve all edges
Expand Down
233 changes: 233 additions & 0 deletions tests/test_path_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
from causy.edge_types import DirectedEdge
from causy.graph import GraphBaseAccessMixin, GraphManager
from tests.utils import CausyTestCase


class GraphTestCase(CausyTestCase):
def test_all_paths_on_underlying_undirected_graph(self):
new_graph_manager = GraphManager
new_graph_manager.__bases__ = (
GraphBaseAccessMixin,
DirectedEdge.GraphAccessMixin,
)
graph = new_graph_manager()
node1 = graph.add_node("test1", [1, 2, 3])
node2 = graph.add_node("test2", [1, 2, 3])
node3 = graph.add_node("test3", [1, 2, 3])
graph.add_directed_edge(node1, node2, {"test": "test"})
graph.add_directed_edge(node2, node3, {"test": "test"})
self.assertIn(
[node1, node2, node3],
[
x
for x in [
l
for l in graph.all_paths_on_underlying_undirected_graph(
node1, node3
)
]
],
)
self.assertNotIn(
[node1, node2],
[l for l in graph.all_paths_on_underlying_undirected_graph(node1, node3)],
)
self.assertNotIn(
[node2, node3],
[l for l in graph.all_paths_on_underlying_undirected_graph(node1, node3)],
)
self.assertNotIn(
[node1, node3],
[l for l in graph.all_paths_on_underlying_undirected_graph(node1, node3)],
)

def test_all_paths_on_underlying_undirected_graph_2(self):
new_graph_manager = GraphManager
new_graph_manager.__bases__ = (
GraphBaseAccessMixin,
DirectedEdge.GraphAccessMixin,
)
graph = new_graph_manager()
node1 = graph.add_node("test1", [1, 2, 3])
node2 = graph.add_node("test2", [1, 2, 3])
node3 = graph.add_node("test3", [1, 2, 3])
graph.add_directed_edge(node3, node2, {"test": "test"})
graph.add_directed_edge(node2, node1, {"test": "test"})
self.assertIn(
[node1, node2, node3],
[
x
for x in [
l
for l in graph.all_paths_on_underlying_undirected_graph(
node1, node3
)
]
],
)
self.assertNotIn(
[node1, node2],
[l for l in graph.all_paths_on_underlying_undirected_graph(node1, node3)],
)
self.assertNotIn(
[node2, node3],
[l for l in graph.all_paths_on_underlying_undirected_graph(node1, node3)],
)
self.assertNotIn(
[node1, node3],
[l for l in graph.all_paths_on_underlying_undirected_graph(node1, node3)],
)

def test_all_paths_on_underlying_undirected_graph_collider_path(self):
new_graph_manager = GraphManager
new_graph_manager.__bases__ = (
GraphBaseAccessMixin,
DirectedEdge.GraphAccessMixin,
)
graph = new_graph_manager()
node1 = graph.add_node("test1", [1, 2, 3])
node2 = graph.add_node("test2", [1, 2, 3])
node3 = graph.add_node("test3", [1, 2, 3])
graph.add_directed_edge(node3, node2, {"test": "test"})
graph.add_directed_edge(node1, node2, {"test": "test"})
self.assertIn(
[node1, node2, node3],
[
x
for x in [
l
for l in graph.all_paths_on_underlying_undirected_graph(
node1, node3
)
]
],
)
self.assertNotIn(
[node1, node2],
[l for l in graph.all_paths_on_underlying_undirected_graph(node1, node3)],
)
self.assertNotIn(
[node2, node3],
[l for l in graph.all_paths_on_underlying_undirected_graph(node1, node3)],
)
self.assertNotIn(
[node1, node3],
[l for l in graph.all_paths_on_underlying_undirected_graph(node1, node3)],
)

def test_all_paths_on_underlying_undirected_graph_several_paths(self):
new_graph_manager = GraphManager
new_graph_manager.__bases__ = (
GraphBaseAccessMixin,
DirectedEdge.GraphAccessMixin,
)
graph = new_graph_manager()
node1 = graph.add_node("test1", [1, 2, 3])
node2 = graph.add_node("test2", [1, 2, 3])
node3 = graph.add_node("test3", [1, 2, 3])
graph.add_directed_edge(node1, node2, {"test": "test"})
graph.add_directed_edge(node2, node3, {"test": "test"})
graph.add_directed_edge(node1, node3, {"test": "test"})
self.assertIn(
[node1, node2, node3],
[
x
for x in [
l
for l in graph.all_paths_on_underlying_undirected_graph(
node1, node3
)
]
],
)
self.assertIn(
[node1, node3],
[
x
for x in [
l
for l in graph.all_paths_on_underlying_undirected_graph(
node1, node3
)
]
],
)
self.assertNotIn(
[node1, node2],
[l for l in graph.all_paths_on_underlying_undirected_graph(node1, node3)],
)
self.assertNotIn(
[node2, node3],
[l for l in graph.all_paths_on_underlying_undirected_graph(node1, node3)],
)

def test_are_nodes_d_separated(self):
new_graph_manager = GraphManager
new_graph_manager.__bases__ = (
GraphBaseAccessMixin,
DirectedEdge.GraphAccessMixin,
)
graph = new_graph_manager()
node1 = graph.add_node("test1", [1, 2, 3])
node2 = graph.add_node("test2", [1, 2, 3])
node3 = graph.add_node("test3", [1, 2, 3])
graph.add_directed_edge(node1, node2, {"test": "test"})
graph.add_directed_edge(node2, node3, {"test": "test"})
self.assertFalse(graph.are_nodes_d_separated(node1, node3, []))

def test_are_nodes_d_separated_2(self):
new_graph_manager = GraphManager
new_graph_manager.__bases__ = (
GraphBaseAccessMixin,
DirectedEdge.GraphAccessMixin,
)
graph = new_graph_manager()
node1 = graph.add_node("test1", [1, 2, 3])
node2 = graph.add_node("test2", [1, 2, 3])
node3 = graph.add_node("test3", [1, 2, 3])
graph.add_directed_edge(node1, node2, {"test": "test"})
graph.add_directed_edge(node2, node3, {"test": "test"})
self.assertTrue(graph.are_nodes_d_separated(node1, node3, [node2]))

def test_are_nodes_d_separated_3(self):
new_graph_manager = GraphManager
new_graph_manager.__bases__ = (
GraphBaseAccessMixin,
DirectedEdge.GraphAccessMixin,
)
graph = new_graph_manager()
node1 = graph.add_node("test1", [1, 2, 3])
node2 = graph.add_node("test2", [1, 2, 3])
node3 = graph.add_node("test3", [1, 2, 3])
graph.add_directed_edge(node1, node2, {"test": "test"})
graph.add_directed_edge(node3, node2, {"test": "test"})
self.assertTrue(graph.are_nodes_d_separated(node1, node3, []))

def test_are_nodes_d_separated_4(self):
new_graph_manager = GraphManager
new_graph_manager.__bases__ = (
GraphBaseAccessMixin,
DirectedEdge.GraphAccessMixin,
)
graph = new_graph_manager()
node1 = graph.add_node("test1", [1, 2, 3])
node2 = graph.add_node("test2", [1, 2, 3])
node3 = graph.add_node("test3", [1, 2, 3])
graph.add_directed_edge(node1, node2, {"test": "test"})
graph.add_directed_edge(node3, node2, {"test": "test"})
self.assertFalse(graph.are_nodes_d_separated(node1, node3, [node2]))

def test_are_nodes_d_separated_3(self):
new_graph_manager = GraphManager
new_graph_manager.__bases__ = (
GraphBaseAccessMixin,
DirectedEdge.GraphAccessMixin,
)
graph = new_graph_manager()
node1 = graph.add_node("test1", [1, 2, 3])
node2 = graph.add_node("test2", [1, 2, 3])
node3 = graph.add_node("test3", [1, 2, 3])
graph.add_directed_edge(node1, node2, {"test": "test"})
graph.add_directed_edge(node3, node2, {"test": "test"})
graph.add_directed_edge(node1, node3, {"test": "test"})
self.assertFalse(graph.are_nodes_d_separated(node1, node3, []))
Loading
Loading