Skip to content

Commit

Permalink
Merge pull request #46 from causy-dev/fix-collider-rule
Browse files Browse the repository at this point in the history
Fix collider rule
  • Loading branch information
this-is-sofia authored Jul 6, 2024
2 parents 0ec9210 + 5039a7b commit f04328a
Show file tree
Hide file tree
Showing 3 changed files with 384 additions and 17 deletions.
125 changes: 110 additions & 15 deletions causy/causal_discovery/constraint/orientation_rules/pc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
from typing import Tuple, List, Optional, Generic
import itertools

Expand All @@ -9,7 +10,7 @@
PipelineStepInterfaceType,
)
from causy.models import ComparisonSettings, TestResultAction, TestResult
from causy.variables import IntegerParameter, BoolParameter
from causy.variables import IntegerParameter, BoolParameter, StringParameter


# theory for all orientation rules with pictures:
Expand All @@ -18,6 +19,58 @@
# TODO: refactor ColliderTest -> ColliderRule and move to folder orientation_rules (after checking for duplicates)


def filter_unapplied_actions(actions, u, v):
"""
Filter out actions that have not been applied to the graph yet.
:param actions: list of actions
:param u: node u
:param v: node v
:return: list of actions that have not been applied to the graph yet
"""
filtered = []
for result_set in actions:
if result_set is None:
continue
for result in result_set:
if result.u == u and result.v == v:
filtered.append(result)
return filtered


def generate_restores(unapplied_actions):
"""
Generate restore actions for unapplied actions.
:param unapplied_actions: list of unapplied actions
:param x: node x
:param y: node y
:return: list of restore actions
"""
results = []
for action in unapplied_actions:
if action.action == TestResultAction.REMOVE_EDGE_DIRECTED:
results.append(
TestResult(
u=action.u,
v=action.v,
action=TestResultAction.RESTORE_EDGE_DIRECTED,
data={},
)
)
return results


class ColliderTestConflictResolutionStrategies(enum.StrEnum):
"""
Enum for the conflict resolution strategies for the ColliderTest.
"""

# If a conflict occurs, the edge that was removed first is kept.
KEEP_FIRST = "KEEP_FIRST"

# If a conflict occurs, the edge that was removed last is kept.
KEEP_LAST = "KEEP_LAST"


class ColliderTest(
PipelineStepInterface[PipelineStepInterfaceType], Generic[PipelineStepInterfaceType]
):
Expand All @@ -27,6 +80,10 @@ class ColliderTest(
chunk_size_parallel_processing: IntegerParameter = 1
parallel: BoolParameter = False

conflict_resolution_strategy: StringParameter = (
ColliderTestConflictResolutionStrategies.KEEP_FIRST
)

needs_unapplied_actions: BoolParameter = True

def process(
Expand Down Expand Up @@ -75,20 +132,58 @@ def process(
separators += [a.id for a in action.data["separatedBy"]]

if z.id not in separators:
results += [
TestResult(
u=z,
v=x,
action=TestResultAction.REMOVE_EDGE_DIRECTED,
data={},
),
TestResult(
u=z,
v=y,
action=TestResultAction.REMOVE_EDGE_DIRECTED,
data={},
),
]
unapplied_actions_x_z = filter_unapplied_actions(
unapplied_actions, x, z
)
unapplied_actions_y_z = filter_unapplied_actions(
unapplied_actions, y, z
)
if len(unapplied_actions_y_z) > 0 or len(unapplied_actions_x_z) > 0:
if (
ColliderTestConflictResolutionStrategies.KEEP_FIRST
is self.conflict_resolution_strategy
):
# We keep the first edge that was removed
continue
elif (
ColliderTestConflictResolutionStrategies.KEEP_LAST
is self.conflict_resolution_strategy
):
# We keep the last edge that was removed and restore the other edges
results.extend(generate_restores(unapplied_actions_x_z))
results.extend(generate_restores(unapplied_actions_y_z))
results.append(
TestResult(
u=z,
v=x,
action=TestResultAction.REMOVE_EDGE_DIRECTED,
data={},
)
)
results.append(
TestResult(
u=z,
v=y,
action=TestResultAction.REMOVE_EDGE_DIRECTED,
data={},
)
)

else:
results += [
TestResult(
u=z,
v=x,
action=TestResultAction.REMOVE_EDGE_DIRECTED,
data={},
),
TestResult(
u=z,
v=y,
action=TestResultAction.REMOVE_EDGE_DIRECTED,
data={},
),
]
return results


Expand Down
7 changes: 6 additions & 1 deletion causy/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,12 @@ def resolve_variable_to_object(obj: Any, variables):
obj.__dict__[attribute] = variables[value.name]
else:
raise ValueError(f'Variable "{value.name}" not found in the variables.')
elif hasattr(value, "__dict__"):
elif (
hasattr(value, "__dict__")
and not isinstance(value, NoneType)
and not hasattr(value, "value")
):
# we check for value because we don't want to resolve the variable if it's a variable object itself or a Enum
obj.__dict__[attribute] = resolve_variable_to_object(value, variables)
return obj

Expand Down
Loading

0 comments on commit f04328a

Please sign in to comment.