Skip to content

Commit

Permalink
Merge branch 'main' of github.com:LilithWittmann/causality
Browse files Browse the repository at this point in the history
  • Loading branch information
LilithWittmann committed Nov 5, 2023
2 parents 0ecaf36 + 6731efb commit 8dadd41
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 17 deletions.
1 change: 1 addition & 0 deletions causy/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def __init__(self, comparison_settings: ComparisonSettings, chunked: bool = None


class IndependenceTestInterface(ABC, SerializeMixin):
# TODO: refactor IndependenceTestInterface -> PipelineStepInterface or so
num_of_comparison_elements: int = 0
generator: Optional[GeneratorInterface] = None

Expand Down
Empty file.
86 changes: 86 additions & 0 deletions causy/orientation_rules/fci.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import Tuple, List, Optional

from causy.generators import AllCombinationsGenerator
from causy.interfaces import (
TestResultAction,
IndependenceTestInterface,
ComparisonSettings,
BaseGraphInterface,
TestResult,
)


class ColliderRuleFCI(IndependenceTestInterface):
generator = AllCombinationsGenerator(
comparison_settings=ComparisonSettings(min=2, max=2)
)
chunk_size_parallel_processing = 1
parallel = False

def test(
self, nodes: Tuple[str], graph: BaseGraphInterface
) -> Optional[List[TestResult] | TestResult]:
"""
Some notes on how we implment FCI: After the independence tests, we have a graph with undirected edges which are
implemented as two directed edges, one in each direction. We initialize the graph by adding values to all these edges,
in the beginning, they get the value "either directed or undirected". Then we perform the collider test. Unlike in PC,
we do not delete directed edges from z to x and from z to y in order to obtain the structure (x -> z <- y). Instead, we
delete the information "either directed or undirected" from the directed edges from x to z and from y to z. That means,
the directed edges from x to z and from y to z are now truly directed edges. The edges from z to x and from z to y can
still stand for a directed edge or no directed edge. In the literature, this is portrayed by the meta symbol * and we
obtain x *-> z <-* y. There might be ways to implement these similar but still subtly different orientation rules more consistently.
TODO: write tests
We call triples x, y, z of nodes v structures if x and y that are NOT adjacent but share an adjacent node z.
V structures looks like this in the undirected skeleton: (x - z - y).
We now check if z is in the separating set. If so, the edges must be oriented from x to z and from y to z:
(x *-> z <-* y), where * indicated that there can be an arrowhead or none, we do not know, at least until
applying further rules.
:param nodes: list of nodes
:param graph: the current graph
:returns: list of actions that will be executed on graph
"""

x = graph.nodes[nodes[0]]
y = graph.nodes[nodes[1]]

# if x and y are adjacent, do nothing
if graph.undirected_edge_exists(x, y):
return TestResult(x=x, y=y, action=TestResultAction.DO_NOTHING, data={})

# if x and y are NOT adjacent, store all shared adjacent nodes
potential_zs = set(graph.edges[x.id].keys()).intersection(
set(graph.edges[y.id].keys())
)

actions = graph.retrieve_edge_history(
x, y, TestResultAction.REMOVE_EDGE_UNDIRECTED
)

# if x and y are not independent given z, safe action: make z a collider
results = []
for z in potential_zs:
z = graph.nodes[z]

separators = []
for action in actions:
if "separatedBy" in action.data:
separators += [a.id for a in action.data["separatedBy"]]

if z.id not in separators:
results += [
TestResult(
x=x,
y=z,
action=TestResultAction.UPDATE_EDGE_DIRECTED,
data={"edge_type": None},
),
TestResult(
x=y,
y=z,
action=TestResultAction.UPDATE_EDGE_DIRECTED,
data={"edge_type": None},
),
]
return results
2 changes: 2 additions & 0 deletions causy/orientation_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# theory for all orientation rules with pictures:
# https://hpi.de/fileadmin/user_upload/fachgebiete/plattner/teaching/CausalInference/2019/Introduction_to_Constraint-Based_Causal_Structure_Learning.pdf

# TODO: refactor ColliderTest -> ColliderRule and move to folder orientation_rules (after checking for duplicates)


class ColliderTest(IndependenceTestInterface):
generator = AllCombinationsGenerator(
Expand Down
36 changes: 36 additions & 0 deletions tests/test_orientation_rules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import unittest

from causy.graph import graph_model_factory, Graph
from causy.interfaces import TestResult, TestResultAction
from causy.orientation_rules.fci import ColliderRuleFCI


class OrientationTestCase(unittest.TestCase):
def test_collider_rule_fci(self):
pipeline = [ColliderRuleFCI()]
model = graph_model_factory(pipeline_steps=pipeline)()
model.graph = Graph()
x = model.graph.add_node("X", [0, 1, 2])
y = model.graph.add_node("Y", [3, 4, 5])
z = model.graph.add_node("Z", [6, 7, 8])
model.graph.add_edge(x, y, {"edge_type": "either directed or undirected"})
model.graph.add_edge(z, y, {"edge_type": "either directed or undirected"})
model.graph.add_edge_history(
x,
y,
TestResult(
x=x,
y=z,
action=TestResultAction.REMOVE_EDGE_UNDIRECTED,
data={"separatedBy": []},
),
)
model.execute_pipeline_steps()
self.assertEqual(model.graph.edge_value(x, y), {"edge_type": None})
self.assertEqual(model.graph.edge_value(z, y), {"edge_type": None})
self.assertEqual(
model.graph.edge_value(y, x), {"edge_type": "either directed or undirected"}
)
self.assertEqual(
model.graph.edge_value(y, z), {"edge_type": "either directed or undirected"}
)
103 changes: 86 additions & 17 deletions tests/test_pc_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,26 @@ def test_with_minimal_toy_model(self):
for key, node in tst.graph.nodes.items():
node_mapping[node.name] = key

self.assertTrue(tst.graph.only_directed_edge_exists(tst.graph.nodes[node_mapping["V"]], tst.graph.nodes[node_mapping["Z"]]))
self.assertTrue(tst.graph.only_directed_edge_exists(tst.graph.nodes[node_mapping["W"]], tst.graph.nodes[node_mapping["Z"]]))
self.assertTrue(tst.graph.only_directed_edge_exists(tst.graph.nodes[node_mapping["Z"]], tst.graph.nodes[node_mapping["X"]]))
self.assertTrue(tst.graph.only_directed_edge_exists(tst.graph.nodes[node_mapping["Z"]], tst.graph.nodes[node_mapping["Y"]]))
self.assertTrue(
tst.graph.only_directed_edge_exists(
tst.graph.nodes[node_mapping["V"]], tst.graph.nodes[node_mapping["Z"]]
)
)
self.assertTrue(
tst.graph.only_directed_edge_exists(
tst.graph.nodes[node_mapping["W"]], tst.graph.nodes[node_mapping["Z"]]
)
)
self.assertTrue(
tst.graph.only_directed_edge_exists(
tst.graph.nodes[node_mapping["Z"]], tst.graph.nodes[node_mapping["X"]]
)
)
self.assertTrue(
tst.graph.only_directed_edge_exists(
tst.graph.nodes[node_mapping["Z"]], tst.graph.nodes[node_mapping["Y"]]
)
)

def test_with_larger_toy_model(self):
a, b, c, d, e, f, g, sample_size = 1.2, 1.7, 2, 1.5, 3, 4, 1.8, 10000
Expand All @@ -98,19 +114,72 @@ def test_with_larger_toy_model(self):
for key, node in tst.graph.nodes.items():
node_mapping[node.name] = key

self.assertFalse(tst.graph.edge_exists(tst.graph.nodes[node_mapping["A"]], tst.graph.nodes[node_mapping["B"]]))
self.assertTrue(tst.graph.directed_edge_exists(tst.graph.nodes[node_mapping["A"]], tst.graph.nodes[node_mapping["C"]]))
self.assertTrue(tst.graph.directed_edge_exists(tst.graph.nodes[node_mapping["B"]], tst.graph.nodes[node_mapping["C"]]))
self.assertTrue(tst.graph.directed_edge_exists(tst.graph.nodes[node_mapping["A"]], tst.graph.nodes[node_mapping["D"]]))
self.assertTrue(tst.graph.directed_edge_exists(tst.graph.nodes[node_mapping["B"]], tst.graph.nodes[node_mapping["D"]]))
self.assertTrue(tst.graph.directed_edge_exists(tst.graph.nodes[node_mapping["C"]], tst.graph.nodes[node_mapping["D"]]))
self.assertTrue(tst.graph.directed_edge_exists(tst.graph.nodes[node_mapping["B"]], tst.graph.nodes[node_mapping["E"]]))
self.assertTrue(tst.graph.directed_edge_exists(tst.graph.nodes[node_mapping["E"]], tst.graph.nodes[node_mapping["F"]]))
self.assertTrue(tst.graph.directed_edge_exists(tst.graph.nodes[node_mapping["B"]], tst.graph.nodes[node_mapping["F"]]))
self.assertTrue(tst.graph.directed_edge_exists(tst.graph.nodes[node_mapping["C"]], tst.graph.nodes[node_mapping["F"]]))
self.assertTrue(tst.graph.directed_edge_exists(tst.graph.nodes[node_mapping["D"]], tst.graph.nodes[node_mapping["F"]]))
self.assertFalse(tst.graph.edge_exists(tst.graph.nodes[node_mapping["A"]], tst.graph.nodes[node_mapping["E"]]))
self.assertFalse(tst.graph.edge_exists(tst.graph.nodes[node_mapping["A"]], tst.graph.nodes[node_mapping["F"]]))
self.assertFalse(
tst.graph.edge_exists(
tst.graph.nodes[node_mapping["A"]], tst.graph.nodes[node_mapping["B"]]
)
)
self.assertTrue(
tst.graph.directed_edge_exists(
tst.graph.nodes[node_mapping["A"]], tst.graph.nodes[node_mapping["C"]]
)
)
self.assertTrue(
tst.graph.directed_edge_exists(
tst.graph.nodes[node_mapping["B"]], tst.graph.nodes[node_mapping["C"]]
)
)
self.assertTrue(
tst.graph.directed_edge_exists(
tst.graph.nodes[node_mapping["A"]], tst.graph.nodes[node_mapping["D"]]
)
)
self.assertTrue(
tst.graph.directed_edge_exists(
tst.graph.nodes[node_mapping["B"]], tst.graph.nodes[node_mapping["D"]]
)
)
self.assertTrue(
tst.graph.directed_edge_exists(
tst.graph.nodes[node_mapping["C"]], tst.graph.nodes[node_mapping["D"]]
)
)
self.assertTrue(
tst.graph.directed_edge_exists(
tst.graph.nodes[node_mapping["B"]], tst.graph.nodes[node_mapping["E"]]
)
)
self.assertTrue(
tst.graph.directed_edge_exists(
tst.graph.nodes[node_mapping["E"]], tst.graph.nodes[node_mapping["F"]]
)
)
self.assertTrue(
tst.graph.directed_edge_exists(
tst.graph.nodes[node_mapping["B"]], tst.graph.nodes[node_mapping["F"]]
)
)
self.assertTrue(
tst.graph.directed_edge_exists(
tst.graph.nodes[node_mapping["C"]], tst.graph.nodes[node_mapping["F"]]
)
)
self.assertTrue(
tst.graph.directed_edge_exists(
tst.graph.nodes[node_mapping["D"]], tst.graph.nodes[node_mapping["F"]]
)
)
self.assertFalse(
tst.graph.edge_exists(
tst.graph.nodes[node_mapping["A"]], tst.graph.nodes[node_mapping["E"]]
)
)
self.assertFalse(
tst.graph.edge_exists(
tst.graph.nodes[node_mapping["A"]], tst.graph.nodes[node_mapping["F"]]
)
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 8dadd41

Please sign in to comment.