From 0fe7866534ea06cb2dc47b02247bc68d2e039182 Mon Sep 17 00:00:00 2001 From: vbcwonderland Date: Sun, 7 Apr 2024 15:02:46 -0400 Subject: [PATCH 1/6] partite all nodes into pass or not pass the AndersonDarling Test --- .../edu/cmu/tetrad/search/MarkovCheck.java | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java index 11893865e8..57ebb3ae4c 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java @@ -251,6 +251,27 @@ public Double checkAgainstAndersonDarlingTest(List pValues) { return generalAndersonDarlingTest.getP(); } + public List> getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(IndependenceTest independenceTest, Graph graph, Double threshold) { + // when calling, default reject null as <=0.05 + List> accepts_rejects = new ArrayList<>(); + List accepts = new ArrayList<>(); + List rejects = new ArrayList<>(); + List allNodes = graph.getNodes(); + for (Node x : allNodes) { + List localIndependenceFacts = getLocalIndependenceFacts(x); + List localPValues = getLocalPValues(independenceTest, localIndependenceFacts); + Double ADTest = checkAgainstAndersonDarlingTest(localPValues); + if (ADTest <= threshold) { + rejects.add(x); + } else { + accepts.add(x); + } + } + accepts_rejects.add(accepts); + accepts_rejects.add(rejects); + return accepts_rejects; + } + /** * Returns the variables of the independence test. From 3443c1d5ec1cb82de000e86909583a380bc94a32 Mon Sep 17 00:00:00 2001 From: vbcwonderland Date: Sun, 7 Apr 2024 18:23:43 -0400 Subject: [PATCH 2/6] get markov banket for each node and calculate against true graph for precission and recall --- .../java/edu/cmu/tetrad/graph/GraphNode.java | 2 +- .../java/edu/cmu/tetrad/graph/GraphUtils.java | 4 +- .../edu/cmu/tetrad/search/MarkovCheck.java | 28 +++++++++++-- .../edu/cmu/tetrad/test/TestCheckMarkov.java | 39 +++++++++++++++++++ 4 files changed, 66 insertions(+), 7 deletions(-) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphNode.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphNode.java index d621dd50b1..b624c4a66b 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphNode.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphNode.java @@ -47,7 +47,7 @@ public class GraphNode implements Node { /** * The name of the node. */ - private String name = "??"; + private String name = "??"; // TODO VBC: can we change this into a seriel number /** * The type of the node. * diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphUtils.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphUtils.java index 4074fc28d2..43bae6d55b 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphUtils.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphUtils.java @@ -2253,7 +2253,7 @@ public static Graph trimGraph(List targets, Graph graph, int trimmingStyle graph = trimAdjacentToTarget(targets, graph); break; case 3: - graph = trimMarkovBlanketGraph(targets, graph); + graph = trimMarkovBlanketGraph(targets, graph); // TODO VBC this is what i want to use break; case 4: graph = trimSemidirected(targets, graph); @@ -2298,7 +2298,7 @@ private static Graph trimAdjacentToTarget(List targets, Graph graph) { * @param graph the original graph from which the Markov blanket graph is derived * @return the trimmed Markov blanket graph */ - private static Graph trimMarkovBlanketGraph(List targets, Graph graph) { + private static Graph trimMarkovBlanketGraph(List targets, Graph graph) { // TODO vbc this is Graph mbDag = new EdgeListGraph(graph); M: diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java index 57ebb3ae4c..e67efc88f6 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java @@ -2,10 +2,7 @@ import edu.cmu.tetrad.data.GeneralAndersonDarlingTest; import edu.cmu.tetrad.data.Knowledge; -import edu.cmu.tetrad.graph.Graph; -import edu.cmu.tetrad.graph.GraphUtils; -import edu.cmu.tetrad.graph.IndependenceFact; -import edu.cmu.tetrad.graph.Node; +import edu.cmu.tetrad.graph.*; import edu.cmu.tetrad.search.test.*; import edu.cmu.tetrad.util.SublistGenerator; import edu.cmu.tetrad.util.TetradLogger; @@ -272,6 +269,29 @@ public List> getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(Ind return accepts_rejects; } + public Double getPrecisionOrRecallOnMarkovBlanketGraph(Node x, Graph estimatedGraph, Graph trueGraph, boolean getPrecision) { + List singleNode = Arrays.asList(x); + String nodeName = x.getName(); + Node originalNode = trueGraph.getNode(nodeName); + List singleOriginalNode = Arrays.asList(originalNode); + + // Get Markov Blanket Subgraph for this node x. + Graph xMBGraphForEstimatedGraph = GraphUtils.trimGraph(singleNode, estimatedGraph, 3); + + Graph xMBGraphForTrueGraph = GraphUtils.trimGraph(singleOriginalNode, trueGraph, 3); // TODO VBC, this is always 0 because singleNode is not in trueGraph! + // TODO VBC: is using name to find x's corresponding node in trueGraph the right way? + Set xMBGraphForEstimatedGraphEdges = xMBGraphForEstimatedGraph.getEdges(); // TODO: this is often 435 for 30 nodes + Set xMBGraphForTrueGraphEdges = xMBGraphForTrueGraph.getEdges(); // TODO: Here the output is often 0/1 + System.out.println("xMBGraphForTrueGraphEdges size: " + xMBGraphForTrueGraphEdges.size()); // TODO VBC + System.out.println("xMBGraphTrue Nodes size: " + xMBGraphForTrueGraph.getNodes().size()); // TODO VBC this is always 0. + + HashSet truePositive = new HashSet<>(xMBGraphForEstimatedGraphEdges); + truePositive.retainAll(xMBGraphForTrueGraphEdges); + double precision = (double) truePositive.size() / xMBGraphForTrueGraphEdges.size(); + double recall = (double) truePositive.size() / xMBGraphForEstimatedGraphEdges.size(); + return getPrecision ? precision : recall; + } + /** * Returns the variables of the independence test. diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java index 0d2dbb54f1..4f92cd098e 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java @@ -16,6 +16,7 @@ import edu.cmu.tetrad.util.NumberFormatUtil; import org.junit.Test; +import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; import java.util.List; @@ -110,4 +111,42 @@ public void test2() { System.out.println(markovCheck.getMarkovCheckRecordString()); } + + @Test + public void testPrecissionRecallForLocal() { + Graph dag = RandomGraph.randomDag(30, 0, 10, 100, 100, + 100, false); // truegraph + SemPm pm = new SemPm(dag); + SemIm im = new SemIm(pm); + DataSet data = im.simulateData(500, false); + SemBicScore score = new SemBicScore(data, true); + PermutationSearch search = new PermutationSearch(new Boss(score)); + Graph cpdag = search.search(); // estimatedgraph + IndependenceTest fisherZTest = new IndTestFisherZ(data, 0.05); + MarkovCheck markovCheck = new MarkovCheck(cpdag, fisherZTest, ConditioningSetType.LOCAL_MARKOV); + + List> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, cpdag, 0.05); + List accepts = accepts_rejects.get(0); + List rejects = accepts_rejects.get(1); + System.out.println("Accepts size: " + accepts.size()); + System.out.println("Rejects size: " + rejects.size()); + System.out.println("Estimated Graph size: " + cpdag.getNodes().size()); + System.out.println("True Graph size: " + dag.getNodes().size()); + + + + + List acceptsPrecision = new ArrayList<>(); + List acceptsRecall = new ArrayList<>(); + for(Node a: accepts) { + double precision = markovCheck.getPrecisionOrRecallOnMarkovBlanketGraph(a, cpdag, dag, true); + double recall = markovCheck.getPrecisionOrRecallOnMarkovBlanketGraph(a, cpdag, dag, false); + acceptsPrecision.add(precision); + acceptsRecall.add(recall); + } + System.out.println("Accepts Precissions: " + acceptsPrecision); + System.out.println("Accepts Recall: " + acceptsRecall); + + + } } From 2355c6ae961b1f13f9ca43a9d8f3ab0582afd2f8 Mon Sep 17 00:00:00 2001 From: vbcwonderland Date: Mon, 8 Apr 2024 21:11:34 -0400 Subject: [PATCH 3/6] Accepts and Recalls with process prints --- .../java/edu/cmu/tetrad/graph/GraphUtils.java | 4 +- .../edu/cmu/tetrad/search/MarkovCheck.java | 44 ++++++++++++------- .../edu/cmu/tetrad/test/TestCheckMarkov.java | 40 +++++++++-------- 3 files changed, 51 insertions(+), 37 deletions(-) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphUtils.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphUtils.java index 43bae6d55b..56de316a58 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphUtils.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphUtils.java @@ -103,7 +103,7 @@ public static boolean isClique(Collection set, Graph graph) { * @param graph a DAG, CPDAG, MAG, or PAG. * @return a {@link edu.cmu.tetrad.graph.Graph} object */ - public static Graph markovBlanketSubgraph(Node target, Graph graph) { + public static Graph markovBlanketSubgraph(Node target, Graph graph) { // TODO VBC: @Joe is this the more general method you recommended? Set mb = markovBlanket(target, graph); Graph mbGraph = new EdgeListGraph(); @@ -2253,7 +2253,7 @@ public static Graph trimGraph(List targets, Graph graph, int trimmingStyle graph = trimAdjacentToTarget(targets, graph); break; case 3: - graph = trimMarkovBlanketGraph(targets, graph); // TODO VBC this is what i want to use + graph = trimMarkovBlanketGraph(targets, graph); // TODO VBC currently using this break; case 4: graph = trimSemidirected(targets, graph); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java index e67efc88f6..b31cf0e689 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java @@ -271,24 +271,36 @@ public List> getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(Ind public Double getPrecisionOrRecallOnMarkovBlanketGraph(Node x, Graph estimatedGraph, Graph trueGraph, boolean getPrecision) { List singleNode = Arrays.asList(x); - String nodeName = x.getName(); - Node originalNode = trueGraph.getNode(nodeName); - List singleOriginalNode = Arrays.asList(originalNode); + // Lookup graph is the same structure as trueGraph's structure but node objects replaced by estimated graph nodes. + Graph lookupGraph = GraphUtils.replaceNodes(trueGraph, estimatedGraph.getNodes()); + + // TODO VBC use the recurssion method + Graph xMBLookupGraph = GraphUtils.trimGraph(singleNode, lookupGraph, 3); + Set xMBLookupGraphEdges = xMBLookupGraph.getEdges(); + System.out.println("@@@@@@@@@@@@@@@@"); + System.out.println("True Graph:" + trueGraph); + System.out.println("LookupGraph:" + lookupGraph); // this print should be the same as the true graph + + System.out.println("xMBLookupGraphEdges size: " + xMBLookupGraphEdges.size()); + System.out.println("xMBLookupGraph Nodes size: " + xMBLookupGraph.getNodes().size()); + System.out.println("xMBLookupGraph:" + xMBLookupGraph); // The MB trim of the lookup graph, so it should be a subset of the lookup graph print // Get Markov Blanket Subgraph for this node x. - Graph xMBGraphForEstimatedGraph = GraphUtils.trimGraph(singleNode, estimatedGraph, 3); - - Graph xMBGraphForTrueGraph = GraphUtils.trimGraph(singleOriginalNode, trueGraph, 3); // TODO VBC, this is always 0 because singleNode is not in trueGraph! - // TODO VBC: is using name to find x's corresponding node in trueGraph the right way? - Set xMBGraphForEstimatedGraphEdges = xMBGraphForEstimatedGraph.getEdges(); // TODO: this is often 435 for 30 nodes - Set xMBGraphForTrueGraphEdges = xMBGraphForTrueGraph.getEdges(); // TODO: Here the output is often 0/1 - System.out.println("xMBGraphForTrueGraphEdges size: " + xMBGraphForTrueGraphEdges.size()); // TODO VBC - System.out.println("xMBGraphTrue Nodes size: " + xMBGraphForTrueGraph.getNodes().size()); // TODO VBC this is always 0. - - HashSet truePositive = new HashSet<>(xMBGraphForEstimatedGraphEdges); - truePositive.retainAll(xMBGraphForTrueGraphEdges); - double precision = (double) truePositive.size() / xMBGraphForTrueGraphEdges.size(); - double recall = (double) truePositive.size() / xMBGraphForEstimatedGraphEdges.size(); + Graph xMBEstimatedGraph = GraphUtils.trimGraph(singleNode, estimatedGraph, 3); + Set xMBEstimatedGraphEdges = xMBEstimatedGraph.getEdges(); + System.out.println("xMBEstimatedGraphEdges size: " + xMBEstimatedGraphEdges.size()); + System.out.println("xMBEstimatedGraph Nodes size: " + xMBEstimatedGraph.getNodes().size()); + System.out.println("xMBEstimatedGraph:" + xMBEstimatedGraph); // This should be compared with the xMBLookupGraph + System.out.println("@@@@@@@@@@@@@@@@"); + + HashSet truePositive = new HashSet<>(xMBEstimatedGraphEdges); + // TODO VBC: QUESTION FOr DISCUSSION + // Here it would only be retained if the points and direction of an edge is exactly the same. + // Do we want to only check for points? If not, a lot wrong/no direction edges would be filtered out at this step. + truePositive.retainAll(xMBLookupGraphEdges); + + double precision = (double) truePositive.size() / xMBLookupGraphEdges.size(); + double recall = (double) truePositive.size() / xMBEstimatedGraphEdges.size(); return getPrecision ? precision : recall; } diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java index 4f92cd098e..c6399b0d46 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java @@ -14,6 +14,7 @@ import edu.cmu.tetrad.sem.SemIm; import edu.cmu.tetrad.sem.SemPm; import edu.cmu.tetrad.util.NumberFormatUtil; +import edu.cmu.tetrad.util.Parameters; import org.junit.Test; import java.util.ArrayList; @@ -114,38 +115,39 @@ public void test2() { @Test public void testPrecissionRecallForLocal() { - Graph dag = RandomGraph.randomDag(30, 0, 10, 100, 100, - 100, false); // truegraph - SemPm pm = new SemPm(dag); - SemIm im = new SemIm(pm); - DataSet data = im.simulateData(500, false); - SemBicScore score = new SemBicScore(data, true); - PermutationSearch search = new PermutationSearch(new Boss(score)); - Graph cpdag = search.search(); // estimatedgraph - IndependenceTest fisherZTest = new IndTestFisherZ(data, 0.05); - MarkovCheck markovCheck = new MarkovCheck(cpdag, fisherZTest, ConditioningSetType.LOCAL_MARKOV); + Graph trueGraph = RandomGraph.randomDag(10, 0, 10, 100, 100, 100, false); + System.out.println("Test True Graph: " + trueGraph); + System.out.println("Test True Graph size: " + trueGraph.getNodes().size()); + + SemPm pm = new SemPm(trueGraph); + SemIm im = new SemIm(pm, new Parameters()); + DataSet data = im.simulateData(1000, false); + edu.cmu.tetrad.search.score.SemBicScore score = new SemBicScore(data, false); + score.setPenaltyDiscount(2); + Graph estimatedCpdag = new PermutationSearch(new Boss(score)).search(); // Estimated graph + System.out.println("Test Estimated CPDAG Graph: " + estimatedCpdag); + System.out.println("Test Estimated Graph size: " + estimatedCpdag.getNodes().size()); + System.out.println("====================================="); - List> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, cpdag, 0.05); + IndependenceTest fisherZTest = new IndTestFisherZ(data, 0.05); + MarkovCheck markovCheck = new MarkovCheck(estimatedCpdag, fisherZTest, ConditioningSetType.LOCAL_MARKOV); // TODO Also try MB for settype + List> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05); List accepts = accepts_rejects.get(0); List rejects = accepts_rejects.get(1); System.out.println("Accepts size: " + accepts.size()); System.out.println("Rejects size: " + rejects.size()); - System.out.println("Estimated Graph size: " + cpdag.getNodes().size()); - System.out.println("True Graph size: " + dag.getNodes().size()); - - - List acceptsPrecision = new ArrayList<>(); List acceptsRecall = new ArrayList<>(); for(Node a: accepts) { - double precision = markovCheck.getPrecisionOrRecallOnMarkovBlanketGraph(a, cpdag, dag, true); - double recall = markovCheck.getPrecisionOrRecallOnMarkovBlanketGraph(a, cpdag, dag, false); + double precision = markovCheck.getPrecisionOrRecallOnMarkovBlanketGraph(a, estimatedCpdag, trueGraph, true); + double recall = markovCheck.getPrecisionOrRecallOnMarkovBlanketGraph(a, estimatedCpdag, trueGraph, false); acceptsPrecision.add(precision); acceptsRecall.add(recall); } - System.out.println("Accepts Precissions: " + acceptsPrecision); + System.out.println("Accepts Precisions: " + acceptsPrecision); System.out.println("Accepts Recall: " + acceptsRecall); + System.out.println("****************************************************"); } From d3db254fd0e40c23c0c6f646676c1f4a21438187 Mon Sep 17 00:00:00 2001 From: vbcwonderland Date: Mon, 8 Apr 2024 21:24:22 -0400 Subject: [PATCH 4/6] nit --- tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphNode.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphNode.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphNode.java index b624c4a66b..d621dd50b1 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphNode.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphNode.java @@ -47,7 +47,7 @@ public class GraphNode implements Node { /** * The name of the node. */ - private String name = "??"; // TODO VBC: can we change this into a seriel number + private String name = "??"; /** * The type of the node. * From e7ea86f70c3fd62b03ae23672ef6383469cab37b Mon Sep 17 00:00:00 2001 From: vbcwonderland Date: Mon, 8 Apr 2024 21:28:36 -0400 Subject: [PATCH 5/6] nit fix imports --- .../src/main/java/edu/cmu/tetrad/search/MarkovCheck.java | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java index b31cf0e689..56c858d051 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java @@ -2,7 +2,11 @@ import edu.cmu.tetrad.data.GeneralAndersonDarlingTest; import edu.cmu.tetrad.data.Knowledge; -import edu.cmu.tetrad.graph.*; +import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.graph.GraphUtils; +import edu.cmu.tetrad.graph.IndependenceFact; +import edu.cmu.tetrad.graph.Node; +import edu.cmu.tetrad.graph.Edge; import edu.cmu.tetrad.search.test.*; import edu.cmu.tetrad.util.SublistGenerator; import edu.cmu.tetrad.util.TetradLogger; From 49c96b0d9fd7e7040d55c26fa18e557e544b1e9c Mon Sep 17 00:00:00 2001 From: vbcwonderland Date: Wed, 17 Apr 2024 15:29:29 -0400 Subject: [PATCH 6/6] MB part done --- .../statistic/OrientationPrecision.java | 2 +- .../statistic/OrientationRecall.java | 2 +- .../statistic/utils/OrientationConfusion.java | 2 +- .../main/java/edu/cmu/tetrad/graph/Edge.java | 4 +- .../edu/cmu/tetrad/graph/EdgeListGraph.java | 33 ++++--- .../java/edu/cmu/tetrad/graph/GraphUtils.java | 2 + .../edu/cmu/tetrad/search/MarkovCheck.java | 89 +++++++++++++++---- .../edu/cmu/tetrad/test/TestCheckMarkov.java | 3 +- 8 files changed, 103 insertions(+), 34 deletions(-) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/OrientationPrecision.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/OrientationPrecision.java index 41d8ec6262..8d5e67cdd8 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/OrientationPrecision.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/OrientationPrecision.java @@ -12,7 +12,7 @@ * @author bryanandrews, osephramsey * @version $Id: $Id */ -public class OrientationPrecision implements Statistic { +public class OrientationPrecision implements Statistic { // TODO VBC: is this one we want to use? @Serial private static final long serialVersionUID = 23L; diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/OrientationRecall.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/OrientationRecall.java index 3c9797736c..f18233479f 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/OrientationRecall.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/OrientationRecall.java @@ -13,7 +13,7 @@ * compared to the true graph. It calculates the ratio of true positive orientations to the sum of true positive and * false negative orientations. */ -public class OrientationRecall implements Statistic { +public class OrientationRecall implements Statistic { // TODO VBC: use this? @Serial private static final long serialVersionUID = 23L; diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/utils/OrientationConfusion.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/utils/OrientationConfusion.java index eadb2438be..3c15738b15 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/utils/OrientationConfusion.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/utils/OrientationConfusion.java @@ -39,7 +39,7 @@ public class OrientationConfusion { * @param truth a {@link edu.cmu.tetrad.graph.Graph} object * @param est a {@link edu.cmu.tetrad.graph.Graph} object */ - public OrientationConfusion(Graph truth, Graph est) { + public OrientationConfusion(Graph truth, Graph est) { // TODO VBC: is this one we want to use? this.tp = 0; this.fp = 0; this.fn = 0; diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Edge.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Edge.java index 8a5abd0ca6..4f60819abf 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Edge.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Edge.java @@ -342,7 +342,7 @@ public final String toString() { _type = new StringBuilder("no edge"); break; case ta: - _type = new StringBuilder("-->"); + _type = new StringBuilder("-->"); // TODO VBC: or shall we just compare edge by strings break; case at: _type = new StringBuilder("<--"); @@ -439,7 +439,7 @@ public final boolean equals(Object o) { * @param _edge a {@link edu.cmu.tetrad.graph.Edge} object * @return a int */ - public int compareTo(Edge _edge) { + public int compareTo(Edge _edge) { // TODO VBC: seems only comparing the edpoint not the direction? int comp1 = getNode1().compareTo(_edge.getNode1()); if (comp1 != 0) { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/EdgeListGraph.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/EdgeListGraph.java index 8a76b3f297..f5eada618d 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/EdgeListGraph.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/EdgeListGraph.java @@ -612,14 +612,15 @@ public List getAdjacentNodes(Node node) { Set edges = this.edgeLists.get(node); Set adj = new HashSet<>(); - for (Edge edge : edges) { - if (edge == null) { - continue; - } + if (edges != null) { + for (Edge edge : edges) { + if (edge == null) { + continue; + } - adj.add(edge.getDistalNode(node)); + adj.add(edge.getDistalNode(node)); + } } - return new ArrayList<>(adj); } @@ -741,13 +742,25 @@ public boolean addEdge(Edge edge) { // Someoone may have changed the name of one of these variables, in which // case we need to reconstitute the edgeLists map, since the name of a // node is used part of the definition of node equality. - if (!edgeLists.containsKey(edge.getNode1()) || !edgeLists.containsKey(edge.getNode2())) { + Node node1 = edge.getNode1(); + Node node2 = edge.getNode2(); + // System.out.println("Real Before: " + edgeLists); + if (!edgeLists.containsKey(node1) || !edgeLists.containsKey(node2)) { this.edgeLists = new HashMap<>(this.edgeLists); } - - this.edgeLists.get(edge.getNode1()).add(edge); - this.edgeLists.get(edge.getNode2()).add(edge); + // System.out.println("Before Adding: " + edgeLists); + if (this.edgeLists.get(node1) == null ) { + // System.out.println("Missing node1 is not in edgeLists: " + node1); + this.edgeLists.put(node1, new HashSet<>()); + } + if (this.edgeLists.get(node2) == null ) { + // System.out.println("Missing node2 is not in edgeLists: " + node2); + this.edgeLists.put(node2, new HashSet<>()); + } + this.edgeLists.get(node1).add(edge); + this.edgeLists.get(node2).add(edge); this.edgesSet.add(edge); + // System.out.println("After: " + edgeLists); this.parentsHash.remove(edge.getNode1()); this.parentsHash.remove(edge.getNode2()); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphUtils.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphUtils.java index 56de316a58..46f0d96c9c 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphUtils.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphUtils.java @@ -117,6 +117,8 @@ public static Graph markovBlanketSubgraph(Node target, Graph graph) { // TODO VB for (int i = 0; i < mbList.size(); i++) { for (int j = i + 1; j < mbList.size(); j++) { + List edges = graph.getEdges(mbList.get(i), mbList.get(j)); + // System.out.println("Add edges between!!!! " + mbList.get(i) + " " + mbList.get(j)); for (Edge e : graph.getEdges(mbList.get(i), mbList.get(j))) { mbGraph.addEdge(e); } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java index 56c858d051..c99807efef 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java @@ -2,11 +2,7 @@ import edu.cmu.tetrad.data.GeneralAndersonDarlingTest; import edu.cmu.tetrad.data.Knowledge; -import edu.cmu.tetrad.graph.Graph; -import edu.cmu.tetrad.graph.GraphUtils; -import edu.cmu.tetrad.graph.IndependenceFact; -import edu.cmu.tetrad.graph.Node; -import edu.cmu.tetrad.graph.Edge; +import edu.cmu.tetrad.graph.*; import edu.cmu.tetrad.search.test.*; import edu.cmu.tetrad.util.SublistGenerator; import edu.cmu.tetrad.util.TetradLogger; @@ -273,41 +269,98 @@ public List> getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(Ind return accepts_rejects; } + private Graph getMarkovBlanketSubgraph(Graph graph, Node targetNode) { + EdgeListGraph g = new EdgeListGraph(graph); + Set mbNodes = GraphUtils.markovBlanket(targetNode, g); + mbNodes.add(targetNode); + return g.subgraph(new ArrayList<>(mbNodes)); + } + + + public Double getPrecisionOrRecallOnMarkovBlanketGraph(Node x, Graph estimatedGraph, Graph trueGraph, boolean getPrecision) { List singleNode = Arrays.asList(x); // Lookup graph is the same structure as trueGraph's structure but node objects replaced by estimated graph nodes. Graph lookupGraph = GraphUtils.replaceNodes(trueGraph, estimatedGraph.getNodes()); - - // TODO VBC use the recurssion method - Graph xMBLookupGraph = GraphUtils.trimGraph(singleNode, lookupGraph, 3); - Set xMBLookupGraphEdges = xMBLookupGraph.getEdges(); System.out.println("@@@@@@@@@@@@@@@@"); + System.out.println("Node: " + x); System.out.println("True Graph:" + trueGraph); - System.out.println("LookupGraph:" + lookupGraph); // this print should be the same as the true graph + System.out.println("LookupGraph:" + lookupGraph); // print should look the same as the true graph + + // TODO VBC: The Trim method is the most accurate in terms of all nodes and the edges + Graph RecommendedxMBLookupGraph = getMarkovBlanketSubgraph(lookupGraph, x); // Recommended, not working + Graph xMBLookupGraph = GraphUtils.markovBlanketSubgraph(x, lookupGraph); // TODO VBC: this one should include the target node + Graph TrimxMBLookupGraph = GraphUtils.trimGraph(singleNode, lookupGraph, 3); // Best + Set xMBLookupGraphEdges = xMBLookupGraph.getEdges(); System.out.println("xMBLookupGraphEdges size: " + xMBLookupGraphEdges.size()); System.out.println("xMBLookupGraph Nodes size: " + xMBLookupGraph.getNodes().size()); System.out.println("xMBLookupGraph:" + xMBLookupGraph); // The MB trim of the lookup graph, so it should be a subset of the lookup graph print + System.out.println("RecommendedxMBLookupGraph:" + RecommendedxMBLookupGraph); // The MB trim of the lookup graph, so it should be a subset of the lookup graph print + System.out.println("TrimxMBLookupGraph:" + TrimxMBLookupGraph); // The MB trim of the lookup graph, so it should be a subset of the lookup graph print // Get Markov Blanket Subgraph for this node x. - Graph xMBEstimatedGraph = GraphUtils.trimGraph(singleNode, estimatedGraph, 3); + // Graph xMBEstimatedGraph = getMarkovBlanketSubgraph(estimatedGraph, x); + Graph xMBEstimatedGraph = GraphUtils.markovBlanketSubgraph(x, estimatedGraph); + // Graph xMBEstimatedGraph = GraphUtils.trimGraph(singleNode, estimatedGraph, 3); Set xMBEstimatedGraphEdges = xMBEstimatedGraph.getEdges(); System.out.println("xMBEstimatedGraphEdges size: " + xMBEstimatedGraphEdges.size()); System.out.println("xMBEstimatedGraph Nodes size: " + xMBEstimatedGraph.getNodes().size()); System.out.println("xMBEstimatedGraph:" + xMBEstimatedGraph); // This should be compared with the xMBLookupGraph System.out.println("@@@@@@@@@@@@@@@@"); - HashSet truePositive = new HashSet<>(xMBEstimatedGraphEdges); - // TODO VBC: QUESTION FOr DISCUSSION - // Here it would only be retained if the points and direction of an edge is exactly the same. - // Do we want to only check for points? If not, a lot wrong/no direction edges would be filtered out at this step. - truePositive.retainAll(xMBLookupGraphEdges); + HashSet truePositive = new HashSet<>(); + HashSet falsePositive = new HashSet<>(); + HashSet falseNegative = new HashSet<>(); + Set trueGraphEdgesEdges = trueGraph.getEdges(); + Set estimatedGraphEdgesEdges = estimatedGraph.getEdges(); + if (trueGraphEdgesEdges != null && estimatedGraphEdgesEdges != null) { + for (Edge te: trueGraphEdgesEdges) { + for (Edge ee: estimatedGraphEdgesEdges) { + // True Graph's Edge info + Node teNode1 = te.getNode1(); + Node teNode2 = te.getNode1(); + Endpoint teEndpoint1 = te.getEndpoint1(); + Endpoint teEndpoint2 = te.getEndpoint2(); + // Estimated Graph's Edge info + Node eeNode1 = te.getNode1(); + Node eeNode2 = te.getNode1(); + Endpoint eeEndpoint1 = ee.getEndpoint1(); + Endpoint eeEndpoint2 = ee.getEndpoint2(); + boolean isSameNode1 = areSame(teNode1, eeNode1); + boolean isSameNode2 = areSame(teNode2, eeNode2); + + // EdgeTypeProbability.EdgeType teType = te.getEdgeTypeProbabilities().getFirst().getEdgeType(); + + // If both n1 n2 are the same, compare the endpoint1 endpoint2 + if (isSameNode1 && isSameNode2) { + // if (teEndpoint1.compareTo(eeEndpoint1)) + // QUESTION: // TODO VBC: seems Edge#compareTo() only comparing the node itself not the endpoint? + // QUESTION: do we only care about edge type here? + + } - double precision = (double) truePositive.size() / xMBLookupGraphEdges.size(); - double recall = (double) truePositive.size() / xMBEstimatedGraphEdges.size(); + + + } + } + } + // TODO VBC: + // Logic of comparing true graph with estimated graph + + double precision = (double) truePositive.size() / (truePositive.size() + falsePositive.size()); + double recall = (double) truePositive.size() / (truePositive.size() + falseNegative.size()); return getPrecision ? precision : recall; } + private boolean areSame(Node n1, Node n2) { + // TODO VBC: Compare the Nodes are of the same. + // QUESTION: the compareTo() method in Node class is very complicated, involves Lag etc. is that what we want to use? + // or shall we just compare by names of these nodes + + return n1.getName().equals(n2.getName()); + } + /** * Returns the variables of the independence test. diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java index c6399b0d46..a24b305707 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java @@ -115,6 +115,7 @@ public void test2() { @Test public void testPrecissionRecallForLocal() { + // TODO also use randome graph then convert to cpday learn from Test Graph Utils. write a diff test case for this. Graph trueGraph = RandomGraph.randomDag(10, 0, 10, 100, 100, 100, false); System.out.println("Test True Graph: " + trueGraph); System.out.println("Test True Graph size: " + trueGraph.getNodes().size()); @@ -130,7 +131,7 @@ public void testPrecissionRecallForLocal() { System.out.println("====================================="); IndependenceTest fisherZTest = new IndTestFisherZ(data, 0.05); - MarkovCheck markovCheck = new MarkovCheck(estimatedCpdag, fisherZTest, ConditioningSetType.LOCAL_MARKOV); // TODO Also try MB for settype + MarkovCheck markovCheck = new MarkovCheck(estimatedCpdag, fisherZTest, ConditioningSetType.MARKOV_BLANKET); List> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05); List accepts = accepts_rejects.get(0); List rejects = accepts_rejects.get(1);