diff --git a/causy/causal_discovery/constraint/orientation_rules/pc.py b/causy/causal_discovery/constraint/orientation_rules/pc.py index ac056dd..8995ee0 100644 --- a/causy/causal_discovery/constraint/orientation_rules/pc.py +++ b/causy/causal_discovery/constraint/orientation_rules/pc.py @@ -1,3 +1,4 @@ +import enum from typing import Tuple, List, Optional, Generic import itertools @@ -9,7 +10,7 @@ PipelineStepInterfaceType, ) from causy.models import ComparisonSettings, TestResultAction, TestResult -from causy.variables import IntegerParameter, BoolParameter +from causy.variables import IntegerParameter, BoolParameter, StringParameter # theory for all orientation rules with pictures: @@ -18,6 +19,58 @@ # TODO: refactor ColliderTest -> ColliderRule and move to folder orientation_rules (after checking for duplicates) +def filter_unapplied_actions(actions, u, v): + """ + Filter out actions that have not been applied to the graph yet. + :param actions: list of actions + :param u: node u + :param v: node v + :return: list of actions that have not been applied to the graph yet + """ + filtered = [] + for result_set in actions: + if result_set is None: + continue + for result in result_set: + if result.u == u and result.v == v: + filtered.append(result) + return filtered + + +def generate_restores(unapplied_actions): + """ + Generate restore actions for unapplied actions. + :param unapplied_actions: list of unapplied actions + :param x: node x + :param y: node y + :return: list of restore actions + """ + results = [] + for action in unapplied_actions: + if action.action == TestResultAction.REMOVE_EDGE_DIRECTED: + results.append( + TestResult( + u=action.u, + v=action.v, + action=TestResultAction.RESTORE_EDGE_DIRECTED, + data={}, + ) + ) + return results + + +class ColliderTestConflictResolutionStrategies(enum.StrEnum): + """ + Enum for the conflict resolution strategies for the ColliderTest. + """ + + # If a conflict occurs, the edge that was removed first is kept. + KEEP_FIRST = "KEEP_FIRST" + + # If a conflict occurs, the edge that was removed last is kept. + KEEP_LAST = "KEEP_LAST" + + class ColliderTest( PipelineStepInterface[PipelineStepInterfaceType], Generic[PipelineStepInterfaceType] ): @@ -27,6 +80,10 @@ class ColliderTest( chunk_size_parallel_processing: IntegerParameter = 1 parallel: BoolParameter = False + conflict_resolution_strategy: StringParameter = ( + ColliderTestConflictResolutionStrategies.KEEP_FIRST + ) + needs_unapplied_actions: BoolParameter = True def process( @@ -75,20 +132,58 @@ def process( separators += [a.id for a in action.data["separatedBy"]] if z.id not in separators: - results += [ - TestResult( - u=z, - v=x, - action=TestResultAction.REMOVE_EDGE_DIRECTED, - data={}, - ), - TestResult( - u=z, - v=y, - action=TestResultAction.REMOVE_EDGE_DIRECTED, - data={}, - ), - ] + unapplied_actions_x_z = filter_unapplied_actions( + unapplied_actions, x, z + ) + unapplied_actions_y_z = filter_unapplied_actions( + unapplied_actions, y, z + ) + if len(unapplied_actions_y_z) > 0 or len(unapplied_actions_x_z) > 0: + if ( + ColliderTestConflictResolutionStrategies.KEEP_FIRST + is self.conflict_resolution_strategy + ): + # We keep the first edge that was removed + continue + elif ( + ColliderTestConflictResolutionStrategies.KEEP_LAST + is self.conflict_resolution_strategy + ): + # We keep the last edge that was removed and restore the other edges + results.extend(generate_restores(unapplied_actions_x_z)) + results.extend(generate_restores(unapplied_actions_y_z)) + results.append( + TestResult( + u=z, + v=x, + action=TestResultAction.REMOVE_EDGE_DIRECTED, + data={}, + ) + ) + results.append( + TestResult( + u=z, + v=y, + action=TestResultAction.REMOVE_EDGE_DIRECTED, + data={}, + ) + ) + + else: + results += [ + TestResult( + u=z, + v=x, + action=TestResultAction.REMOVE_EDGE_DIRECTED, + data={}, + ), + TestResult( + u=z, + v=y, + action=TestResultAction.REMOVE_EDGE_DIRECTED, + data={}, + ), + ] return results diff --git a/causy/variables.py b/causy/variables.py index 1b74d31..bf59625 100644 --- a/causy/variables.py +++ b/causy/variables.py @@ -202,7 +202,12 @@ def resolve_variable_to_object(obj: Any, variables): obj.__dict__[attribute] = variables[value.name] else: raise ValueError(f'Variable "{value.name}" not found in the variables.') - elif hasattr(value, "__dict__"): + elif ( + hasattr(value, "__dict__") + and not isinstance(value, NoneType) + and not hasattr(value, "value") + ): + # we check for value because we don't want to resolve the variable if it's a variable object itself or a Enum obj.__dict__[attribute] = resolve_variable_to_object(value, variables) return obj diff --git a/tests/test_orientation_tests.py b/tests/test_orientation_tests.py index 1e52751..789c71b 100644 --- a/tests/test_orientation_tests.py +++ b/tests/test_orientation_tests.py @@ -8,6 +8,7 @@ FurtherOrientTripleTest, OrientQuadrupleTest, FurtherOrientQuadrupleTest, + ColliderTestConflictResolutionStrategies, ) from causy.graph_model import graph_model_factory @@ -43,6 +44,133 @@ def test_collider_test(self): model.execute_pipeline_steps() self.assertTrue(model.graph.only_directed_edge_exists(x, y)) self.assertTrue(model.graph.only_directed_edge_exists(z, y)) + self.assertFalse(model.graph.edge_exists(x, z)) + + def test_collider_test_2(self): + pipeline = [ColliderTest()] + model = graph_model_factory( + Algorithm( + pipeline_steps=pipeline, + edge_types=[], + name="TestCollider", + ) + )() + model.graph = GraphManager() + x = model.graph.add_node("X", [0, 1, 2]) + y = model.graph.add_node("Y", [3, 4, 5]) + z = model.graph.add_node("Z", [6, 7, 8]) + model.graph.add_edge(x, y, {}) + model.graph.add_edge(z, y, {}) + model.graph.add_edge_history( + x, + z, + TestResult( + u=x, + v=z, + action=TestResultAction.REMOVE_EDGE_UNDIRECTED, + data={"separatedBy": [y]}, + ), + ) + model.execute_pipeline_steps() + self.assertFalse(model.graph.only_directed_edge_exists(x, y)) + self.assertFalse(model.graph.only_directed_edge_exists(z, y)) + self.assertTrue(model.graph.edge_exists(x, y)) + self.assertTrue(model.graph.edge_exists(z, y)) + self.assertFalse(model.graph.edge_exists(x, z)) + + def test_collider_test_3(self): + pipeline = [ColliderTest()] + model = graph_model_factory( + Algorithm( + pipeline_steps=pipeline, + edge_types=[], + name="TestCollider", + ) + )() + model.graph = GraphManager() + x = model.graph.add_node("X", [0, 1, 2]) + y = model.graph.add_node("Y", [3, 4, 5]) + z = model.graph.add_node("Z", [6, 7, 8]) + a = model.graph.add_node("A", [9, 10, 11]) + model.graph.add_edge(x, y, {}) + model.graph.add_edge(z, y, {}) + model.graph.add_edge(x, a, {}) + model.graph.add_edge(z, a, {}) + model.graph.add_edge_history( + x, + z, + TestResult( + u=x, + v=z, + action=TestResultAction.REMOVE_EDGE_UNDIRECTED, + data={"separatedBy": []}, + ), + ) + model.graph.add_edge_history( + y, + a, + TestResult( + u=y, + v=a, + action=TestResultAction.REMOVE_EDGE_UNDIRECTED, + data={"separatedBy": [x, z]}, + ), + ) + model.execute_pipeline_steps() + self.assertTrue(model.graph.only_directed_edge_exists(x, y)) + self.assertTrue(model.graph.only_directed_edge_exists(z, y)) + self.assertTrue(model.graph.only_directed_edge_exists(x, a)) + self.assertTrue(model.graph.only_directed_edge_exists(z, a)) + self.assertFalse(model.graph.edge_exists(x, z)) + + def test_collider_test_4(self): + pipeline = [ColliderTest()] + model = graph_model_factory( + Algorithm( + pipeline_steps=pipeline, + edge_types=[], + name="TestCollider", + ) + )() + model.graph = GraphManager() + x = model.graph.add_node("X", [0, 1, 2]) + y = model.graph.add_node("Y", [3, 4, 5]) + z = model.graph.add_node("Z", [6, 7, 8]) + a = model.graph.add_node("A", [9, 10, 11]) + model.graph.add_edge(x, y, {}) + model.graph.add_edge(z, y, {}) + model.graph.add_edge(x, a, {}) + model.graph.add_edge(z, a, {}) + model.graph.add_edge_history( + x, + z, + TestResult( + u=x, + v=z, + action=TestResultAction.REMOVE_EDGE_UNDIRECTED, + data={"separatedBy": [a]}, + ), + ) + model.graph.add_edge_history( + y, + a, + TestResult( + u=y, + v=a, + action=TestResultAction.REMOVE_EDGE_UNDIRECTED, + data={"separatedBy": [x, z]}, + ), + ) + model.execute_pipeline_steps() + self.assertTrue(model.graph.only_directed_edge_exists(x, y)) + self.assertTrue(model.graph.only_directed_edge_exists(z, y)) + self.assertTrue(model.graph.edge_exists(x, y)) + self.assertTrue(model.graph.edge_exists(z, y)) + self.assertFalse(model.graph.only_directed_edge_exists(x, a)) + self.assertFalse(model.graph.only_directed_edge_exists(z, a)) + self.assertTrue(model.graph.edge_exists(x, a)) + self.assertTrue(model.graph.edge_exists(z, a)) + self.assertFalse(model.graph.edge_exists(x, z)) def test_collider_test_multiple_orientation_rules(self): pipeline = [ @@ -139,7 +267,7 @@ def test_collider_test_with_nonempty_separation_set(self): model.graph.add_edge(z, y, {}) model.graph.add_edge_history( x, - y, + z, TestResult( u=x, v=z, @@ -147,9 +275,148 @@ def test_collider_test_with_nonempty_separation_set(self): data={"separatedBy": [y]}, ), ) + model.execute_pipeline_steps() + self.assertFalse(model.graph.only_directed_edge_exists(x, y)) + self.assertFalse(model.graph.only_directed_edge_exists(z, y)) self.assertTrue(model.graph.undirected_edge_exists(x, y)) self.assertTrue(model.graph.undirected_edge_exists(y, z)) + def test_collider_multiple_colliders(self): + pipeline = [ + ColliderTest() + ] + model = graph_model_factory( + Algorithm( + pipeline_steps=pipeline, + edge_types=[], + name="TestCollider", + ) + )() + model.graph = GraphManager() + x = model.graph.add_node("X", []) + y = model.graph.add_node("Y", []) + z = model.graph.add_node("Z", []) + a = model.graph.add_node("A", []) + model.graph.add_edge(x, y, {}) + model.graph.add_edge(z, y, {}) + model.graph.add_edge(x, a, {}) + model.graph.add_edge(z, a, {}) + model.graph.add_edge_history( + x, + z, + TestResult( + u=x, + v=z, + action=TestResultAction.REMOVE_EDGE_UNDIRECTED, + data={"separatedBy": []}, + ), + ) + model.execute_pipeline_steps() + self.assertTrue(model.graph.only_directed_edge_exists(x, a)) + self.assertTrue(model.graph.only_directed_edge_exists(z, a)) + self.assertTrue(model.graph.only_directed_edge_exists(x, y)) + self.assertTrue(model.graph.only_directed_edge_exists(z, y)) + + def test_collider_prioritize_collider_rules(self): + pipeline = [ + ColliderTest( + conflict_resolution_strategy=ColliderTestConflictResolutionStrategies.KEEP_LAST + ) + ] + model = graph_model_factory( + Algorithm( + pipeline_steps=pipeline, + edge_types=[], + name="TestCollider", + ) + )() + model.graph = GraphManager() + x = model.graph.add_node("X", []) + y = model.graph.add_node("Y", []) + z = model.graph.add_node("Z", []) + a = model.graph.add_node("A", []) + model.graph.add_edge(x, y, {}) + model.graph.add_edge(z, y, {}) + model.graph.add_edge(x, a, {}) + model.graph.add_edge(z, a, {}) + model.graph.remove_directed_edge(a, x) + model.graph.remove_directed_edge(a, z) + model.graph.add_edge_history( + x, + z, + TestResult( + u=x, + v=z, + action=TestResultAction.REMOVE_EDGE_UNDIRECTED, + data={"separatedBy": []}, + ), + ) + model.graph.add_edge_history( + a, + y, + TestResult( + u=x, + v=z, + action=TestResultAction.REMOVE_EDGE_UNDIRECTED, + data={"separatedBy": []}, + ), + ) + model.execute_pipeline_steps() + self.assertTrue(model.graph.edge_exists(y, z)) + self.assertTrue(model.graph.edge_exists(a, x)) + self.assertTrue(model.graph.edge_exists(a, z)) + self.assertTrue(model.graph.edge_exists(x, y)) + + def test_collider_prioritize_collider_rules_2(self): + pipeline = [ + ColliderTest( + conflict_resolution_strategy=ColliderTestConflictResolutionStrategies.KEEP_FIRST + ) + ] + model = graph_model_factory( + Algorithm( + pipeline_steps=pipeline, + edge_types=[], + name="TestCollider", + ) + )() + model.graph = GraphManager() + x = model.graph.add_node("X", []) + y = model.graph.add_node("Y", []) + z = model.graph.add_node("Z", []) + a = model.graph.add_node("A", []) + model.graph.add_edge(x, y, {}) + model.graph.add_edge(z, y, {}) + model.graph.add_edge(x, a, {}) + model.graph.add_edge(z, a, {}) + model.graph.remove_directed_edge(a, x) + model.graph.remove_directed_edge(a, z) + model.graph.add_edge_history( + x, + z, + TestResult( + u=x, + v=z, + action=TestResultAction.REMOVE_EDGE_UNDIRECTED, + data={"separatedBy": []}, + ), + ) + model.graph.add_edge_history( + a, + y, + TestResult( + u=x, + v=z, + action=TestResultAction.REMOVE_EDGE_UNDIRECTED, + data={"separatedBy": []}, + ), + ) + model.execute_pipeline_steps() + self.assertTrue(model.graph.edge_exists(y, z)) + self.assertTrue(model.graph.edge_exists(a, x)) + self.assertTrue(model.graph.edge_exists(a, z)) + self.assertTrue(model.graph.edge_exists(x, y)) + def test_non_collider_test(self): pipeline = [NonColliderTest()] model = graph_model_factory(