Skip to content

Commit

Permalink
feat(orientation_tests): add orientation tests UnshieldedTripleCollid…
Browse files Browse the repository at this point in the history
…erTest, UnshieldedTripleNonColliderTest, FurtherOrientTripleTest, OrientQuadrupleTest
  • Loading branch information
this-is-sofia committed Oct 1, 2023
1 parent bf1831a commit 5b2b2b0
Show file tree
Hide file tree
Showing 4 changed files with 278 additions and 60 deletions.
27 changes: 22 additions & 5 deletions graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
IndependenceTestInterface,
PartialCorrelationTest,
CalculateCorrelations,
ExtendedPartialCorrelationTest,
UnshieldedTriplesTest,
ExtendedPartialCorrelationTestMatrix,
PlaceholderTest,
ExtendedPartialCorrelationTest2,
ExtendedPartialCorrelationTestLinearRegression,
)

from orientation_tests import UnshieldedTripleColliderTest

from interfaces import (
BaseGraphInterface,
NodeInterface,
Expand Down Expand Up @@ -173,6 +174,22 @@ def edge_exists(self, u: Node, v: Node):
return False
if v not in self.edges:
return False
if u not in self.edges[v]:
return False
if v not in self.edges[u]:
return False
return True

def directed_edge_exists(self, u: Node, v: Node):
if u.name not in self.nodes:
return False
if v.name not in self.nodes:
return False
if u not in self.edges:
return False
if v not in self.edges[u]:
return False
return True

def edge_value(self, u: Node, v: Node):
return self.edges[u][v]
Expand Down Expand Up @@ -385,8 +402,8 @@ def __init__(
CalculateCorrelations(),
CorrelationCoefficientTest(threshold=0.1),
PartialCorrelationTest(threshold=0.1),
ExtendedPartialCorrelationTest2(threshold=0.1),
UnshieldedTriplesTest(),
ExtendedPartialCorrelationTestMatrix(threshold=0.1),
UnshieldedTripleColliderTest(),
# check replacing it with a loop of ExtendedPartialCorrelationTest
# Loop(
# pipeline_steps=[
Expand Down
68 changes: 15 additions & 53 deletions independence_tests.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from statistics import correlation, covariance # , linear_regression
from typing import Tuple, List, Optional
from typing import Tuple, List
import math

# Use cupy for GPU support - if available - otherwise use numpy
Expand Down Expand Up @@ -100,7 +100,7 @@ def test(
self, nodes: Tuple[str], graph: BaseGraphInterface
) -> CorrelationTestResult:
"""
Test if nodes x,y are independent with z as conditioning variable based on a partial correlation test.
Test if nodes x,y are independent given node z based on a partial correlation test.
We use this test for all combinations of 3 nodes because it is faster than the extended test (which supports combinations of n nodes). We can
use it to remove edges between nodes which are not independent given another node and so reduce the number of combinations for the extended test.
:param nodes: the nodes to test
Expand Down Expand Up @@ -156,7 +156,7 @@ def test(
)


class ExtendedPartialCorrelationTest(IndependenceTestInterface):
class ExtendedPartialCorrelationTestLinearRegression(IndependenceTestInterface):
NUM_OF_COMPARISON_ELEMENTS = ComparisonSettings(min=5, max=AS_MANY_AS_FIELDS)
CHUNK_SIZE_PARALLEL_PROCESSING = 1
PARALLEL = True
Expand All @@ -165,9 +165,10 @@ def test(
self, nodes: List[str], graph: BaseGraphInterface
) -> CorrelationTestResult:
"""
Test if nodes x,y are independent with Z as a set of conditioning variables based on partial correlation using linear regression and a correlation test on the residuals.
Test if nodes x,y are independent given Z (set of nodes) based on partial correlation using linear regression and a correlation test on the residuals.
We use this test for all combinations of more than 3 nodes because it is slower.
:param nodes: the nodes to test
:return: A CorrelationTestResult with the action to take
"""
n = len(nodes)
sample_size = len(graph.nodes[nodes[0]].values)
Expand Down Expand Up @@ -205,58 +206,19 @@ def test(
return results


class UnshieldedTriplesTest(IndependenceTestInterface):
NUM_OF_COMPARISON_ELEMENTS = 2
CHUNK_SIZE_PARALLEL_PROCESSING = 1

def test(
self, nodes: Tuple[str], graph: BaseGraphInterface
) -> List[CorrelationTestResult] | CorrelationTestResult:
# https://github.com/pgmpy/pgmpy/blob/1fe10598df5430295a8fc5cdca85cf2d9e1c4330/pgmpy/estimators/PC.py#L416

x = graph.nodes[nodes[0]]
y = graph.nodes[nodes[1]]

if graph.edge_exists(x, y):
return CorrelationTestResult(
x=x, y=y, action=CorrelationTestResultAction.DO_NOTHING, data={}
)

potential_zs = set(graph.edges[x].keys()).intersection(
set(graph.edges[y].keys())
)

for z in potential_zs:
separators = graph.retrieve_edge_history(
x, y, CorrelationTestResultAction.REMOVE_EDGE_UNDIRECTED
)

if z not in separators:
return [
CorrelationTestResult(
x=z,
y=x,
action=CorrelationTestResultAction.REMOVE_EDGE_DIRECTED,
data={},
),
CorrelationTestResult(
x=z,
y=y,
action=CorrelationTestResultAction.REMOVE_EDGE_DIRECTED,
data={},
),
]

return CorrelationTestResult(
x=x, y=y, action=CorrelationTestResultAction.DO_NOTHING, data={}
)


class ExtendedPartialCorrelationTest2(IndependenceTestInterface):
class ExtendedPartialCorrelationTestMatrix(IndependenceTestInterface):
NUM_OF_COMPARISON_ELEMENTS = ComparisonSettings(min=4, max=AS_MANY_AS_FIELDS)
CHUNK_SIZE_PARALLEL_PROCESSING = 1
PARALLEL = True

"""
Test if nodes x,y are independent given Z (set of nodes) based on partial correlation using the inverted covariance matrix (precision matrix).
https://en.wikipedia.org/wiki/Partial_correlation#Using_matrix_inversion
We use this test for all combinations of more than 3 nodes because it is slower.
:param nodes: the nodes to test
:return: A CorrelationTestResult with the action to take
"""

def test(
self, nodes: List[str], graph: BaseGraphInterface
) -> CorrelationTestResult:
Expand Down
224 changes: 224 additions & 0 deletions orientation_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
from typing import Tuple, List
import itertools

from interfaces import BaseGraphInterface, CorrelationTestResult, CorrelationTestResultAction, IndependenceTestInterface


class UnshieldedTripleColliderTest(IndependenceTestInterface):
NUM_OF_COMPARISON_ELEMENTS = 2
CHUNK_SIZE_PARALLEL_PROCESSING = 1

def test(
self, nodes: Tuple[str], graph: BaseGraphInterface
) -> List[CorrelationTestResult] | CorrelationTestResult:
"""
For all nodes x and y that are not adjacent but share an adjacent node z, we check if z is in the seperating set.
If z is not in the seperating set, we know that x and y are uncorrelated given z, so the edges must be oriented from x to z and from y to z.
:param nodes: list of nodes
:param graph: the current graph
:returns: list of actions that will be executed on graph
"""
# https://github.com/pgmpy/pgmpy/blob/1fe10598df5430295a8fc5cdca85cf2d9e1c4330/pgmpy/estimators/PC.py#L416

x = graph.nodes[nodes[0]]
y = graph.nodes[nodes[1]]

# if x and y are adjacent, do nothing
if graph.edge_exists(x, y):
return CorrelationTestResult(
x=x, y=y, action=CorrelationTestResultAction.DO_NOTHING, data={}
)

# if x and y are NOT adjacent, store all shared adjacent nodes
potential_zs = set(graph.edges[x].keys()).intersection(
set(graph.edges[y].keys())
)

# if x and y are not independent given z, safe action: make z a collider
results = []
for z in potential_zs:
separators = graph.retrieve_edge_history(
x, y, CorrelationTestResultAction.REMOVE_EDGE_UNDIRECTED
)

if z not in separators:
results += [
CorrelationTestResult(
x=z,
y=x,
action=CorrelationTestResultAction.REMOVE_EDGE_DIRECTED,
data={},
),
CorrelationTestResult(
x=z,
y=y,
action=CorrelationTestResultAction.REMOVE_EDGE_DIRECTED,
data={},
),
]
return results



class UnshieldedTripleNonColliderTest(IndependenceTestInterface):
NUM_OF_COMPARISON_ELEMENTS = 2
CHUNK_SIZE_PARALLEL_PROCESSING = 1

def test(
self, nodes: Tuple[str], graph: BaseGraphInterface
) -> List[CorrelationTestResult] | CorrelationTestResult:
"""
Further orientation rule.
:param nodes: list of nodes
:param graph: the current graph
:returns: list of actions that will be executed on graph
"""

x = graph.nodes[nodes[0]]
y = graph.nodes[nodes[1]]

# if x and y are adjacent, do nothing
if graph.edge_exists(x, y):
return CorrelationTestResult(
x=x, y=y, action=CorrelationTestResultAction.DO_NOTHING, data={}
)

# if x and y are NOT adjacent, store all shared adjacent nodes
potential_zs = set(graph.edges[x].keys()).intersection(
set(graph.edges[y].keys())
)
results = []
for z in potential_zs:
if graph.directed_edge_exists(x, z):
if graph.edge_exists(y, z):
results.append(CorrelationTestResult(
x=z,
y=y,
action=CorrelationTestResultAction.REMOVE_EDGE_DIRECTED,
data={},
))
return results



class FurtherOrientTripleTest(IndependenceTestInterface):
NUM_OF_COMPARISON_ELEMENTS = 2
CHUNK_SIZE_PARALLEL_PROCESSING = 1

def test(
self, nodes: Tuple[str], graph: BaseGraphInterface
) -> List[CorrelationTestResult] | CorrelationTestResult:
"""
Further orientation rule.
:param nodes: list of nodes
:param graph: the current graph
:returns: list of actions that will be executed on graph
"""

x = graph.nodes[nodes[0]]
y = graph.nodes[nodes[1]]

potential_zs = set(graph.edges[x].keys()).intersection(
set(graph.edges[y].keys())
)

results = []
for z in potential_zs:
if graph.edge_exists(x,y) and graph.directed_edge_exists(x,z) and graph.directed_edge_exists(z,y):
results.append(CorrelationTestResult(
x=y,
y=x,
action=CorrelationTestResultAction.REMOVE_EDGE_DIRECTED,
data={},
))
if graph.edge_exists(x,y) and graph.directed_edge_exists(y,z) and graph.directed_edge_exists(z,x):
results.append(CorrelationTestResult(
x=x,
y=y,
action=CorrelationTestResultAction.REMOVE_EDGE_DIRECTED,
data={},
))
return results


class OrientQuadrupleTest(IndependenceTestInterface):
NUM_OF_COMPARISON_ELEMENTS = 2
CHUNK_SIZE_PARALLEL_PROCESSING = 1

def test(
self, nodes: Tuple[str], graph: BaseGraphInterface
) -> List[CorrelationTestResult] | CorrelationTestResult:
"""
Further orientation rule.
:param nodes: list of nodes
:param graph: the current graph
:returns: list of actions that will be executed on graph
"""

x = graph.nodes[nodes[0]]
y = graph.nodes[nodes[1]]

potential_zs = set(graph.edges[x].keys()).intersection(
set(graph.edges[y].keys())
)

results = []
for z in potential_zs:
if graph.edge_exists(x,y) and graph.directed_edge_exists(x,z) and graph.directed_edge_exists(z,y):
results.append(CorrelationTestResult(
x=y,
y=x,
action=CorrelationTestResultAction.REMOVE_EDGE_DIRECTED,
data={},
))
if graph.edge_exists(x,y) and graph.directed_edge_exists(y,z) and graph.directed_edge_exists(z,x):
results.append(CorrelationTestResult(
x=x,
y=y,
action=CorrelationTestResultAction.REMOVE_EDGE_DIRECTED,
data={},
))
return results

class OrientQuadrupleTest(IndependenceTestInterface):
NUM_OF_COMPARISON_ELEMENTS = 2
CHUNK_SIZE_PARALLEL_PROCESSING = 1

def test(
self, nodes: Tuple[str], graph: BaseGraphInterface
) -> List[CorrelationTestResult] | CorrelationTestResult:
"""
Further orientation rule.
:param nodes: list of nodes
:param graph: the current graph
:returns: list of actions that will be executed on graph
"""

x = graph.nodes[nodes[0]]
y = graph.nodes[nodes[1]]

potential_zs = set(graph.edges[x].keys()).intersection(
set(graph.edges[y].keys())
)

results = []
for zs in itertools.combinations(potential_zs, 2):
z = zs[0]
w = zs[1]
if not graph.edge_exists(x,y) and graph.directed_edge_exists(x,z) and graph.directed_edge_exists(y,z) \
and graph.edge_exists(x,w) and graph.edge_exists(y,w):
results.append(CorrelationTestResult(
x=z,
y=w,
action=CorrelationTestResultAction.REMOVE_EDGE_DIRECTED,
data={},
))
if not graph.edge_exists(x,y) and graph.directed_edge_exists(x,w) and graph.directed_edge_exists(y,w) \
and graph.edge_exists(x,z) and graph.edge_exists(y,z):
results.append(CorrelationTestResult(
x=w,
y=z,
action=CorrelationTestResultAction.REMOVE_EDGE_DIRECTED,
data={},
))
return results
Loading

0 comments on commit 5b2b2b0

Please sign in to comment.