diff --git a/graph.py b/graph.py index 2cd8bb7..ae15e09 100644 --- a/graph.py +++ b/graph.py @@ -11,12 +11,13 @@ IndependenceTestInterface, PartialCorrelationTest, CalculateCorrelations, - ExtendedPartialCorrelationTest, - UnshieldedTriplesTest, + ExtendedPartialCorrelationTestMatrix, PlaceholderTest, - ExtendedPartialCorrelationTest2, + ExtendedPartialCorrelationTestLinearRegression, ) +from orientation_tests import UnshieldedTripleColliderTest + from interfaces import ( BaseGraphInterface, NodeInterface, @@ -173,6 +174,22 @@ def edge_exists(self, u: Node, v: Node): return False if v not in self.edges: return False + if u not in self.edges[v]: + return False + if v not in self.edges[u]: + return False + return True + + def directed_edge_exists(self, u: Node, v: Node): + if u.name not in self.nodes: + return False + if v.name not in self.nodes: + return False + if u not in self.edges: + return False + if v not in self.edges[u]: + return False + return True def edge_value(self, u: Node, v: Node): return self.edges[u][v] @@ -385,8 +402,8 @@ def __init__( CalculateCorrelations(), CorrelationCoefficientTest(threshold=0.1), PartialCorrelationTest(threshold=0.1), - ExtendedPartialCorrelationTest2(threshold=0.1), - UnshieldedTriplesTest(), + ExtendedPartialCorrelationTestMatrix(threshold=0.1), + UnshieldedTripleColliderTest(), # check replacing it with a loop of ExtendedPartialCorrelationTest # Loop( # pipeline_steps=[ diff --git a/independence_tests.py b/independence_tests.py index d8c814a..ae54ec7 100644 --- a/independence_tests.py +++ b/independence_tests.py @@ -1,5 +1,5 @@ from statistics import correlation, covariance # , linear_regression -from typing import Tuple, List, Optional +from typing import Tuple, List import math # Use cupy for GPU support - if available - otherwise use numpy @@ -100,7 +100,7 @@ def test( self, nodes: Tuple[str], graph: BaseGraphInterface ) -> CorrelationTestResult: """ - Test if nodes x,y are independent with z as conditioning variable based on a partial correlation test. + Test if nodes x,y are independent given node z based on a partial correlation test. We use this test for all combinations of 3 nodes because it is faster than the extended test (which supports combinations of n nodes). We can use it to remove edges between nodes which are not independent given another node and so reduce the number of combinations for the extended test. :param nodes: the nodes to test @@ -156,7 +156,7 @@ def test( ) -class ExtendedPartialCorrelationTest(IndependenceTestInterface): +class ExtendedPartialCorrelationTestLinearRegression(IndependenceTestInterface): NUM_OF_COMPARISON_ELEMENTS = ComparisonSettings(min=5, max=AS_MANY_AS_FIELDS) CHUNK_SIZE_PARALLEL_PROCESSING = 1 PARALLEL = True @@ -165,9 +165,10 @@ def test( self, nodes: List[str], graph: BaseGraphInterface ) -> CorrelationTestResult: """ - Test if nodes x,y are independent with Z as a set of conditioning variables based on partial correlation using linear regression and a correlation test on the residuals. + Test if nodes x,y are independent given Z (set of nodes) based on partial correlation using linear regression and a correlation test on the residuals. We use this test for all combinations of more than 3 nodes because it is slower. - + :param nodes: the nodes to test + :return: A CorrelationTestResult with the action to take """ n = len(nodes) sample_size = len(graph.nodes[nodes[0]].values) @@ -205,58 +206,19 @@ def test( return results -class UnshieldedTriplesTest(IndependenceTestInterface): - NUM_OF_COMPARISON_ELEMENTS = 2 - CHUNK_SIZE_PARALLEL_PROCESSING = 1 - - def test( - self, nodes: Tuple[str], graph: BaseGraphInterface - ) -> List[CorrelationTestResult] | CorrelationTestResult: - # https://github.com/pgmpy/pgmpy/blob/1fe10598df5430295a8fc5cdca85cf2d9e1c4330/pgmpy/estimators/PC.py#L416 - - x = graph.nodes[nodes[0]] - y = graph.nodes[nodes[1]] - - if graph.edge_exists(x, y): - return CorrelationTestResult( - x=x, y=y, action=CorrelationTestResultAction.DO_NOTHING, data={} - ) - - potential_zs = set(graph.edges[x].keys()).intersection( - set(graph.edges[y].keys()) - ) - - for z in potential_zs: - separators = graph.retrieve_edge_history( - x, y, CorrelationTestResultAction.REMOVE_EDGE_UNDIRECTED - ) - - if z not in separators: - return [ - CorrelationTestResult( - x=z, - y=x, - action=CorrelationTestResultAction.REMOVE_EDGE_DIRECTED, - data={}, - ), - CorrelationTestResult( - x=z, - y=y, - action=CorrelationTestResultAction.REMOVE_EDGE_DIRECTED, - data={}, - ), - ] - - return CorrelationTestResult( - x=x, y=y, action=CorrelationTestResultAction.DO_NOTHING, data={} - ) - - -class ExtendedPartialCorrelationTest2(IndependenceTestInterface): +class ExtendedPartialCorrelationTestMatrix(IndependenceTestInterface): NUM_OF_COMPARISON_ELEMENTS = ComparisonSettings(min=4, max=AS_MANY_AS_FIELDS) CHUNK_SIZE_PARALLEL_PROCESSING = 1 PARALLEL = True + """ + Test if nodes x,y are independent given Z (set of nodes) based on partial correlation using the inverted covariance matrix (precision matrix). + https://en.wikipedia.org/wiki/Partial_correlation#Using_matrix_inversion + We use this test for all combinations of more than 3 nodes because it is slower. + :param nodes: the nodes to test + :return: A CorrelationTestResult with the action to take + """ + def test( self, nodes: List[str], graph: BaseGraphInterface ) -> CorrelationTestResult: diff --git a/orientation_tests.py b/orientation_tests.py new file mode 100644 index 0000000..dfce64f --- /dev/null +++ b/orientation_tests.py @@ -0,0 +1,224 @@ +from typing import Tuple, List +import itertools + +from interfaces import BaseGraphInterface, CorrelationTestResult, CorrelationTestResultAction, IndependenceTestInterface + + +class UnshieldedTripleColliderTest(IndependenceTestInterface): + NUM_OF_COMPARISON_ELEMENTS = 2 + CHUNK_SIZE_PARALLEL_PROCESSING = 1 + + def test( + self, nodes: Tuple[str], graph: BaseGraphInterface + ) -> List[CorrelationTestResult] | CorrelationTestResult: + """ + For all nodes x and y that are not adjacent but share an adjacent node z, we check if z is in the seperating set. + If z is not in the seperating set, we know that x and y are uncorrelated given z, so the edges must be oriented from x to z and from y to z. + :param nodes: list of nodes + :param graph: the current graph + :returns: list of actions that will be executed on graph + """ + # https://github.com/pgmpy/pgmpy/blob/1fe10598df5430295a8fc5cdca85cf2d9e1c4330/pgmpy/estimators/PC.py#L416 + + x = graph.nodes[nodes[0]] + y = graph.nodes[nodes[1]] + + # if x and y are adjacent, do nothing + if graph.edge_exists(x, y): + return CorrelationTestResult( + x=x, y=y, action=CorrelationTestResultAction.DO_NOTHING, data={} + ) + + # if x and y are NOT adjacent, store all shared adjacent nodes + potential_zs = set(graph.edges[x].keys()).intersection( + set(graph.edges[y].keys()) + ) + + # if x and y are not independent given z, safe action: make z a collider + results = [] + for z in potential_zs: + separators = graph.retrieve_edge_history( + x, y, CorrelationTestResultAction.REMOVE_EDGE_UNDIRECTED + ) + + if z not in separators: + results += [ + CorrelationTestResult( + x=z, + y=x, + action=CorrelationTestResultAction.REMOVE_EDGE_DIRECTED, + data={}, + ), + CorrelationTestResult( + x=z, + y=y, + action=CorrelationTestResultAction.REMOVE_EDGE_DIRECTED, + data={}, + ), + ] + return results + + + +class UnshieldedTripleNonColliderTest(IndependenceTestInterface): + NUM_OF_COMPARISON_ELEMENTS = 2 + CHUNK_SIZE_PARALLEL_PROCESSING = 1 + + def test( + self, nodes: Tuple[str], graph: BaseGraphInterface + ) -> List[CorrelationTestResult] | CorrelationTestResult: + """ + Further orientation rule. + :param nodes: list of nodes + :param graph: the current graph + :returns: list of actions that will be executed on graph + """ + + x = graph.nodes[nodes[0]] + y = graph.nodes[nodes[1]] + + # if x and y are adjacent, do nothing + if graph.edge_exists(x, y): + return CorrelationTestResult( + x=x, y=y, action=CorrelationTestResultAction.DO_NOTHING, data={} + ) + + # if x and y are NOT adjacent, store all shared adjacent nodes + potential_zs = set(graph.edges[x].keys()).intersection( + set(graph.edges[y].keys()) + ) + results = [] + for z in potential_zs: + if graph.directed_edge_exists(x, z): + if graph.edge_exists(y, z): + results.append(CorrelationTestResult( + x=z, + y=y, + action=CorrelationTestResultAction.REMOVE_EDGE_DIRECTED, + data={}, + )) + return results + + + +class FurtherOrientTripleTest(IndependenceTestInterface): + NUM_OF_COMPARISON_ELEMENTS = 2 + CHUNK_SIZE_PARALLEL_PROCESSING = 1 + + def test( + self, nodes: Tuple[str], graph: BaseGraphInterface + ) -> List[CorrelationTestResult] | CorrelationTestResult: + """ + Further orientation rule. + :param nodes: list of nodes + :param graph: the current graph + :returns: list of actions that will be executed on graph + """ + + x = graph.nodes[nodes[0]] + y = graph.nodes[nodes[1]] + + potential_zs = set(graph.edges[x].keys()).intersection( + set(graph.edges[y].keys()) + ) + + results = [] + for z in potential_zs: + if graph.edge_exists(x,y) and graph.directed_edge_exists(x,z) and graph.directed_edge_exists(z,y): + results.append(CorrelationTestResult( + x=y, + y=x, + action=CorrelationTestResultAction.REMOVE_EDGE_DIRECTED, + data={}, + )) + if graph.edge_exists(x,y) and graph.directed_edge_exists(y,z) and graph.directed_edge_exists(z,x): + results.append(CorrelationTestResult( + x=x, + y=y, + action=CorrelationTestResultAction.REMOVE_EDGE_DIRECTED, + data={}, + )) + return results + + +class OrientQuadrupleTest(IndependenceTestInterface): + NUM_OF_COMPARISON_ELEMENTS = 2 + CHUNK_SIZE_PARALLEL_PROCESSING = 1 + + def test( + self, nodes: Tuple[str], graph: BaseGraphInterface + ) -> List[CorrelationTestResult] | CorrelationTestResult: + """ + Further orientation rule. + :param nodes: list of nodes + :param graph: the current graph + :returns: list of actions that will be executed on graph + """ + + x = graph.nodes[nodes[0]] + y = graph.nodes[nodes[1]] + + potential_zs = set(graph.edges[x].keys()).intersection( + set(graph.edges[y].keys()) + ) + + results = [] + for z in potential_zs: + if graph.edge_exists(x,y) and graph.directed_edge_exists(x,z) and graph.directed_edge_exists(z,y): + results.append(CorrelationTestResult( + x=y, + y=x, + action=CorrelationTestResultAction.REMOVE_EDGE_DIRECTED, + data={}, + )) + if graph.edge_exists(x,y) and graph.directed_edge_exists(y,z) and graph.directed_edge_exists(z,x): + results.append(CorrelationTestResult( + x=x, + y=y, + action=CorrelationTestResultAction.REMOVE_EDGE_DIRECTED, + data={}, + )) + return results + +class OrientQuadrupleTest(IndependenceTestInterface): + NUM_OF_COMPARISON_ELEMENTS = 2 + CHUNK_SIZE_PARALLEL_PROCESSING = 1 + + def test( + self, nodes: Tuple[str], graph: BaseGraphInterface + ) -> List[CorrelationTestResult] | CorrelationTestResult: + """ + Further orientation rule. + :param nodes: list of nodes + :param graph: the current graph + :returns: list of actions that will be executed on graph + """ + + x = graph.nodes[nodes[0]] + y = graph.nodes[nodes[1]] + + potential_zs = set(graph.edges[x].keys()).intersection( + set(graph.edges[y].keys()) + ) + + results = [] + for zs in itertools.combinations(potential_zs, 2): + z = zs[0] + w = zs[1] + if not graph.edge_exists(x,y) and graph.directed_edge_exists(x,z) and graph.directed_edge_exists(y,z) \ + and graph.edge_exists(x,w) and graph.edge_exists(y,w): + results.append(CorrelationTestResult( + x=z, + y=w, + action=CorrelationTestResultAction.REMOVE_EDGE_DIRECTED, + data={}, + )) + if not graph.edge_exists(x,y) and graph.directed_edge_exists(x,w) and graph.directed_edge_exists(y,w) \ + and graph.edge_exists(x,z) and graph.edge_exists(y,z): + results.append(CorrelationTestResult( + x=w, + y=z, + action=CorrelationTestResultAction.REMOVE_EDGE_DIRECTED, + data={}, + )) + return results \ No newline at end of file diff --git a/pipelines/pc.json b/pipelines/pc.json index 5718537..7acf2d3 100644 --- a/pipelines/pc.json +++ b/pipelines/pc.json @@ -19,13 +19,28 @@ } }, { - "step": "independence_tests.ExtendedPartialCorrelationTest2", + "step": "independence_tests.ExtendedPartialCorrelationTestMatrix", "params": { "threshold": 0.01 } }, { - "step": "independence_tests.UnshieldedTriplesTest", + "step": "orientation_tests.UnshieldedTripleColliderTest", + "params": { + } + }, + { + "step": "orientation_tests.UnshieldedTripleNonColliderTest", + "params": { + } + }, + { + "step": "orientation_tests.FurtherOrientTripleTest", + "params": { + } + }, + { + "step": "orientation_tests.OrientQuadrupleTest", "params": { } }