diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/OrientationPrecision.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/OrientationPrecision.java index 41d8ec6262..8d5e67cdd8 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/OrientationPrecision.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/OrientationPrecision.java @@ -12,7 +12,7 @@ * @author bryanandrews, osephramsey * @version $Id: $Id */ -public class OrientationPrecision implements Statistic { +public class OrientationPrecision implements Statistic { // TODO VBC: is this one we want to use? @Serial private static final long serialVersionUID = 23L; diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/OrientationRecall.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/OrientationRecall.java index 3c9797736c..f18233479f 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/OrientationRecall.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/OrientationRecall.java @@ -13,7 +13,7 @@ * compared to the true graph. It calculates the ratio of true positive orientations to the sum of true positive and * false negative orientations. */ -public class OrientationRecall implements Statistic { +public class OrientationRecall implements Statistic { // TODO VBC: use this? @Serial private static final long serialVersionUID = 23L; diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/utils/OrientationConfusion.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/utils/OrientationConfusion.java index eadb2438be..3c15738b15 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/utils/OrientationConfusion.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/utils/OrientationConfusion.java @@ -39,7 +39,7 @@ public class OrientationConfusion { * @param truth a {@link edu.cmu.tetrad.graph.Graph} object * @param est a {@link edu.cmu.tetrad.graph.Graph} object */ - public OrientationConfusion(Graph truth, Graph est) { + public OrientationConfusion(Graph truth, Graph est) { // TODO VBC: is this one we want to use? this.tp = 0; this.fp = 0; this.fn = 0; diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Edge.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Edge.java index 8a5abd0ca6..4f60819abf 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Edge.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Edge.java @@ -342,7 +342,7 @@ public final String toString() { _type = new StringBuilder("no edge"); break; case ta: - _type = new StringBuilder("-->"); + _type = new StringBuilder("-->"); // TODO VBC: or shall we just compare edge by strings break; case at: _type = new StringBuilder("<--"); @@ -439,7 +439,7 @@ public final boolean equals(Object o) { * @param _edge a {@link edu.cmu.tetrad.graph.Edge} object * @return a int */ - public int compareTo(Edge _edge) { + public int compareTo(Edge _edge) { // TODO VBC: seems only comparing the edpoint not the direction? int comp1 = getNode1().compareTo(_edge.getNode1()); if (comp1 != 0) { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/EdgeListGraph.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/EdgeListGraph.java index 8a76b3f297..f5eada618d 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/EdgeListGraph.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/EdgeListGraph.java @@ -612,14 +612,15 @@ public List getAdjacentNodes(Node node) { Set edges = this.edgeLists.get(node); Set adj = new HashSet<>(); - for (Edge edge : edges) { - if (edge == null) { - continue; - } + if (edges != null) { + for (Edge edge : edges) { + if (edge == null) { + continue; + } - adj.add(edge.getDistalNode(node)); + adj.add(edge.getDistalNode(node)); + } } - return new ArrayList<>(adj); } @@ -741,13 +742,25 @@ public boolean addEdge(Edge edge) { // Someoone may have changed the name of one of these variables, in which // case we need to reconstitute the edgeLists map, since the name of a // node is used part of the definition of node equality. - if (!edgeLists.containsKey(edge.getNode1()) || !edgeLists.containsKey(edge.getNode2())) { + Node node1 = edge.getNode1(); + Node node2 = edge.getNode2(); + // System.out.println("Real Before: " + edgeLists); + if (!edgeLists.containsKey(node1) || !edgeLists.containsKey(node2)) { this.edgeLists = new HashMap<>(this.edgeLists); } - - this.edgeLists.get(edge.getNode1()).add(edge); - this.edgeLists.get(edge.getNode2()).add(edge); + // System.out.println("Before Adding: " + edgeLists); + if (this.edgeLists.get(node1) == null ) { + // System.out.println("Missing node1 is not in edgeLists: " + node1); + this.edgeLists.put(node1, new HashSet<>()); + } + if (this.edgeLists.get(node2) == null ) { + // System.out.println("Missing node2 is not in edgeLists: " + node2); + this.edgeLists.put(node2, new HashSet<>()); + } + this.edgeLists.get(node1).add(edge); + this.edgeLists.get(node2).add(edge); this.edgesSet.add(edge); + // System.out.println("After: " + edgeLists); this.parentsHash.remove(edge.getNode1()); this.parentsHash.remove(edge.getNode2()); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphUtils.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphUtils.java index 4074fc28d2..46f0d96c9c 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphUtils.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphUtils.java @@ -103,7 +103,7 @@ public static boolean isClique(Collection set, Graph graph) { * @param graph a DAG, CPDAG, MAG, or PAG. * @return a {@link edu.cmu.tetrad.graph.Graph} object */ - public static Graph markovBlanketSubgraph(Node target, Graph graph) { + public static Graph markovBlanketSubgraph(Node target, Graph graph) { // TODO VBC: @Joe is this the more general method you recommended? Set mb = markovBlanket(target, graph); Graph mbGraph = new EdgeListGraph(); @@ -117,6 +117,8 @@ public static Graph markovBlanketSubgraph(Node target, Graph graph) { for (int i = 0; i < mbList.size(); i++) { for (int j = i + 1; j < mbList.size(); j++) { + List edges = graph.getEdges(mbList.get(i), mbList.get(j)); + // System.out.println("Add edges between!!!! " + mbList.get(i) + " " + mbList.get(j)); for (Edge e : graph.getEdges(mbList.get(i), mbList.get(j))) { mbGraph.addEdge(e); } @@ -2253,7 +2255,7 @@ public static Graph trimGraph(List targets, Graph graph, int trimmingStyle graph = trimAdjacentToTarget(targets, graph); break; case 3: - graph = trimMarkovBlanketGraph(targets, graph); + graph = trimMarkovBlanketGraph(targets, graph); // TODO VBC currently using this break; case 4: graph = trimSemidirected(targets, graph); @@ -2298,7 +2300,7 @@ private static Graph trimAdjacentToTarget(List targets, Graph graph) { * @param graph the original graph from which the Markov blanket graph is derived * @return the trimmed Markov blanket graph */ - private static Graph trimMarkovBlanketGraph(List targets, Graph graph) { + private static Graph trimMarkovBlanketGraph(List targets, Graph graph) { // TODO vbc this is Graph mbDag = new EdgeListGraph(graph); M: diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java index 11893865e8..c99807efef 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java @@ -2,10 +2,7 @@ import edu.cmu.tetrad.data.GeneralAndersonDarlingTest; import edu.cmu.tetrad.data.Knowledge; -import edu.cmu.tetrad.graph.Graph; -import edu.cmu.tetrad.graph.GraphUtils; -import edu.cmu.tetrad.graph.IndependenceFact; -import edu.cmu.tetrad.graph.Node; +import edu.cmu.tetrad.graph.*; import edu.cmu.tetrad.search.test.*; import edu.cmu.tetrad.util.SublistGenerator; import edu.cmu.tetrad.util.TetradLogger; @@ -251,6 +248,119 @@ public Double checkAgainstAndersonDarlingTest(List pValues) { return generalAndersonDarlingTest.getP(); } + public List> getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(IndependenceTest independenceTest, Graph graph, Double threshold) { + // when calling, default reject null as <=0.05 + List> accepts_rejects = new ArrayList<>(); + List accepts = new ArrayList<>(); + List rejects = new ArrayList<>(); + List allNodes = graph.getNodes(); + for (Node x : allNodes) { + List localIndependenceFacts = getLocalIndependenceFacts(x); + List localPValues = getLocalPValues(independenceTest, localIndependenceFacts); + Double ADTest = checkAgainstAndersonDarlingTest(localPValues); + if (ADTest <= threshold) { + rejects.add(x); + } else { + accepts.add(x); + } + } + accepts_rejects.add(accepts); + accepts_rejects.add(rejects); + return accepts_rejects; + } + + private Graph getMarkovBlanketSubgraph(Graph graph, Node targetNode) { + EdgeListGraph g = new EdgeListGraph(graph); + Set mbNodes = GraphUtils.markovBlanket(targetNode, g); + mbNodes.add(targetNode); + return g.subgraph(new ArrayList<>(mbNodes)); + } + + + + public Double getPrecisionOrRecallOnMarkovBlanketGraph(Node x, Graph estimatedGraph, Graph trueGraph, boolean getPrecision) { + List singleNode = Arrays.asList(x); + // Lookup graph is the same structure as trueGraph's structure but node objects replaced by estimated graph nodes. + Graph lookupGraph = GraphUtils.replaceNodes(trueGraph, estimatedGraph.getNodes()); + System.out.println("@@@@@@@@@@@@@@@@"); + System.out.println("Node: " + x); + System.out.println("True Graph:" + trueGraph); + System.out.println("LookupGraph:" + lookupGraph); // print should look the same as the true graph + + // TODO VBC: The Trim method is the most accurate in terms of all nodes and the edges + Graph RecommendedxMBLookupGraph = getMarkovBlanketSubgraph(lookupGraph, x); // Recommended, not working + Graph xMBLookupGraph = GraphUtils.markovBlanketSubgraph(x, lookupGraph); // TODO VBC: this one should include the target node + Graph TrimxMBLookupGraph = GraphUtils.trimGraph(singleNode, lookupGraph, 3); // Best + Set xMBLookupGraphEdges = xMBLookupGraph.getEdges(); + + System.out.println("xMBLookupGraphEdges size: " + xMBLookupGraphEdges.size()); + System.out.println("xMBLookupGraph Nodes size: " + xMBLookupGraph.getNodes().size()); + System.out.println("xMBLookupGraph:" + xMBLookupGraph); // The MB trim of the lookup graph, so it should be a subset of the lookup graph print + System.out.println("RecommendedxMBLookupGraph:" + RecommendedxMBLookupGraph); // The MB trim of the lookup graph, so it should be a subset of the lookup graph print + System.out.println("TrimxMBLookupGraph:" + TrimxMBLookupGraph); // The MB trim of the lookup graph, so it should be a subset of the lookup graph print + + // Get Markov Blanket Subgraph for this node x. + // Graph xMBEstimatedGraph = getMarkovBlanketSubgraph(estimatedGraph, x); + Graph xMBEstimatedGraph = GraphUtils.markovBlanketSubgraph(x, estimatedGraph); + // Graph xMBEstimatedGraph = GraphUtils.trimGraph(singleNode, estimatedGraph, 3); + Set xMBEstimatedGraphEdges = xMBEstimatedGraph.getEdges(); + System.out.println("xMBEstimatedGraphEdges size: " + xMBEstimatedGraphEdges.size()); + System.out.println("xMBEstimatedGraph Nodes size: " + xMBEstimatedGraph.getNodes().size()); + System.out.println("xMBEstimatedGraph:" + xMBEstimatedGraph); // This should be compared with the xMBLookupGraph + System.out.println("@@@@@@@@@@@@@@@@"); + + HashSet truePositive = new HashSet<>(); + HashSet falsePositive = new HashSet<>(); + HashSet falseNegative = new HashSet<>(); + Set trueGraphEdgesEdges = trueGraph.getEdges(); + Set estimatedGraphEdgesEdges = estimatedGraph.getEdges(); + if (trueGraphEdgesEdges != null && estimatedGraphEdgesEdges != null) { + for (Edge te: trueGraphEdgesEdges) { + for (Edge ee: estimatedGraphEdgesEdges) { + // True Graph's Edge info + Node teNode1 = te.getNode1(); + Node teNode2 = te.getNode1(); + Endpoint teEndpoint1 = te.getEndpoint1(); + Endpoint teEndpoint2 = te.getEndpoint2(); + // Estimated Graph's Edge info + Node eeNode1 = te.getNode1(); + Node eeNode2 = te.getNode1(); + Endpoint eeEndpoint1 = ee.getEndpoint1(); + Endpoint eeEndpoint2 = ee.getEndpoint2(); + boolean isSameNode1 = areSame(teNode1, eeNode1); + boolean isSameNode2 = areSame(teNode2, eeNode2); + + // EdgeTypeProbability.EdgeType teType = te.getEdgeTypeProbabilities().getFirst().getEdgeType(); + + // If both n1 n2 are the same, compare the endpoint1 endpoint2 + if (isSameNode1 && isSameNode2) { + // if (teEndpoint1.compareTo(eeEndpoint1)) + // QUESTION: // TODO VBC: seems Edge#compareTo() only comparing the node itself not the endpoint? + // QUESTION: do we only care about edge type here? + + } + + + + } + } + } + // TODO VBC: + // Logic of comparing true graph with estimated graph + + double precision = (double) truePositive.size() / (truePositive.size() + falsePositive.size()); + double recall = (double) truePositive.size() / (truePositive.size() + falseNegative.size()); + return getPrecision ? precision : recall; + } + + private boolean areSame(Node n1, Node n2) { + // TODO VBC: Compare the Nodes are of the same. + // QUESTION: the compareTo() method in Node class is very complicated, involves Lag etc. is that what we want to use? + // or shall we just compare by names of these nodes + + return n1.getName().equals(n2.getName()); + } + /** * Returns the variables of the independence test. diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java index 0d2dbb54f1..a24b305707 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java @@ -14,8 +14,10 @@ import edu.cmu.tetrad.sem.SemIm; import edu.cmu.tetrad.sem.SemPm; import edu.cmu.tetrad.util.NumberFormatUtil; +import edu.cmu.tetrad.util.Parameters; import org.junit.Test; +import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; import java.util.List; @@ -110,4 +112,44 @@ public void test2() { System.out.println(markovCheck.getMarkovCheckRecordString()); } + + @Test + public void testPrecissionRecallForLocal() { + // TODO also use randome graph then convert to cpday learn from Test Graph Utils. write a diff test case for this. + Graph trueGraph = RandomGraph.randomDag(10, 0, 10, 100, 100, 100, false); + System.out.println("Test True Graph: " + trueGraph); + System.out.println("Test True Graph size: " + trueGraph.getNodes().size()); + + SemPm pm = new SemPm(trueGraph); + SemIm im = new SemIm(pm, new Parameters()); + DataSet data = im.simulateData(1000, false); + edu.cmu.tetrad.search.score.SemBicScore score = new SemBicScore(data, false); + score.setPenaltyDiscount(2); + Graph estimatedCpdag = new PermutationSearch(new Boss(score)).search(); // Estimated graph + System.out.println("Test Estimated CPDAG Graph: " + estimatedCpdag); + System.out.println("Test Estimated Graph size: " + estimatedCpdag.getNodes().size()); + System.out.println("====================================="); + + IndependenceTest fisherZTest = new IndTestFisherZ(data, 0.05); + MarkovCheck markovCheck = new MarkovCheck(estimatedCpdag, fisherZTest, ConditioningSetType.MARKOV_BLANKET); + List> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05); + List accepts = accepts_rejects.get(0); + List rejects = accepts_rejects.get(1); + System.out.println("Accepts size: " + accepts.size()); + System.out.println("Rejects size: " + rejects.size()); + + List acceptsPrecision = new ArrayList<>(); + List acceptsRecall = new ArrayList<>(); + for(Node a: accepts) { + double precision = markovCheck.getPrecisionOrRecallOnMarkovBlanketGraph(a, estimatedCpdag, trueGraph, true); + double recall = markovCheck.getPrecisionOrRecallOnMarkovBlanketGraph(a, estimatedCpdag, trueGraph, false); + acceptsPrecision.add(precision); + acceptsRecall.add(recall); + } + System.out.println("Accepts Precisions: " + acceptsPrecision); + System.out.println("Accepts Recall: " + acceptsRecall); + System.out.println("****************************************************"); + + + } }