-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' of github.com:LilithWittmann/causality
- Loading branch information
Showing
6 changed files
with
211 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
from typing import Tuple, List, Optional | ||
|
||
from causy.generators import AllCombinationsGenerator | ||
from causy.interfaces import ( | ||
TestResultAction, | ||
IndependenceTestInterface, | ||
ComparisonSettings, | ||
BaseGraphInterface, | ||
TestResult, | ||
) | ||
|
||
|
||
class ColliderRuleFCI(IndependenceTestInterface): | ||
generator = AllCombinationsGenerator( | ||
comparison_settings=ComparisonSettings(min=2, max=2) | ||
) | ||
chunk_size_parallel_processing = 1 | ||
parallel = False | ||
|
||
def test( | ||
self, nodes: Tuple[str], graph: BaseGraphInterface | ||
) -> Optional[List[TestResult] | TestResult]: | ||
""" | ||
Some notes on how we implment FCI: After the independence tests, we have a graph with undirected edges which are | ||
implemented as two directed edges, one in each direction. We initialize the graph by adding values to all these edges, | ||
in the beginning, they get the value "either directed or undirected". Then we perform the collider test. Unlike in PC, | ||
we do not delete directed edges from z to x and from z to y in order to obtain the structure (x -> z <- y). Instead, we | ||
delete the information "either directed or undirected" from the directed edges from x to z and from y to z. That means, | ||
the directed edges from x to z and from y to z are now truly directed edges. The edges from z to x and from z to y can | ||
still stand for a directed edge or no directed edge. In the literature, this is portrayed by the meta symbol * and we | ||
obtain x *-> z <-* y. There might be ways to implement these similar but still subtly different orientation rules more consistently. | ||
TODO: write tests | ||
We call triples x, y, z of nodes v structures if x and y that are NOT adjacent but share an adjacent node z. | ||
V structures looks like this in the undirected skeleton: (x - z - y). | ||
We now check if z is in the separating set. If so, the edges must be oriented from x to z and from y to z: | ||
(x *-> z <-* y), where * indicated that there can be an arrowhead or none, we do not know, at least until | ||
applying further rules. | ||
:param nodes: list of nodes | ||
:param graph: the current graph | ||
:returns: list of actions that will be executed on graph | ||
""" | ||
|
||
x = graph.nodes[nodes[0]] | ||
y = graph.nodes[nodes[1]] | ||
|
||
# if x and y are adjacent, do nothing | ||
if graph.undirected_edge_exists(x, y): | ||
return TestResult(x=x, y=y, action=TestResultAction.DO_NOTHING, data={}) | ||
|
||
# if x and y are NOT adjacent, store all shared adjacent nodes | ||
potential_zs = set(graph.edges[x.id].keys()).intersection( | ||
set(graph.edges[y.id].keys()) | ||
) | ||
|
||
actions = graph.retrieve_edge_history( | ||
x, y, TestResultAction.REMOVE_EDGE_UNDIRECTED | ||
) | ||
|
||
# if x and y are not independent given z, safe action: make z a collider | ||
results = [] | ||
for z in potential_zs: | ||
z = graph.nodes[z] | ||
|
||
separators = [] | ||
for action in actions: | ||
if "separatedBy" in action.data: | ||
separators += [a.id for a in action.data["separatedBy"]] | ||
|
||
if z.id not in separators: | ||
results += [ | ||
TestResult( | ||
x=x, | ||
y=z, | ||
action=TestResultAction.UPDATE_EDGE_DIRECTED, | ||
data={"edge_type": None}, | ||
), | ||
TestResult( | ||
x=y, | ||
y=z, | ||
action=TestResultAction.UPDATE_EDGE_DIRECTED, | ||
data={"edge_type": None}, | ||
), | ||
] | ||
return results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import unittest | ||
|
||
from causy.graph import graph_model_factory, Graph | ||
from causy.interfaces import TestResult, TestResultAction | ||
from causy.orientation_rules.fci import ColliderRuleFCI | ||
|
||
|
||
class OrientationTestCase(unittest.TestCase): | ||
def test_collider_rule_fci(self): | ||
pipeline = [ColliderRuleFCI()] | ||
model = graph_model_factory(pipeline_steps=pipeline)() | ||
model.graph = Graph() | ||
x = model.graph.add_node("X", [0, 1, 2]) | ||
y = model.graph.add_node("Y", [3, 4, 5]) | ||
z = model.graph.add_node("Z", [6, 7, 8]) | ||
model.graph.add_edge(x, y, {"edge_type": "either directed or undirected"}) | ||
model.graph.add_edge(z, y, {"edge_type": "either directed or undirected"}) | ||
model.graph.add_edge_history( | ||
x, | ||
y, | ||
TestResult( | ||
x=x, | ||
y=z, | ||
action=TestResultAction.REMOVE_EDGE_UNDIRECTED, | ||
data={"separatedBy": []}, | ||
), | ||
) | ||
model.execute_pipeline_steps() | ||
self.assertEqual(model.graph.edge_value(x, y), {"edge_type": None}) | ||
self.assertEqual(model.graph.edge_value(z, y), {"edge_type": None}) | ||
self.assertEqual( | ||
model.graph.edge_value(y, x), {"edge_type": "either directed or undirected"} | ||
) | ||
self.assertEqual( | ||
model.graph.edge_value(y, z), {"edge_type": "either directed or undirected"} | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters