getBootstrapGraphs();
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/continuous/dag/Dagma.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/continuous/dag/Dagma.java
index a47570bffd..b0530b1f90 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/continuous/dag/Dagma.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/continuous/dag/Dagma.java
@@ -13,6 +13,7 @@
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.search.score.Score;
+import edu.cmu.tetrad.search.utils.LogUtilsSearch;
import edu.cmu.tetrad.util.Parameters;
import edu.cmu.tetrad.util.Params;
import edu.cmu.tetrad.util.TetradLogger;
@@ -50,7 +51,7 @@ public Graph search(DataModel dataSet, Parameters parameters) {
search.setCpdag(parameters.getBoolean(Params.CPDAG));
Graph graph = search.search();
TetradLogger.getInstance().forceLogMessage(graph.toString());
-
+ LogUtilsSearch.stampWithBic(graph, dataSet);
return graph;
} else {
Dagma algorithm = new Dagma();
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/continuous/dag/DirectLingam.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/continuous/dag/DirectLingam.java
index d30c33400d..339201ff8d 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/continuous/dag/DirectLingam.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/continuous/dag/DirectLingam.java
@@ -56,7 +56,6 @@ public Graph search(DataModel dataSet, Parameters parameters) {
edu.cmu.tetrad.search.DirectLingam search = new edu.cmu.tetrad.search.DirectLingam(data, score);
Graph graph = search.search();
TetradLogger.getInstance().forceLogMessage(graph.toString());
-
LogUtilsSearch.stampWithBic(graph, dataSet);
return graph;
} else {
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/multi/ImagesBoss.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/multi/ImagesBoss.java
new file mode 100644
index 0000000000..9aaa6bdc8c
--- /dev/null
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/multi/ImagesBoss.java
@@ -0,0 +1,218 @@
+package edu.cmu.tetrad.algcomparison.algorithm.multi;
+
+import edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm;
+import edu.cmu.tetrad.algcomparison.algorithm.oracle.cpdag.Fges;
+import edu.cmu.tetrad.algcomparison.independence.IndependenceWrapper;
+import edu.cmu.tetrad.algcomparison.score.ScoreWrapper;
+import edu.cmu.tetrad.algcomparison.score.SemBicScore;
+import edu.cmu.tetrad.algcomparison.utils.HasKnowledge;
+import edu.cmu.tetrad.algcomparison.utils.UsesScoreWrapper;
+import edu.cmu.tetrad.annotation.AlgType;
+import edu.cmu.tetrad.annotation.Bootstrapping;
+import edu.cmu.tetrad.data.*;
+import edu.cmu.tetrad.graph.EdgeListGraph;
+import edu.cmu.tetrad.graph.Graph;
+import edu.cmu.tetrad.search.Boss;
+import edu.cmu.tetrad.search.PermutationSearch;
+import edu.cmu.tetrad.search.score.ImagesScore;
+import edu.cmu.tetrad.search.score.Score;
+import edu.cmu.tetrad.search.utils.TsUtils;
+import edu.cmu.tetrad.util.Parameters;
+import edu.cmu.tetrad.util.Params;
+import edu.pitt.dbmi.algo.resampling.GeneralResamplingTest;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.LinkedList;
+import java.util.List;
+
+/**
+ * Wraps the IMaGES algorithm for continuous variables. This version uses the BOSS algorithm in place of FGES.
+ *
+ * Requires that the parameter 'randomSelectionSize' be set to indicate how many datasets should be taken at a time
+ * (randomly). This cannot given multiple values.
+ *
+ * @author josephramsey
+ */
+@edu.cmu.tetrad.annotation.Algorithm(
+ name = "IMaGES-BOSS",
+ command = "images-boss",
+ algoType = AlgType.forbid_latent_common_causes,
+ dataType = DataType.All
+)
+@Bootstrapping
+public class ImagesBoss implements MultiDataSetAlgorithm, HasKnowledge, UsesScoreWrapper {
+
+ private static final long serialVersionUID = 23L;
+ private Knowledge knowledge = new Knowledge();
+
+ private ScoreWrapper score = new SemBicScore();
+
+ public ImagesBoss(ScoreWrapper score) {
+ this.score = score;
+ }
+
+ public ImagesBoss() {
+ }
+
+ @Override
+ public Graph search(List dataSets, Parameters parameters) {
+ int meta = parameters.getInt(Params.IMAGES_META_ALG);
+
+ if (parameters.getInt(Params.NUMBER_RESAMPLING) < 1) {
+ List _dataSets = new ArrayList<>();
+
+ if (parameters.getInt(Params.TIME_LAG) > 0) {
+ for (DataModel dataSet : dataSets) {
+ DataSet timeSeries = TsUtils.createLagData((DataSet) dataSet, parameters.getInt(Params.TIME_LAG));
+ if (dataSet.getName() != null) {
+ timeSeries.setName(dataSet.getName());
+ }
+ _dataSets.add(timeSeries);
+ }
+
+ dataSets = _dataSets;
+ }
+
+ List scores = new ArrayList<>();
+
+ for (DataModel dataModel : dataSets) {
+ Score s = score.getScore(dataModel, parameters);
+ scores.add(s);
+ }
+
+ ImagesScore score = new ImagesScore(scores);
+
+ if (meta == 1) {
+ PermutationSearch search = new PermutationSearch(new Boss(score));
+ search.setSeed(parameters.getLong(Params.SEED));
+// edu.cmu.tetrad.search.Fges search = new edu.cmu.tetrad.search.Fges(score);
+ search.setKnowledge(this.knowledge);
+// search.setVerbose(parameters.getBoolean(Params.VERBOSE));
+ return search.search();
+ } else if (meta == 2) {
+ PermutationSearch search = new PermutationSearch(new Boss(score));
+ search.setKnowledge(this.knowledge);
+ return search.search();
+ } else {
+ throw new IllegalArgumentException("Unrecognized meta option: " + meta);
+ }
+ } else {
+ ImagesBoss imagesSemBic = new ImagesBoss();
+
+ List dataSets2 = new ArrayList<>();
+
+ for (DataModel dataModel : dataSets) {
+ dataSets2.add((DataSet) dataModel);
+ }
+
+ List _dataSets = new ArrayList<>();
+
+ if (parameters.getInt(Params.TIME_LAG) > 0) {
+ for (DataSet dataSet : dataSets2) {
+ DataSet timeSeries = TsUtils.createLagData(dataSet, parameters.getInt(Params.TIME_LAG));
+ if (dataSet.getName() != null) {
+ timeSeries.setName(dataSet.getName());
+ }
+ _dataSets.add(timeSeries);
+ }
+
+ dataSets2 = _dataSets;
+ }
+
+ GeneralResamplingTest search = new GeneralResamplingTest(
+ dataSets2,
+ imagesSemBic,
+ parameters.getInt(Params.NUMBER_RESAMPLING),
+ parameters.getDouble(Params.PERCENT_RESAMPLE_SIZE),
+ parameters.getBoolean(Params.RESAMPLING_WITH_REPLACEMENT), parameters.getInt(Params.RESAMPLING_ENSEMBLE), parameters.getBoolean(Params.ADD_ORIGINAL_DATASET));
+ search.setParameters(parameters);
+ search.setVerbose(parameters.getBoolean(Params.VERBOSE));
+ search.setKnowledge(this.knowledge);
+ search.setScoreWrapper(score);
+ return search.search();
+ }
+ }
+
+ @Override
+ public Graph search(DataModel dataSet, Parameters parameters) {
+ if (parameters.getInt(Params.NUMBER_RESAMPLING) < 1) {
+ return search(Collections.singletonList(SimpleDataLoader.getMixedDataSet(dataSet)), parameters);
+ } else {
+ ImagesBoss images = new ImagesBoss();
+
+ List dataSets = Collections.singletonList(SimpleDataLoader.getMixedDataSet(dataSet));
+ GeneralResamplingTest search = new GeneralResamplingTest(dataSets,
+ images,
+ parameters.getInt(Params.NUMBER_RESAMPLING),
+ parameters.getDouble(Params.PERCENT_RESAMPLE_SIZE),
+ parameters.getBoolean(Params.RESAMPLING_WITH_REPLACEMENT),
+ parameters.getInt(Params.RESAMPLING_ENSEMBLE),
+ parameters.getBoolean(Params.ADD_ORIGINAL_DATASET));
+
+ if (score == null) {
+ System.out.println();
+ }
+
+ search.setParameters(parameters);
+ search.setVerbose(parameters.getBoolean(Params.VERBOSE));
+ search.setScoreWrapper(score);
+ return search.search();
+ }
+ }
+
+ @Override
+ public Graph getComparisonGraph(Graph graph) {
+ return new EdgeListGraph(graph);
+ }
+
+ @Override
+ public String getDescription() {
+ return "IMaGES";
+ }
+
+ @Override
+ public DataType getDataType() {
+ return DataType.All;
+ }
+
+ @Override
+ public List getParameters() {
+ List parameters = new LinkedList<>();
+ parameters.addAll(new SemBicScore().getParameters());
+
+ parameters.addAll((new Fges()).getParameters());
+ parameters.add(Params.RANDOM_SELECTION_SIZE);
+ parameters.add(Params.TIME_LAG);
+ parameters.add(Params.IMAGES_META_ALG);
+ parameters.add(Params.SEED);
+ parameters.add(Params.VERBOSE);
+
+ return parameters;
+ }
+
+ @Override
+ public Knowledge getKnowledge() {
+ return this.knowledge;
+ }
+
+ @Override
+ public void setKnowledge(Knowledge knowledge) {
+ this.knowledge = new Knowledge((Knowledge) knowledge);
+ }
+
+ @Override
+ public ScoreWrapper getScoreWrapper() {
+ return this.score;
+ }
+
+ @Override
+ public void setScoreWrapper(ScoreWrapper score) {
+ this.score = score;
+ }
+
+ @Override
+ public void setIndTestWrapper(IndependenceWrapper test) {
+ // Not used.
+ }
+}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Boss.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Boss.java
index a7c92da553..7d9a4a3868 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Boss.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Boss.java
@@ -42,7 +42,7 @@ public class Boss implements Algorithm, UsesScoreWrapper, HasKnowledge,
private ScoreWrapper score;
private Knowledge knowledge = new Knowledge();
private List bootstrapGraphs = new ArrayList<>();
-
+ private long seed = 01;
public Boss() {
// Used in reflection; do not delete.
@@ -54,6 +54,8 @@ public Boss(ScoreWrapper score) {
@Override
public Graph search(DataModel dataModel, Parameters parameters) {
+ this.seed = parameters.getLong(Params.SEED);
+
if (parameters.getInt(Params.NUMBER_RESAMPLING) < 1) {
if (parameters.getInt(Params.TIME_LAG) > 0) {
DataSet dataSet = (DataSet) dataModel;
@@ -68,6 +70,7 @@ public Graph search(DataModel dataModel, Parameters parameters) {
Score score = this.score.getScore(dataModel, parameters);
edu.cmu.tetrad.search.Boss boss = new edu.cmu.tetrad.search.Boss(score);
+
boss.setUseBes(parameters.getBoolean(Params.USE_BES));
boss.setNumStarts(parameters.getInt(Params.NUM_STARTS));
boss.setNumThreads(parameters.getInt(Params.NUM_THREADS));
@@ -75,8 +78,10 @@ public Graph search(DataModel dataModel, Parameters parameters) {
boss.setVerbose(parameters.getBoolean(Params.VERBOSE));
PermutationSearch permutationSearch = new PermutationSearch(boss);
permutationSearch.setKnowledge(this.knowledge);
+ permutationSearch.setSeed(seed);
Graph graph = permutationSearch.search();
- LogUtilsSearch.stampWithScores(graph, dataModel, score);
+ LogUtilsSearch.stampWithScore(graph, score);
+ LogUtilsSearch.stampWithBic(graph, dataModel);
return graph;
} else {
Boss algorithm = new Boss(this.score);
@@ -117,6 +122,7 @@ public List getParameters() {
params.add(Params.TIME_LAG);
params.add(Params.NUM_THREADS);
params.add(Params.USE_DATA_ORDER);
+ params.add(Params.SEED);
params.add(Params.VERBOSE);
return params;
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/BossLingam.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/BossLingam.java
index 4dc585c651..dac01c2201 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/BossLingam.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/BossLingam.java
@@ -67,6 +67,7 @@ public Graph search(DataModel dataModel, Parameters parameters) {
boss.setUseDataOrder(parameters.getBoolean(Params.USE_DATA_ORDER));
boss.setVerbose(parameters.getBoolean(Params.VERBOSE));
PermutationSearch permutationSearch = new PermutationSearch(boss);
+ permutationSearch.setSeed(parameters.getLong(Params.SEED));
permutationSearch.setKnowledge(this.knowledge);
Graph cpdag = permutationSearch.search();
@@ -114,6 +115,7 @@ public List getParameters() {
parameters.add(Params.TIME_LAG);
parameters.add(Params.NUM_THREADS);
parameters.add(Params.USE_DATA_ORDER);
+ parameters.add(Params.SEED);
parameters.add(Params.VERBOSE);
return parameters;
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Cpc.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Cpc.java
index 7c34c9f73b..cfbdfa40ca 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Cpc.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Cpc.java
@@ -65,47 +65,27 @@ public Graph search(DataModel dataModel, Parameters parameters) {
knowledge = timeSeries.getKnowledge();
}
- PcCommon.ConflictRule conflictRule;
-
- switch (parameters.getInt(Params.CONFLICT_RULE)) {
- case 1:
- conflictRule = PcCommon.ConflictRule.PRIORITIZE_EXISTING;
- break;
- case 2:
- conflictRule = PcCommon.ConflictRule.ORIENT_BIDIRECTED;
- break;
- case 3:
- conflictRule = PcCommon.ConflictRule.OVERWRITE_EXISTING;
- break;
- default:
- throw new IllegalArgumentException("Unknown conflict rule: " + parameters.getInt(Params.CONFLICT_RULE));
-
- }
-
-// PcCommon.PcHeuristicType pcHeuristicType;
-//
-// switch (parameters.getInt(Params.PC_HEURISTIC)) {
-// case 0:
-// pcHeuristicType = PcCommon.PcHeuristicType.NONE;
-// break;
-// case 1:
-// pcHeuristicType = PcCommon.PcHeuristicType.HEURISTIC_1;
-// break;
-// case 2:
-// pcHeuristicType = PcCommon.PcHeuristicType.HEURISTIC_2;
-// break;
-// case 3:
-// pcHeuristicType = PcCommon.PcHeuristicType.HEURISTIC_3;
-// break;
-// default:
-// throw new IllegalArgumentException("Unknown conflict rule: " + parameters.getInt(Params.CONFLICT_RULE));
-//
-// }
+ PcCommon.ConflictRule conflictRule = switch (parameters.getInt(Params.CONFLICT_RULE)) {
+ case 1 -> PcCommon.ConflictRule.PRIORITIZE_EXISTING;
+ case 2 -> PcCommon.ConflictRule.ORIENT_BIDIRECTED;
+ case 3 -> PcCommon.ConflictRule.OVERWRITE_EXISTING;
+ default ->
+ throw new IllegalArgumentException("Unknown conflict rule: " + parameters.getInt(Params.CONFLICT_RULE));
+ };
+
+ PcCommon.PcHeuristicType pcHeuristicType = switch (parameters.getInt(Params.PC_HEURISTIC)) {
+ case 0 -> PcCommon.PcHeuristicType.NONE;
+ case 1 -> PcCommon.PcHeuristicType.HEURISTIC_1;
+ case 2 -> PcCommon.PcHeuristicType.HEURISTIC_2;
+ case 3 -> PcCommon.PcHeuristicType.HEURISTIC_3;
+ default ->
+ throw new IllegalArgumentException("Unknown conflict rule: " + parameters.getInt(Params.CONFLICT_RULE));
+ };
edu.cmu.tetrad.search.Cpc search = new edu.cmu.tetrad.search.Cpc(getIndependenceWrapper().getTest(dataModel, parameters));
search.setDepth(parameters.getInt(Params.DEPTH));
search.meekPreventCycles(parameters.getBoolean(Params.MEEK_PREVENT_CYCLES));
-// search.setPcHeuristicType(pcHeuristicType);
+ search.setPcHeuristicType(pcHeuristicType);
search.setVerbose(parameters.getBoolean(Params.VERBOSE));
search.setKnowledge(knowledge);
search.setConflictRule(conflictRule);
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Fas.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Fas.java
index bd8df26b75..babb6958dc 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Fas.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Fas.java
@@ -14,11 +14,13 @@
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphTransforms;
+import edu.cmu.tetrad.search.utils.PcCommon;
import edu.cmu.tetrad.util.Parameters;
import edu.cmu.tetrad.util.Params;
import edu.pitt.dbmi.algo.resampling.GeneralResamplingTest;
import java.io.PrintStream;
+import java.io.Serial;
import java.util.ArrayList;
import java.util.List;
@@ -36,6 +38,7 @@
public class Fas implements Algorithm, HasKnowledge, TakesIndependenceWrapper,
ReturnsBootstrapGraphs {
+ @Serial
private static final long serialVersionUID = 23L;
private IndependenceWrapper test;
private Knowledge knowledge = new Knowledge();
@@ -51,28 +54,18 @@ public Fas(IndependenceWrapper test) {
@Override
public Graph search(DataModel dataSet, Parameters parameters) {
if (parameters.getInt(Params.NUMBER_RESAMPLING) < 1) {
-// PcCommon.PcHeuristicType pcHeuristicType;
-//
-// switch (parameters.getInt(Params.PC_HEURISTIC)) {
-// case 0:
-// pcHeuristicType = PcCommon.PcHeuristicType.NONE;
-// break;
-// case 1:
-// pcHeuristicType = PcCommon.PcHeuristicType.HEURISTIC_1;
-// break;
-// case 2:
-// pcHeuristicType = PcCommon.PcHeuristicType.HEURISTIC_2;
-// break;
-// case 3:
-// pcHeuristicType = PcCommon.PcHeuristicType.HEURISTIC_3;
-// break;
-// default:
-// throw new IllegalArgumentException("Unknown conflict rule: " + parameters.getInt(Params.CONFLICT_RULE));
-// }
+ PcCommon.PcHeuristicType pcHeuristicType = switch (parameters.getInt(Params.PC_HEURISTIC)) {
+ case 0 -> PcCommon.PcHeuristicType.NONE;
+ case 1 -> PcCommon.PcHeuristicType.HEURISTIC_1;
+ case 2 -> PcCommon.PcHeuristicType.HEURISTIC_2;
+ case 3 -> PcCommon.PcHeuristicType.HEURISTIC_3;
+ default ->
+ throw new IllegalArgumentException("Unknown conflict rule: " + parameters.getInt(Params.CONFLICT_RULE));
+ };
edu.cmu.tetrad.search.Fas search = new edu.cmu.tetrad.search.Fas(this.test.getTest(dataSet, parameters));
search.setStable(parameters.getBoolean(Params.STABLE_FAS));
-// search.setPcHeuristicType(pcHeuristicType);
+ search.setPcHeuristicType(pcHeuristicType);
search.setDepth(parameters.getInt(Params.DEPTH));
search.setKnowledge(this.knowledge);
search.setVerbose(parameters.getBoolean(Params.VERBOSE));
@@ -119,7 +112,7 @@ public DataType getDataType() {
public List getParameters() {
List parameters = new ArrayList<>();
parameters.add(Params.DEPTH);
-// parameters.add(Params.PC_HEURISTIC);
+ parameters.add(Params.PC_HEURISTIC);
parameters.add(Params.STABLE_FAS);
parameters.add(Params.VERBOSE);
return parameters;
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Fges.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Fges.java
index d3298a8528..6b0c020a68 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Fges.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Fges.java
@@ -100,11 +100,12 @@ public Graph search(DataModel dataModel, Parameters parameters) {
graph = search.search();
- if (!graph.getAllAttributes().containsKey("BIC")) {
+ if (dataModel.isContinuous() && !graph.getAllAttributes().containsKey("BIC")) {
graph.addAttribute("BIC", new BicEst().getValue(null, graph, dataModel));
}
- LogUtilsSearch.stampWithScores(graph, dataModel, score);
+ LogUtilsSearch.stampWithScore(graph, score);
+ LogUtilsSearch.stampWithBic(graph, dataModel);
return graph;
} else {
Fges fges = new Fges(this.score);
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Grasp.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Grasp.java
index 287699409d..796bf0758b 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Grasp.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Grasp.java
@@ -16,6 +16,7 @@
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.search.IndependenceTest;
+import edu.cmu.tetrad.search.score.GraphScore;
import edu.cmu.tetrad.search.score.Score;
import edu.cmu.tetrad.search.utils.LogUtilsSearch;
import edu.cmu.tetrad.search.utils.TsUtils;
@@ -75,6 +76,7 @@ public Graph search(DataModel dataModel, Parameters parameters) {
test.setVerbose(parameters.getBoolean(Params.VERBOSE));
edu.cmu.tetrad.search.Grasp grasp = new edu.cmu.tetrad.search.Grasp(test, score);
+ grasp.setSeed(parameters.getLong(Params.SEED));
grasp.setDepth(parameters.getInt(Params.GRASP_DEPTH));
grasp.setUncoveredDepth(parameters.getInt(Params.GRASP_SINGULAR_DEPTH));
grasp.setNonSingularDepth(parameters.getInt(Params.GRASP_NONSINGULAR_DEPTH));
@@ -89,7 +91,11 @@ public Graph search(DataModel dataModel, Parameters parameters) {
grasp.setKnowledge(this.knowledge);
grasp.bestOrder(score.getVariables());
Graph graph = grasp.getGraph(parameters.getBoolean(Params.OUTPUT_CPDAG));
- LogUtilsSearch.stampWithScores(graph, dataModel, score);
+ LogUtilsSearch.stampWithScore(graph, score);
+ LogUtilsSearch.stampWithBic(graph, dataModel);
+
+ LogUtilsSearch.stampWithBic(graph, dataModel);
+
return graph;
} else {
Grasp algorithm = new Grasp(this.test, this.score);
@@ -136,6 +142,7 @@ public List getParameters() {
params.add(Params.USE_DATA_ORDER);
params.add(Params.ALLOW_INTERNAL_RANDOMNESS);
params.add(Params.TIME_LAG);
+ params.add(Params.SEED);
params.add(Params.VERBOSE);
// Parameters
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Pc.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Pc.java
index decfad9141..b1607fb99d 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Pc.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Pc.java
@@ -20,6 +20,7 @@
import edu.cmu.tetrad.util.Params;
import edu.pitt.dbmi.algo.resampling.GeneralResamplingTest;
+import java.io.Serial;
import java.util.ArrayList;
import java.util.List;
@@ -38,6 +39,7 @@
@Bootstrapping
public class Pc implements Algorithm, HasKnowledge, TakesIndependenceWrapper,
ReturnsBootstrapGraphs {
+ @Serial
private static final long serialVersionUID = 23L;
private IndependenceWrapper test;
private Knowledge knowledge = new Knowledge();
@@ -63,47 +65,28 @@ public Graph search(DataModel dataModel, Parameters parameters) {
knowledge = timeSeries.getKnowledge();
}
- PcCommon.ConflictRule conflictRule;
-
- switch (parameters.getInt(Params.CONFLICT_RULE)) {
- case 1:
- conflictRule = PcCommon.ConflictRule.PRIORITIZE_EXISTING;
- break;
- case 2:
- conflictRule = PcCommon.ConflictRule.ORIENT_BIDIRECTED;
- break;
- case 3:
- conflictRule = PcCommon.ConflictRule.OVERWRITE_EXISTING;
- break;
- default:
- throw new IllegalArgumentException("Unknown conflict rule: " + parameters.getInt(Params.CONFLICT_RULE));
-
- }
-
-// PcCommon.PcHeuristicType pcHeuristicType;
-//
-// switch (parameters.getInt(Params.PC_HEURISTIC)) {
-// case 0:
-// pcHeuristicType = PcCommon.PcHeuristicType.NONE;
-// break;
-// case 1:
-// pcHeuristicType = PcCommon.PcHeuristicType.HEURISTIC_1;
-// break;
-// case 2:
-// pcHeuristicType = PcCommon.PcHeuristicType.HEURISTIC_2;
-// break;
-// case 3:
-// pcHeuristicType = PcCommon.PcHeuristicType.HEURISTIC_3;
-// break;
-// default:
-// throw new IllegalArgumentException("Unknown conflict rule: " + parameters.getInt(Params.CONFLICT_RULE));
-// }
+ PcCommon.ConflictRule conflictRule = switch (parameters.getInt(Params.CONFLICT_RULE)) {
+ case 1 -> PcCommon.ConflictRule.PRIORITIZE_EXISTING;
+ case 2 -> PcCommon.ConflictRule.ORIENT_BIDIRECTED;
+ case 3 -> PcCommon.ConflictRule.OVERWRITE_EXISTING;
+ default ->
+ throw new IllegalArgumentException("Unknown conflict rule: " + parameters.getInt(Params.CONFLICT_RULE));
+ };
+
+ PcCommon.PcHeuristicType pcHeuristicType = switch (parameters.getInt(Params.PC_HEURISTIC)) {
+ case 0 -> PcCommon.PcHeuristicType.NONE;
+ case 1 -> PcCommon.PcHeuristicType.HEURISTIC_1;
+ case 2 -> PcCommon.PcHeuristicType.HEURISTIC_2;
+ case 3 -> PcCommon.PcHeuristicType.HEURISTIC_3;
+ default ->
+ throw new IllegalArgumentException("Unknown conflict rule: " + parameters.getInt(Params.CONFLICT_RULE));
+ };
edu.cmu.tetrad.search.Pc search = new edu.cmu.tetrad.search.Pc(getIndependenceWrapper().getTest(dataModel, parameters));
search.setUseMaxPHeuristic(parameters.getBoolean(Params.USE_MAX_P_HEURISTIC));
search.setDepth(parameters.getInt(Params.DEPTH));
search.setMeekPreventCycles(parameters.getBoolean(Params.MEEK_PREVENT_CYCLES));
-// search.setPcHeuristicType(pcHeuristicType);
+ search.setPcHeuristicType(pcHeuristicType);
search.setVerbose(parameters.getBoolean(Params.VERBOSE));
search.setKnowledge(this.knowledge);
search.setStable(parameters.getBoolean(Params.STABLE_FAS));
@@ -149,7 +132,7 @@ public List getParameters() {
parameters.add(Params.USE_MAX_P_HEURISTIC);
parameters.add(Params.CONFLICT_RULE);
parameters.add(Params.MEEK_PREVENT_CYCLES);
-// parameters.add(Params.PC_HEURISTIC);
+ parameters.add(Params.PC_HEURISTIC);
parameters.add(Params.DEPTH);
parameters.add(Params.TIME_LAG);
parameters.add(Params.VERBOSE);
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Pcd.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Pcd.java
index f74f3d8730..91e98402dd 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Pcd.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Pcd.java
@@ -8,6 +8,7 @@
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphTransforms;
import edu.cmu.tetrad.search.test.ScoreIndTest;
+import edu.cmu.tetrad.search.utils.LogUtilsSearch;
import edu.cmu.tetrad.search.work_in_progress.SemBicScoreDeterministic;
import edu.cmu.tetrad.util.Parameters;
import edu.cmu.tetrad.util.Params;
@@ -53,7 +54,9 @@ public Graph search(DataModel dataSet, Parameters parameters) {
search.setDepth(parameters.getInt(Params.DEPTH));
search.setKnowledge(this.knowledge);
search.setVerbose(parameters.getBoolean(Params.VERBOSE));
- return search.search();
+ Graph search1 = search.search();
+ LogUtilsSearch.stampWithBic(search1, dataSet);
+ return search1;
} else {
Pcd algorithm = new Pcd();
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/RestrictedBoss.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/RestrictedBoss.java
index 50f0ca4e39..95020839a2 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/RestrictedBoss.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/RestrictedBoss.java
@@ -97,6 +97,7 @@ public Graph search(DataModel dataModel, Parameters parameters) {
boss.setUseBes(parameters.getBoolean(Params.USE_BES));
boss.setNumStarts(parameters.getInt(Params.NUM_STARTS));
PermutationSearch permutationSearch = new PermutationSearch(boss);
+ permutationSearch.setSeed(parameters.getLong(Params.SEED));
permutationSearch.setKnowledge(knowledge);
permutationSearch.search();
@@ -169,6 +170,7 @@ public List getParameters() {
params.add(Params.NUM_STARTS);
params.add(Params.TARGETS);
params.add(Params.TRIMMING_STYLE);
+ params.add(Params.SEED);
return params;
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Sp.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Sp.java
index d61d61584e..3eb33d683f 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Sp.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Sp.java
@@ -70,7 +70,8 @@ public Graph search(DataModel dataModel, Parameters parameters) {
PermutationSearch permutationSearch = new PermutationSearch(new edu.cmu.tetrad.search.Sp(score));
permutationSearch.setKnowledge(this.knowledge);
Graph graph = permutationSearch.search();
- LogUtilsSearch.stampWithScores(graph, dataModel, score);
+ LogUtilsSearch.stampWithScore(graph, score);
+ LogUtilsSearch.stampWithBic(graph, dataModel);
return graph;
} else {
Sp algorithm = new Sp(this.score);
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Bfci.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Bfci.java
index 53a25cc689..9fb5c2d212 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Bfci.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Bfci.java
@@ -22,6 +22,7 @@
import edu.cmu.tetrad.util.Params;
import edu.pitt.dbmi.algo.resampling.GeneralResamplingTest;
+import java.io.Serial;
import java.util.ArrayList;
import java.util.List;
@@ -46,6 +47,7 @@
public class Bfci implements Algorithm, UsesScoreWrapper,
TakesIndependenceWrapper, HasKnowledge, ReturnsBootstrapGraphs {
+ @Serial
private static final long serialVersionUID = 23L;
private IndependenceWrapper test;
private ScoreWrapper score;
@@ -59,6 +61,12 @@ public Bfci() {
// Used for reflection; do not delete.
}
+ /**
+ * Constructs a new BFCI algorithm using the given test and score.
+ *
+ * @param test the independence test to use
+ * @param score the score to use
+ */
public Bfci(IndependenceWrapper test, ScoreWrapper score) {
this.test = test;
this.score = score;
@@ -79,6 +87,7 @@ public Graph search(DataModel dataModel, Parameters parameters) {
BFci search = new BFci(this.test.getTest(dataModel, parameters), this.score.getScore(dataModel, parameters));
+ search.setSeed(parameters.getLong(Params.SEED));
search.setBossUseBes(parameters.getBoolean(Params.USE_BES));
search.setMaxPathLength(parameters.getInt(Params.MAX_PATH_LENGTH));
search.setCompleteRuleSetUsed(parameters.getBoolean(Params.COMPLETE_RULE_SET_USED));
@@ -130,6 +139,7 @@ public List getParameters() {
params.add(Params.DO_DISCRIMINATING_PATH_RULE);
params.add(Params.DEPTH);
params.add(Params.TIME_LAG);
+ params.add(Params.SEED);
params.add(Params.VERBOSE);
// Parameters
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Ccd.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Ccd.java
index 9accc6e7e6..5fec0b6156 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Ccd.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Ccd.java
@@ -15,6 +15,7 @@
import edu.cmu.tetrad.util.Params;
import edu.pitt.dbmi.algo.resampling.GeneralResamplingTest;
+import java.io.Serial;
import java.util.ArrayList;
import java.util.List;
@@ -29,8 +30,8 @@
algoType = AlgType.forbid_latent_common_causes
)
@Bootstrapping
-//@Experimental
public class Ccd implements Algorithm, TakesIndependenceWrapper, ReturnsBootstrapGraphs {
+ @Serial
private static final long serialVersionUID = 23L;
private IndependenceWrapper test;
private List bootstrapGraphs = new ArrayList<>();
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Fci.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Fci.java
index 752d387009..85390ea95c 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Fci.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Fci.java
@@ -14,11 +14,13 @@
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphTransforms;
+import edu.cmu.tetrad.search.utils.PcCommon;
import edu.cmu.tetrad.search.utils.TsUtils;
import edu.cmu.tetrad.util.Parameters;
import edu.cmu.tetrad.util.Params;
import edu.pitt.dbmi.algo.resampling.GeneralResamplingTest;
+import java.io.Serial;
import java.util.ArrayList;
import java.util.List;
@@ -36,6 +38,7 @@
public class Fci implements Algorithm, HasKnowledge, TakesIndependenceWrapper,
ReturnsBootstrapGraphs {
+ @Serial
private static final long serialVersionUID = 23L;
private IndependenceWrapper test;
private Knowledge knowledge = new Knowledge();
@@ -62,24 +65,14 @@ public Graph search(DataModel dataModel, Parameters parameters) {
knowledge = timeSeries.getKnowledge();
}
-// PcCommon.PcHeuristicType pcHeuristicType;
-//
-// switch (parameters.getInt(Params.PC_HEURISTIC)) {
-// case 0:
-// pcHeuristicType = PcCommon.PcHeuristicType.NONE;
-// break;
-// case 1:
-// pcHeuristicType = PcCommon.PcHeuristicType.HEURISTIC_1;
-// break;
-// case 2:
-// pcHeuristicType = PcCommon.PcHeuristicType.HEURISTIC_2;
-// break;
-// case 3:
-// pcHeuristicType = PcCommon.PcHeuristicType.HEURISTIC_3;
-// break;
-// default:
-// throw new IllegalArgumentException("Unknown conflict rule: " + parameters.getInt(Params.CONFLICT_RULE));
-// }
+ PcCommon.PcHeuristicType pcHeuristicType = switch (parameters.getInt(Params.PC_HEURISTIC)) {
+ case 0 -> PcCommon.PcHeuristicType.NONE;
+ case 1 -> PcCommon.PcHeuristicType.HEURISTIC_1;
+ case 2 -> PcCommon.PcHeuristicType.HEURISTIC_2;
+ case 3 -> PcCommon.PcHeuristicType.HEURISTIC_3;
+ default ->
+ throw new IllegalArgumentException("Unknown conflict rule: " + parameters.getInt(Params.CONFLICT_RULE));
+ };
edu.cmu.tetrad.search.Fci search = new edu.cmu.tetrad.search.Fci(this.test.getTest(dataModel, parameters));
search.setDepth(parameters.getInt(Params.DEPTH));
@@ -89,7 +82,7 @@ public Graph search(DataModel dataModel, Parameters parameters) {
search.setPossibleMsepSearchDone(parameters.getBoolean(Params.POSSIBLE_MSEP_DONE));
search.setDoDiscriminatingPathRule(parameters.getBoolean(Params.DO_DISCRIMINATING_PATH_RULE));
search.setVerbose(parameters.getBoolean(Params.VERBOSE));
-// search.setPcHeuristicType(pcHeuristicType);
+ search.setPcHeuristicType(pcHeuristicType);
search.setStable(parameters.getBoolean(Params.STABLE_FAS));
return search.search();
@@ -128,6 +121,7 @@ public List getParameters() {
List parameters = new ArrayList<>();
parameters.add(Params.DEPTH);
parameters.add(Params.STABLE_FAS);
+ parameters.add(Params.PC_HEURISTIC);
parameters.add(Params.MAX_PATH_LENGTH);
parameters.add(Params.POSSIBLE_MSEP_DONE);
parameters.add(Params.DO_DISCRIMINATING_PATH_RULE);
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/FciMax.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/FciMax.java
index 4eb8e3c8f0..48a0147fd9 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/FciMax.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/FciMax.java
@@ -14,11 +14,13 @@
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphTransforms;
+import edu.cmu.tetrad.search.utils.PcCommon;
import edu.cmu.tetrad.search.utils.TsUtils;
import edu.cmu.tetrad.util.Parameters;
import edu.cmu.tetrad.util.Params;
import edu.pitt.dbmi.algo.resampling.GeneralResamplingTest;
+import java.io.Serial;
import java.util.ArrayList;
import java.util.List;
@@ -36,6 +38,7 @@
public class FciMax implements Algorithm, HasKnowledge, TakesIndependenceWrapper,
ReturnsBootstrapGraphs {
+ @Serial
private static final long serialVersionUID = 23L;
private IndependenceWrapper test;
private Knowledge knowledge = new Knowledge();
@@ -61,24 +64,14 @@ public Graph search(DataModel dataModel, Parameters parameters) {
knowledge = timeSeries.getKnowledge();
}
-// PcCommon.PcHeuristicType pcHeuristicType;
-//
-// switch (parameters.getInt(Params.PC_HEURISTIC)) {
-// case 0:
-// pcHeuristicType = PcCommon.PcHeuristicType.NONE;
-// break;
-// case 1:
-// pcHeuristicType = PcCommon.PcHeuristicType.HEURISTIC_1;
-// break;
-// case 2:
-// pcHeuristicType = PcCommon.PcHeuristicType.HEURISTIC_2;
-// break;
-// case 3:
-// pcHeuristicType = PcCommon.PcHeuristicType.HEURISTIC_3;
-// break;
-// default:
-// throw new IllegalArgumentException("Unknown conflict rule: " + parameters.getInt(Params.CONFLICT_RULE));
-// }
+ PcCommon.PcHeuristicType pcHeuristicType = switch (parameters.getInt(Params.PC_HEURISTIC)) {
+ case 0 -> PcCommon.PcHeuristicType.NONE;
+ case 1 -> PcCommon.PcHeuristicType.HEURISTIC_1;
+ case 2 -> PcCommon.PcHeuristicType.HEURISTIC_2;
+ case 3 -> PcCommon.PcHeuristicType.HEURISTIC_3;
+ default ->
+ throw new IllegalArgumentException("Unknown conflict rule: " + parameters.getInt(Params.CONFLICT_RULE));
+ };
edu.cmu.tetrad.search.FciMax search = new edu.cmu.tetrad.search.FciMax(this.test.getTest(dataModel, parameters));
search.setDepth(parameters.getInt(Params.DEPTH));
@@ -87,7 +80,7 @@ public Graph search(DataModel dataModel, Parameters parameters) {
search.setCompleteRuleSetUsed(parameters.getBoolean(Params.COMPLETE_RULE_SET_USED));
search.setDoDiscriminatingPathRule(parameters.getBoolean(Params.DO_DISCRIMINATING_PATH_RULE));
search.setPossibleMsepSearchDone(parameters.getBoolean(Params.POSSIBLE_MSEP_DONE));
-// search.setPcHeuristicType(pcHeuristicType);
+ search.setPcHeuristicType(pcHeuristicType);
search.setVerbose(parameters.getBoolean(Params.VERBOSE));
return search.search();
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Gfci.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Gfci.java
index c4404a117c..8c04f9d55b 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Gfci.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Gfci.java
@@ -22,6 +22,7 @@
import edu.pitt.dbmi.algo.resampling.GeneralResamplingTest;
import java.io.PrintStream;
+import java.io.Serial;
import java.util.ArrayList;
import java.util.List;
@@ -40,6 +41,7 @@
public class Gfci implements Algorithm, HasKnowledge, UsesScoreWrapper, TakesIndependenceWrapper,
ReturnsBootstrapGraphs {
+ @Serial
private static final long serialVersionUID = 23L;
private IndependenceWrapper test;
private ScoreWrapper score;
@@ -138,7 +140,7 @@ public Knowledge getKnowledge() {
@Override
public void setKnowledge(Knowledge knowledge) {
- this.knowledge = new Knowledge((Knowledge) knowledge);
+ this.knowledge = new Knowledge(knowledge);
}
@Override
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/GraspFci.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/GraspFci.java
index 1475820bbb..874b2ba97d 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/GraspFci.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/GraspFci.java
@@ -81,6 +81,7 @@ public Graph search(DataModel dataModel, Parameters parameters) {
edu.cmu.tetrad.search.GraspFci search = new edu.cmu.tetrad.search.GraspFci(test, score);
// GRaSP
+ search.setSeed(parameters.getLong(Params.SEED));
search.setDepth(parameters.getInt(Params.GRASP_DEPTH));
search.setSingularDepth(parameters.getInt(Params.GRASP_SINGULAR_DEPTH));
search.setNonSingularDepth(parameters.getInt(Params.GRASP_NONSINGULAR_DEPTH));
@@ -154,6 +155,9 @@ public List getParameters() {
// General
params.add(Params.TIME_LAG);
+
+ params.add(Params.SEED);
+
params.add(Params.VERBOSE);
return params;
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/SvarFci.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/SvarFci.java
index 40c66a4e1b..792c776ad1 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/SvarFci.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/SvarFci.java
@@ -20,6 +20,7 @@
import edu.cmu.tetrad.util.Params;
import edu.pitt.dbmi.algo.resampling.GeneralResamplingTest;
+import java.io.Serial;
import java.util.ArrayList;
import java.util.List;
@@ -39,6 +40,7 @@
public class SvarFci implements Algorithm, HasKnowledge, TakesIndependenceWrapper,
ReturnsBootstrapGraphs {
+ @Serial
private static final long serialVersionUID = 23L;
private IndependenceWrapper test;
private Knowledge knowledge;
@@ -117,7 +119,7 @@ public Knowledge getKnowledge() {
@Override
public void setKnowledge(Knowledge knowledge) {
- this.knowledge = new Knowledge((Knowledge) knowledge);
+ this.knowledge = new Knowledge(knowledge);
}
@Override
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/examples/TestBoss.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/examples/TestBoss.java
index 526ba12154..238428f396 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/examples/TestBoss.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/examples/TestBoss.java
@@ -31,8 +31,12 @@
import edu.cmu.tetrad.algcomparison.simulation.SemSimulation;
import edu.cmu.tetrad.algcomparison.simulation.Simulations;
import edu.cmu.tetrad.algcomparison.statistic.*;
+import edu.cmu.tetrad.util.MillisecondTimes;
import edu.cmu.tetrad.util.Parameters;
import edu.cmu.tetrad.util.Params;
+import org.apache.commons.math3.linear.Array2DRowRealMatrix;
+import org.apache.commons.math3.linear.BlockRealMatrix;
+import org.apache.commons.math3.linear.RealMatrix;
/**
* Test the degenerate Gaussian score.
@@ -41,6 +45,11 @@
*/
public class TestBoss {
public static void main(String... args) {
+ if (true) {
+ testGigaflops();
+ return;
+ }
+
Parameters parameters = new Parameters();
parameters.set(Params.NUM_RUNS, 1);
parameters.set(Params.DIFFERENT_GRAPHS, true);
@@ -94,6 +103,32 @@ public static void main(String... args) {
comparison.compareFromSimulations("comparison", simulations, algorithms, statistics, parameters);
}
+
+ public static void testGigaflops() {
+
+ final long start = MillisecondTimes.timeMillis();// System.currentTimeMillis();
+
+ for (int i = 0; i < 200; i++) {
+ int N = 1024;
+ RealMatrix A = new BlockRealMatrix(N, N);
+ RealMatrix B = new BlockRealMatrix(N, N);
+
+ MillisecondTimes.type = MillisecondTimes.Type.CPU;
+
+ RealMatrix C = A.multiply(B);
+ final long end = MillisecondTimes.timeMillis();// System.currentTimeMillis();
+
+ double gflop = N * N * N * 2e-9;
+ double sec = (end - start) * 1e-3;
+
+ System.out.println(gflop / sec);
+ }
+
+ final long end = MillisecondTimes.timeMillis();// System.currentTimeMillis();
+
+ System.out.println((end - start) * 1e-3);
+
+ }
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/independence/ChiSquare.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/independence/ChiSquare.java
index d78c8c1147..aa1a7d771a 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/independence/ChiSquare.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/independence/ChiSquare.java
@@ -28,7 +28,9 @@ public class ChiSquare implements IndependenceWrapper {
@Override
public IndependenceTest getTest(DataModel dataSet, Parameters parameters) {
- return new IndTestChiSquare(SimpleDataLoader.getDiscreteDataSet(dataSet), parameters.getDouble("alpha"));
+ IndTestChiSquare test = new IndTestChiSquare(SimpleDataLoader.getDiscreteDataSet(dataSet), parameters.getDouble(Params.ALPHA));
+ test.setMinCountPerCell(parameters.getDouble(Params.MIN_COUNT_PER_CELL));
+ return test;
}
@Override
@@ -45,7 +47,7 @@ public DataType getDataType() {
public List getParameters() {
List params = new ArrayList<>();
params.add(Params.ALPHA);
+ params.add(Params.MIN_COUNT_PER_CELL);
return params;
}
-
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/independence/FisherZ.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/independence/FisherZ.java
index 48c0ab7fc3..ac07b8ccf0 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/independence/FisherZ.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/independence/FisherZ.java
@@ -11,6 +11,7 @@
import edu.cmu.tetrad.util.Parameters;
import edu.cmu.tetrad.util.Params;
+import java.io.Serial;
import java.util.ArrayList;
import java.util.List;
@@ -27,19 +28,25 @@
@LinearGaussian
public class FisherZ implements IndependenceWrapper {
+ @Serial
private static final long serialVersionUID = 23L;
@Override
- public IndependenceTest getTest(DataModel dataSet, Parameters parameters) {
+ public IndependenceTest getTest(DataModel dataModel, Parameters parameters) {
double alpha = parameters.getDouble(Params.ALPHA);
- if (dataSet instanceof ICovarianceMatrix) {
- return new IndTestFisherZ((ICovarianceMatrix) dataSet, alpha);
- } else if (dataSet instanceof DataSet) {
- return new IndTestFisherZ((DataSet) dataSet, alpha);
+ IndTestFisherZ test;
+
+ if (dataModel instanceof ICovarianceMatrix) {
+ test = new IndTestFisherZ((ICovarianceMatrix) dataModel, alpha);
+ } else if (dataModel instanceof DataSet) {
+ test = new IndTestFisherZ((DataSet) dataModel, alpha);
+ } else {
+ throw new IllegalArgumentException("Expecting either a dataset or a covariance matrix.");
}
- throw new IllegalArgumentException("Expecting eithet a data set or a covariance matrix.");
+ test.setUsePseudoinverse(parameters.getBoolean(Params.USE_PSEUDOINVERSE));
+ return test;
}
@Override
@@ -56,6 +63,7 @@ public DataType getDataType() {
public List getParameters() {
List params = new ArrayList<>();
params.add(Params.ALPHA);
+ params.add(Params.USE_PSEUDOINVERSE);
return params;
}
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/independence/GICScoreTests.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/independence/GICScoreTests.java
index bd02f7529f..419c15561c 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/independence/GICScoreTests.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/independence/GICScoreTests.java
@@ -12,6 +12,7 @@
import edu.cmu.tetrad.util.Parameters;
import edu.cmu.tetrad.util.Params;
+import java.io.Serial;
import java.util.ArrayList;
import java.util.List;
@@ -28,6 +29,7 @@
@LinearGaussian
public class GICScoreTests implements IndependenceWrapper {
+ @Serial
private static final long serialVersionUID = 23L;
@Override
@@ -71,7 +73,7 @@ public IndependenceTest getTest(DataModel dataSet, Parameters parameters) {
score.setRuleType(ruleType);
score.setPenaltyDiscount(parameters.getDouble(Params.PENALTY_DISCOUNT));
-
+ score.setUsePseudoInverse(parameters.getBoolean(Params.USE_PSEUDOINVERSE));
return new ScoreIndTest(score, dataSet);
}
@@ -92,6 +94,7 @@ public List getParameters() {
params.add(Params.SEM_GIC_RULE);
params.add(Params.PENALTY_DISCOUNT_ZS);
params.add(Params.PRECOMPUTE_COVARIANCES);
+ params.add(Params.USE_PSEUDOINVERSE);
return params;
}
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/independence/Gsquare.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/independence/GSquare.java
similarity index 75%
rename from tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/independence/Gsquare.java
rename to tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/independence/GSquare.java
index e1574d1c92..f7673b15a2 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/independence/Gsquare.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/independence/GSquare.java
@@ -9,6 +9,7 @@
import edu.cmu.tetrad.util.Parameters;
import edu.cmu.tetrad.util.Params;
+import java.io.Serial;
import java.util.ArrayList;
import java.util.List;
@@ -22,13 +23,16 @@
command = "g-square-test",
dataType = DataType.Discrete
)
-public class Gsquare implements IndependenceWrapper {
+public class GSquare implements IndependenceWrapper {
+ @Serial
private static final long serialVersionUID = 23L;
@Override
public IndependenceTest getTest(DataModel dataSet, Parameters parameters) {
- return new IndTestGSquare(SimpleDataLoader.getDiscreteDataSet(dataSet), parameters.getDouble("alpha"));
+ IndTestGSquare test = new IndTestGSquare(SimpleDataLoader.getDiscreteDataSet(dataSet), parameters.getDouble("test"));
+ test.setMinCountPerCell(parameters.getDouble(Params.MIN_COUNT_PER_CELL));
+ return test;
}
@Override
@@ -45,6 +49,7 @@ public DataType getDataType() {
public List getParameters() {
List params = new ArrayList<>();
params.add(Params.ALPHA);
+ params.add(Params.MIN_COUNT_PER_CELL);
return params;
}
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/independence/Kci.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/independence/Kci.java
index 82436d94b0..ab2646b278 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/independence/Kci.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/independence/Kci.java
@@ -9,13 +9,12 @@
import edu.cmu.tetrad.util.Parameters;
import edu.cmu.tetrad.util.Params;
+import java.io.Serial;
import java.util.ArrayList;
import java.util.List;
/**
* Wrapper for KCI test.
- *
- * Note that should work with Linear, Gaussian variables but is general.
*
* @author josephramsey
*/
@@ -27,14 +26,21 @@
@General
public class Kci implements IndependenceWrapper {
+ @Serial
private static final long serialVersionUID = 23L;
-
+ /**
+ * Returns a KCI test.
+ *
+ * @param dataSet The data set to test independence against.
+ * @param parameters The paramters of the test.
+ * @return A KCI test.
+ */
@Override
public IndependenceTest getTest(DataModel dataSet, Parameters parameters) {
edu.cmu.tetrad.search.test.Kci kci = new edu.cmu.tetrad.search.test.Kci(SimpleDataLoader.getContinuousDataSet(dataSet),
parameters.getDouble(Params.ALPHA));
- kci.setApproximate(parameters.getBoolean(Params.KCI_USE_APPROMATION));
+ kci.setApproximate(parameters.getBoolean(Params.KCI_USE_APPROXIMATION));
kci.setWidthMultiplier(parameters.getDouble(Params.KERNEL_MULTIPLIER));
kci.setNumBootstraps(parameters.getInt(Params.KCI_NUM_BOOTSTRAPS));
kci.setThreshold(parameters.getDouble(Params.THRESHOLD_FOR_NUM_EIGENVALUES));
@@ -42,20 +48,36 @@ public IndependenceTest getTest(DataModel dataSet, Parameters parameters) {
return kci;
}
+ /**
+ * Returns the name of the test.
+ *
+ * @return The name of the test.
+ */
@Override
public String getDescription() {
return "KCI";
}
+ /**
+ * Returns the data type of the test, which is continuous.
+ *
+ * @return The data type of the test, which is continuous.
+ * @see DataType
+ */
@Override
public DataType getDataType() {
return DataType.Continuous;
}
+ /**
+ * Returns the parameters of the test.
+ *
+ * @return The parameters of the test.
+ */
@Override
public List getParameters() {
List params = new ArrayList<>();
- params.add(Params.KCI_USE_APPROMATION);
+ params.add(Params.KCI_USE_APPROXIMATION);
params.add(Params.ALPHA);
params.add(Params.KERNEL_MULTIPLIER);
params.add(Params.KCI_NUM_BOOTSTRAPS);
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/independence/PoissonScoreTest.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/independence/PoissonScoreTest.java
index 7133f58047..011e9ac1c4 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/independence/PoissonScoreTest.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/independence/PoissonScoreTest.java
@@ -12,6 +12,7 @@
import edu.cmu.tetrad.util.Parameters;
import edu.cmu.tetrad.util.Params;
+import java.io.Serial;
import java.util.ArrayList;
import java.util.List;
@@ -28,6 +29,7 @@
@LinearGaussian
public class PoissonScoreTest implements IndependenceWrapper {
+ @Serial
private static final long serialVersionUID = 23L;
@Override
@@ -39,7 +41,9 @@ public IndependenceTest getTest(DataModel dataSet, Parameters parameters) {
} else {
score = new PoissonPriorScore((DataSet) dataSet, true);
}
+
score.setLambda(parameters.getDouble(Params.POISSON_LAMBDA));
+ score.setUsePseudoInverse(parameters.getBoolean(Params.USE_PSEUDOINVERSE));
return new ScoreIndTest(score, dataSet);
}
@@ -58,6 +62,7 @@ public DataType getDataType() {
public List getParameters() {
List params = new ArrayList<>();
params.add(Params.POISSON_LAMBDA);
+ params.add(Params.USE_PSEUDOINVERSE);
return params;
}
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/independence/SemBicTest.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/independence/SemBicTest.java
index 1bd5bd099f..890d77e66f 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/independence/SemBicTest.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/independence/SemBicTest.java
@@ -12,6 +12,7 @@
import edu.cmu.tetrad.util.Parameters;
import edu.cmu.tetrad.util.Params;
+import java.io.Serial;
import java.util.ArrayList;
import java.util.List;
@@ -28,6 +29,7 @@
@LinearGaussian
public class SemBicTest implements IndependenceWrapper {
+ @Serial
private static final long serialVersionUID = 23L;
@Override
@@ -43,6 +45,7 @@ public IndependenceTest getTest(DataModel dataSet, Parameters parameters) {
}
score.setPenaltyDiscount(parameters.getDouble(Params.PENALTY_DISCOUNT));
score.setStructurePrior(parameters.getDouble(Params.STRUCTURE_PRIOR));
+ score.setUsePseudoInverse(parameters.getBoolean(Params.USE_PSEUDOINVERSE));
return new ScoreIndTest(score, dataSet);
}
@@ -63,6 +66,7 @@ public List getParameters() {
params.add(Params.PENALTY_DISCOUNT);
params.add(Params.STRUCTURE_PRIOR);
params.add(Params.PRECOMPUTE_COVARIANCES);
+ params.add(Params.USE_PSEUDOINVERSE);
return params;
}
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/score/EbicScore.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/score/EbicScore.java
index 3b9e10c94f..83061b8dd8 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/score/EbicScore.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/score/EbicScore.java
@@ -10,6 +10,7 @@
import edu.cmu.tetrad.util.Parameters;
import edu.cmu.tetrad.util.Params;
+import java.io.Serial;
import java.util.ArrayList;
import java.util.List;
@@ -26,6 +27,7 @@
@LinearGaussian
public class EbicScore implements ScoreWrapper {
+ @Serial
private static final long serialVersionUID = 23L;
private DataModel dataSet;
@@ -44,7 +46,7 @@ public Score getScore(DataModel dataSet, Parameters parameters) {
}
score.setGamma(parameters.getDouble(Params.EBIC_GAMMA));
-// score.setCorrelationThreshold(parameters.getDouble(Params.CORRELATION_THRESHOLD));
+ score.setUsePseudoInverse(parameters.getBoolean(Params.USE_PSEUDOINVERSE));
return score;
}
@@ -62,8 +64,8 @@ public DataType getDataType() {
public List getParameters() {
List parameters = new ArrayList<>();
parameters.add(Params.EBIC_GAMMA);
-// parameters.add(Params.CORRELATION_THRESHOLD);
parameters.add(Params.PRECOMPUTE_COVARIANCES);
+ parameters.add(Params.USE_PSEUDOINVERSE);
return parameters;
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/score/GicScores.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/score/GicScores.java
index aa169a1ad9..997bf06e53 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/score/GicScores.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/score/GicScores.java
@@ -10,6 +10,7 @@
import edu.cmu.tetrad.util.Parameters;
import edu.cmu.tetrad.util.Params;
+import java.io.Serial;
import java.util.ArrayList;
import java.util.List;
@@ -19,13 +20,14 @@
* @author josephramsey
*/
@edu.cmu.tetrad.annotation.Score(
- name = "Generalied Information Criterion Scores",
+ name = "Generalized Information Criterion Scores",
command = "gic-scores",
dataType = {DataType.Continuous, DataType.Covariance}
)
@LinearGaussian
public class GicScores implements ScoreWrapper {
+ @Serial
private static final long serialVersionUID = 23L;
private DataModel dataSet;
@@ -45,33 +47,19 @@ public Score getScore(DataModel dataSet, Parameters parameters) {
}
int anInt = parameters.getInt((Params.SEM_GIC_RULE));
- edu.cmu.tetrad.search.score.GicScores.RuleType ruleType;
-
- switch (anInt) {
- case 1:
- ruleType = edu.cmu.tetrad.search.score.GicScores.RuleType.BIC;
- break;
- case 2:
- ruleType = edu.cmu.tetrad.search.score.GicScores.RuleType.GIC2;
- break;
- case 3:
- ruleType = edu.cmu.tetrad.search.score.GicScores.RuleType.RIC;
- break;
- case 4:
- ruleType = edu.cmu.tetrad.search.score.GicScores.RuleType.RICc;
- break;
- case 5:
- ruleType = edu.cmu.tetrad.search.score.GicScores.RuleType.GIC5;
- break;
- case 6:
- ruleType = edu.cmu.tetrad.search.score.GicScores.RuleType.GIC6;
- break;
- default:
- throw new IllegalArgumentException("Unrecognized rule type: " + anInt);
- }
+ edu.cmu.tetrad.search.score.GicScores.RuleType ruleType = switch (anInt) {
+ case 1 -> edu.cmu.tetrad.search.score.GicScores.RuleType.BIC;
+ case 2 -> edu.cmu.tetrad.search.score.GicScores.RuleType.GIC2;
+ case 3 -> edu.cmu.tetrad.search.score.GicScores.RuleType.RIC;
+ case 4 -> edu.cmu.tetrad.search.score.GicScores.RuleType.RICc;
+ case 5 -> edu.cmu.tetrad.search.score.GicScores.RuleType.GIC5;
+ case 6 -> edu.cmu.tetrad.search.score.GicScores.RuleType.GIC6;
+ default -> throw new IllegalArgumentException("Unrecognized rule type: " + anInt);
+ };
score.setRuleType(ruleType);
score.setPenaltyDiscount(parameters.getDouble(Params.PENALTY_DISCOUNT));
+ score.setUsePseudoInverse(parameters.getBoolean(Params.USE_PSEUDOINVERSE));
return score;
}
@@ -92,6 +80,8 @@ public List getParameters() {
parameters.add(Params.SEM_GIC_RULE);
parameters.add(Params.PENALTY_DISCOUNT_ZS);
parameters.add(Params.PRECOMPUTE_COVARIANCES);
+ parameters.add(Params.USE_PSEUDOINVERSE);
+
return parameters;
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/score/PoissonPriorScore.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/score/PoissonPriorScore.java
index 124039dd0c..889419e0c3 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/score/PoissonPriorScore.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/score/PoissonPriorScore.java
@@ -44,6 +44,7 @@ public Score getScore(DataModel dataSet, Parameters parameters) {
}
score.setLambda(parameters.getDouble(Params.POISSON_LAMBDA));
+ score.setUsePseudoInverse(parameters.getBoolean(Params.USE_PSEUDOINVERSE));
return score;
}
@@ -63,6 +64,7 @@ public List getParameters() {
List parameters = new ArrayList<>();
parameters.add(Params.PRECOMPUTE_COVARIANCES);
parameters.add(Params.POISSON_LAMBDA);
+ parameters.add(Params.USE_PSEUDOINVERSE);
return parameters;
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/score/ScoreWrapper.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/score/ScoreWrapper.java
index 686dad09ea..ef145249b4 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/score/ScoreWrapper.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/score/ScoreWrapper.java
@@ -50,7 +50,9 @@ public interface ScoreWrapper extends HasParameters, TetradSerializable {
/**
* Returns the variable with the given name.
+ *
* @param name the name.
+ * @return the variable.
*/
Node getVariable(String name);
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/score/SemBicScore.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/score/SemBicScore.java
index be60ded960..ab6d0c6bab 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/score/SemBicScore.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/score/SemBicScore.java
@@ -10,6 +10,7 @@
import edu.cmu.tetrad.util.Parameters;
import edu.cmu.tetrad.util.Params;
+import java.io.Serial;
import java.util.ArrayList;
import java.util.List;
@@ -26,6 +27,7 @@
@LinearGaussian
public class SemBicScore implements ScoreWrapper {
+ @Serial
private static final long serialVersionUID = 23L;
private DataModel dataSet;
@@ -46,6 +48,7 @@ public Score getScore(DataModel dataSet, Parameters parameters) {
semBicScore.setPenaltyDiscount(parameters.getDouble(Params.PENALTY_DISCOUNT));
semBicScore.setStructurePrior(parameters.getDouble(Params.SEM_BIC_STRUCTURE_PRIOR));
+ semBicScore.setUsePseudoInverse(parameters.getBoolean(Params.USE_PSEUDOINVERSE));
switch (parameters.getInt(Params.SEM_BIC_RULE)) {
case 1:
@@ -78,6 +81,7 @@ public List getParameters() {
parameters.add(Params.SEM_BIC_STRUCTURE_PRIOR);
parameters.add(Params.SEM_BIC_RULE);
parameters.add(Params.PRECOMPUTE_COVARIANCES);
+ parameters.add(Params.USE_PSEUDOINVERSE);
return parameters;
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/score/ZhangShenBoundScore.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/score/ZhangShenBoundScore.java
index 9f6649fa03..dd6f85baf6 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/score/ZhangShenBoundScore.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/score/ZhangShenBoundScore.java
@@ -11,6 +11,7 @@
import edu.cmu.tetrad.util.Parameters;
import edu.cmu.tetrad.util.Params;
+import java.io.Serial;
import java.util.ArrayList;
import java.util.List;
@@ -27,6 +28,7 @@
@LinearGaussian
public class ZhangShenBoundScore implements ScoreWrapper {
+ @Serial
private static final long serialVersionUID = 23L;
private DataModel dataSet;
@@ -47,6 +49,7 @@ public Score getScore(DataModel dataSet, Parameters parameters) {
}
score.setRiskBound(parameters.getDouble(Params.ZS_RISK_BOUND));
+ score.setUsePseudoInverse(parameters.getBoolean(Params.USE_PSEUDOINVERSE));
return score;
}
@@ -66,6 +69,7 @@ public List getParameters() {
List parameters = new ArrayList<>();
parameters.add(Params.ZS_RISK_BOUND);
parameters.add(Params.PRECOMPUTE_COVARIANCES);
+ parameters.add(Params.USE_PSEUDOINVERSE);
return parameters;
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/BayesNetSimulation.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/BayesNetSimulation.java
index 8347eceea0..6116af41a6 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/BayesNetSimulation.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/BayesNetSimulation.java
@@ -84,8 +84,6 @@ public void createData(Parameters parameters, boolean newModel) {
this.ims = new ArrayList<>();
for (int i = 0; i < parameters.getInt(Params.NUM_RUNS); i++) {
- System.out.println("Simulating dataset #" + (i + 1));
-
if (parameters.getBoolean(Params.DIFFERENT_GRAPHS) && i > 0) {
graph = this.randomGraph.createGraph(parameters);
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/ConditionalGaussianSimulation.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/ConditionalGaussianSimulation.java
index 6f1ec1107d..d06b36e6eb 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/ConditionalGaussianSimulation.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/ConditionalGaussianSimulation.java
@@ -108,8 +108,6 @@ public void createData(Parameters parameters, boolean newModel) {
this.graphs = new ArrayList<>();
for (int i = 0; i < parameters.getInt(Params.NUM_RUNS); i++) {
- System.out.println("Simulating dataset #" + (i + 1));
-
if (parameters.getBoolean(Params.DIFFERENT_GRAPHS) && i > 0) {
graph = this.randomGraph.createGraph(parameters);
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/GeneralSemSimulation.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/GeneralSemSimulation.java
index c33bd2d0ff..dc1b8a97cc 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/GeneralSemSimulation.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/GeneralSemSimulation.java
@@ -66,8 +66,6 @@ public void createData(Parameters parameters, boolean newModel) {
this.ims = new ArrayList<>();
for (int i = 0; i < parameters.getInt(Params.NUM_RUNS); i++) {
- System.out.println("Simulating dataset #" + (i + 1));
-
if (parameters.getBoolean(Params.DIFFERENT_GRAPHS) && i > 0) {
graph = this.randomGraph.createGraph(parameters);
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/GeneralSemSimulationSpecial1.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/GeneralSemSimulationSpecial1.java
index b2801cc127..5ba3bf423a 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/GeneralSemSimulationSpecial1.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/GeneralSemSimulationSpecial1.java
@@ -44,8 +44,6 @@ public void createData(Parameters parameters, boolean newModel) {
this.graphs = new ArrayList<>();
for (int i = 0; i < parameters.getInt(Params.NUM_RUNS); i++) {
- System.out.println("Simulating dataset #" + (i + 1));
-
if (parameters.getBoolean(Params.DIFFERENT_GRAPHS) && i > 0) {
graph = this.randomGraph.createGraph(parameters);
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/LeeHastieSimulation.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/LeeHastieSimulation.java
index d20aaf89c1..08f43f7b43 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/LeeHastieSimulation.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/LeeHastieSimulation.java
@@ -65,8 +65,6 @@ public void createData(Parameters parameters, boolean newModel) {
this.graphs = new ArrayList<>();
for (int i = 0; i < parameters.getInt(Params.NUM_RUNS); i++) {
- System.out.println("Simulating dataset #" + (i + 1));
-
if (parameters.getBoolean(Params.DIFFERENT_GRAPHS) && i > 0) {
graph = this.randomGraph.createGraph(parameters);
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/LinearFisherModel.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/LinearFisherModel.java
index f896c17611..a750843f97 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/LinearFisherModel.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/LinearFisherModel.java
@@ -70,8 +70,6 @@ public void createData(Parameters parameters, boolean newModel) {
System.out.println("degree = " + GraphUtils.getDegree(graph));
for (int i = 0; i < parameters.getInt(Params.NUM_RUNS); i++) {
- System.out.println("Simulating dataset #" + (i + 1));
-
if (this.shocks != null && this.shocks.size() > 0) {
parameters.set(Params.NUM_MEASURES, this.shocks.get(0).getVariables().size());
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/LinearSineSimulation.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/LinearSineSimulation.java
index b0cf5836d8..38473e0238 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/LinearSineSimulation.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/LinearSineSimulation.java
@@ -96,8 +96,6 @@ public void createData(Parameters parameters, boolean newModel) {
this.graphs = new ArrayList<>();
for (int i = 0; i < parameters.getInt(Params.NUM_RUNS); i++) {
- System.out.println("Simulating dataset #" + (i + 1));
-
if (parameters.getBoolean(Params.DIFFERENT_GRAPHS) && i > 0)
graph = this.randomGraph.createGraph(parameters);
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/NLSemSimulation.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/NLSemSimulation.java
index ae8e1a97f8..e9e945aeff 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/NLSemSimulation.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/NLSemSimulation.java
@@ -53,8 +53,6 @@ public void createData(Parameters parameters, boolean newModel) {
int numVars = parameters.getInt(Params.NUM_MEASURES);
for (int i = 0; i < parameters.getInt(Params.NUM_RUNS); i++) {
- System.out.println("Simulating dataset #" + (i + 1));
-
if (parameters.getBoolean(Params.DIFFERENT_GRAPHS) && i > 0) {
graph = this.randomGraph.createGraph(parameters);
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/SemSimulation.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/SemSimulation.java
index 989dfd72f0..e782f6d1b7 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/SemSimulation.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/SemSimulation.java
@@ -67,8 +67,6 @@ public void createData(Parameters parameters, boolean newModel) {
this.ims = new ArrayList<>();
for (int i = 0; i < parameters.getInt(Params.NUM_RUNS); i++) {
- System.out.println("Simulating dataset #" + (i + 1));
-
if (parameters.getBoolean(Params.DIFFERENT_GRAPHS) && i > 0) {
graph = this.randomGraph.createGraph(parameters);
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/SemThenDiscretize.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/SemThenDiscretize.java
index e2bf648fbe..6f8a74aed8 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/SemThenDiscretize.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/SemThenDiscretize.java
@@ -65,8 +65,6 @@ public void createData(Parameters parameters, boolean newModel) {
this.graphs = new ArrayList<>();
for (int i = 0; i < parameters.getInt(Params.NUM_RUNS); i++) {
- System.out.println("Simulating dataset #" + (i + 1));
-
if (parameters.getBoolean(Params.DIFFERENT_GRAPHS) && i > 0) {
graph = this.randomGraph.createGraph(parameters);
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/SimulationTypes.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/SimulationTypes.java
index 7c84ecb1fb..87391af9a4 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/SimulationTypes.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/SimulationTypes.java
@@ -25,12 +25,12 @@
*/
public final class SimulationTypes {
- public static final String BAYS_NET = "Bayes Net";
- public static final String STRUCTURAL_EQUATION_MODEL = "Structural Equation Model";
- public static final String NON_LINEAR_STRUCTURAL_EQUATION_MODEL = "Non-Linear Structural Equation Model";
+ public static final String BAYS_NET = "Bayes Net (Multinomial)";
+ public static final String STRUCTURAL_EQUATION_MODEL = "Linear Structural Equation Model";
public static final String LINEAR_FISHER_MODEL = "Linear Fisher Model";
- public static final String LEE_AND_HASTIE = "Lee & Hastie";
- public static final String CONDITIONAL_GAUSSIAN = "Conditional Gaussian";
+ public static final String NON_LINEAR_STRUCTURAL_EQUATION_MODEL = "Non-Linear Structural Equation Model";
+ public static final String LEE_AND_HASTIE = "Mixed Lee & Hastie";
+ public static final String CONDITIONAL_GAUSSIAN = "Mixed Conditional Gaussian";
public static final String TIME_SERIES = "Time Series";
public static final String STANDARDIZED_STRUCTURAL_EQUATION_MODEL = "Standardized Structural Equation Model";
public static final String GENERAL_STRUCTURAL_EQUATION_MODEL = "General Structural Equation Model";
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/StandardizedSemSimulation.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/StandardizedSemSimulation.java
index dcf2f7f114..cf4ecbed1a 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/StandardizedSemSimulation.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/StandardizedSemSimulation.java
@@ -57,8 +57,6 @@ public void createData(Parameters parameters, boolean newModel) {
this.graphs = new ArrayList<>();
for (int i = 0; i < parameters.getInt("numRuns"); i++) {
- System.out.println("Simulating dataset #" + (i + 1));
-
if (parameters.getBoolean("differentGraphs") && i > 0) {
graph = this.randomGraph.createGraph(parameters);
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/BicDiff.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/BicDiff.java
index 83241bb266..64660110c4 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/BicDiff.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/BicDiff.java
@@ -8,7 +8,7 @@
import static org.apache.commons.math3.util.FastMath.tanh;
/**
- * Difference between the true and estiamted BIC scores.
+ * Difference between the true and estiamted BIC scores. The BIC is calculated as 2L - k ln N, so "higher is better."
*
* @author josephramsey
*/
@@ -16,28 +16,57 @@ public class BicDiff implements Statistic {
private static final long serialVersionUID = 23L;
private boolean precomputeCovariances = true;
+ /**
+ * Returns the name of the statistic.
+ *
+ * @return the name of the statistic.
+ */
@Override
public String getAbbreviation() {
return "BicDiff";
}
+ /**
+ * Returns the description of the statistic.
+ *
+ * @return the description of the statistic.
+ */
@Override
public String getDescription() {
return "Difference between the true and estimated BIC scores";
}
+ /**
+ * Returns the value of the statistic.
+ *
+ * @param trueGraph The true graph (DAG, CPDAG, PAG_of_the_true_DAG).
+ * @param estGraph The estimated graph (same type).
+ * @param dataModel The data model.
+ * @return The value of the statistic.
+ */
@Override
public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) {
- double _true = SemBicScorer.scoreDag(GraphTransforms.dagFromCPDAG(trueGraph, null), dataModel, precomputeCovariances);
- double est = SemBicScorer.scoreDag(GraphTransforms.dagFromCPDAG(estGraph, null), dataModel, precomputeCovariances);
+ double _true = SemBicScorer.scoreDag(GraphTransforms.dagFromCpdag(trueGraph, null), dataModel, precomputeCovariances);
+ double est = SemBicScorer.scoreDag(GraphTransforms.dagFromCpdag(estGraph, null), dataModel, precomputeCovariances);
return (_true - est);
}
+ /**
+ * Returns the normalized value of the statistic.
+ *
+ * @param value The value of the statistic.
+ * @return The normalized value of the statistic.
+ */
@Override
public double getNormValue(double value) {
return tanh(value / 1e6);
}
+ /**
+ * Returns the precompute covariances flag.
+ *
+ * @param precomputeCovariances The precompute covariances flag.
+ */
public void setPrecomputeCovariances(boolean precomputeCovariances) {
this.precomputeCovariances = precomputeCovariances;
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/BicDiffPerRecord.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/BicDiffPerRecord.java
index c98412966e..ca0f412a09 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/BicDiffPerRecord.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/BicDiffPerRecord.java
@@ -29,20 +29,39 @@ public String getDescription() {
"divided by the sample size";
}
+ /**
+ * Returns the difference between the true and estimated BIC scores, divided by the sample size.
+ *
+ * @param trueGraph The true graph (DAG, CPDAG, PAG_of_the_true_DAG).
+ * @param estGraph The estimated graph (same type).
+ * @param dataModel The data model.
+ * @return The difference between the true and estimated BIC scores, divided by the sample size.
+ */
@Override
public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) {
- double _true = SemBicScorer.scoreDag(GraphTransforms.dagFromCPDAG(trueGraph, null), dataModel, precomputeCovariances);
- double est = SemBicScorer.scoreDag(GraphTransforms.dagFromCPDAG(estGraph, null), dataModel, precomputeCovariances);
+ double _true = SemBicScorer.scoreDag(GraphTransforms.dagFromCpdag(trueGraph, null), dataModel, precomputeCovariances);
+ double est = SemBicScorer.scoreDag(GraphTransforms.dagFromCpdag(estGraph, null), dataModel, precomputeCovariances);
if (abs(_true) < 0.0001) _true = 0.0;
if (abs(est) < 0.0001) est = 0.0;
return (_true - est) / ((DataSet) dataModel).getNumRows();
}
+ /**
+ * Returns the normalized value of the statistic.
+ *
+ * @param value The value of the statistic.
+ * @return The normalized value of the statistic.
+ */
@Override
public double getNormValue(double value) {
return tanh(value / 1e6);
}
+ /**
+ * Returns true if the covariances are precomputed.
+ *
+ * @param precomputeCovariances True if the covariances are precomputed.
+ */
public void setPrecomputeCovariances(boolean precomputeCovariances) {
this.precomputeCovariances = precomputeCovariances;
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/BicEst.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/BicEst.java
index 2fc7024067..646b2d23ff 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/BicEst.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/BicEst.java
@@ -1,51 +1,110 @@
package edu.cmu.tetrad.algcomparison.statistic;
import edu.cmu.tetrad.data.DataModel;
+import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphTransforms;
+import edu.cmu.tetrad.graph.Node;
+import edu.cmu.tetrad.search.score.DiscreteBicScore;
import edu.cmu.tetrad.search.score.SemBicScorer;
+import java.io.Serial;
+import java.util.List;
+
import static org.apache.commons.math3.util.FastMath.tanh;
/**
- * Estimated BIC score.
+ * Estimated BIC score. The BIC is calculated as 2L - k ln N, so "higher is better."
*
* @author josephramsey
*/
public class BicEst implements Statistic {
+ @Serial
private static final long serialVersionUID = 23L;
- private double penaltyDiscount = 1.0;
private boolean precomputeCovariances = true;
+ /**
+ * No-arg constructor. Used for reflection; do not delete.
+ */
public BicEst() {
}
- public BicEst(double penaltyDiscount) {
- this.penaltyDiscount = penaltyDiscount;
- }
-
+ /**
+ * Returns the name of the statistic.
+ *
+ * @return the name of the statistic.
+ */
@Override
public String getAbbreviation() {
return "BicEst";
}
+ /**
+ * Returns the description of the statistic.
+ *
+ * @return the description of the statistic.
+ */
@Override
public String getDescription() {
return "BIC of the estimated CPDAG (depends only on the estimated DAG and the data)";
}
+ /**
+ * Returns the value of the statistic.
+ *
+ * @param trueGraph The true graph (DAG, CPDAG, PAG_of_the_true_DAG).
+ * @param estGraph The estimated graph (same type).
+ * @param dataModel The data model.
+ * @return The value of the statistic.
+ */
@Override
public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) {
-// double _true = SemBicScorer.scoreDag(SearchGraphUtils.dagFromCPDAG(trueGraph), dataModel);
- return SemBicScorer.scoreDag(GraphTransforms.dagFromCPDAG(estGraph, null), dataModel, precomputeCovariances);
+ if (dataModel.isDiscrete()) {
+ DiscreteBicScore score = new DiscreteBicScore((DataSet) dataModel);
+
+ Graph dag = GraphTransforms.dagFromCpdag(estGraph, null);
+ List nodes = dag.getNodes();
+
+ double _score = 0.0;
+
+ for (Node node : dag.getNodes()) {
+ score.setPenaltyDiscount(1);
+ int i = nodes.indexOf(node);
+ List parents = dag.getParents(node);
+ int[] parentIndices = new int[parents.size()];
+
+ for (Node parent : parents) {
+ parentIndices[parents.indexOf(parent)] = nodes.indexOf(parent);
+ }
+
+ _score += score.localScore(i, parentIndices);
+ }
+
+ return _score;
+ } else if (dataModel.isContinuous()) {
+ return SemBicScorer.scoreDag(GraphTransforms.dagFromCpdag(estGraph, null), dataModel, precomputeCovariances);
+ } else {
+ throw new IllegalArgumentException("Data must be either discrete or continuous");
+ }
}
+ /**
+ * Returns the normalized value of the statistic.
+ *
+ * @param value The value of the statistic.
+ * @return The normalized value of the statistic.
+ */
@Override
public double getNormValue(double value) {
return tanh(value / 1e6);
}
+ /**
+ * Returns the precompute covariances flag.
+ *
+ * @param precomputeCovariances The precompute covariances flag.
+ */
public void setPrecomputeCovariances(boolean precomputeCovariances) {
this.precomputeCovariances = precomputeCovariances;
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/BicTrue.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/BicTrue.java
index c41e272e5e..2933466d06 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/BicTrue.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/BicTrue.java
@@ -5,38 +5,68 @@
import edu.cmu.tetrad.graph.GraphTransforms;
import edu.cmu.tetrad.search.score.SemBicScorer;
+import java.io.Serial;
+
import static org.apache.commons.math3.util.FastMath.tanh;
/**
- * True BIC score.
+ * True BIC score. The BIC is calculated as 2L - k ln N, so "higher is better."
*
* @author josephramsey
*/
public class BicTrue implements Statistic {
+ @Serial
private static final long serialVersionUID = 23L;
private boolean precomputeCovariances = true;
+ /**
+ * No-arg constructor. Used for reflection; do not delete.
+ */
@Override
public String getAbbreviation() {
return "BicTrue";
}
+ /**
+ * Returns the description of the statistic.
+ *
+ * @return
+ */
@Override
public String getDescription() {
return "BIC of the true model";
}
+ /**
+ * Returns the value of the statistic.
+ *
+ * @param trueGraph The true graph (DAG, CPDAG, PAG_of_the_true_DAG).
+ * @param estGraph The estimated graph (same type).
+ * @param dataModel The data model.
+ * @return The value of the statistic.
+ */
@Override
public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) {
// double est = SemBicScorer.scoreDag(SearchGraphUtils.dagFromCPDAG(estGraph), dataModel);
- return SemBicScorer.scoreDag(GraphTransforms.dagFromCPDAG(trueGraph, null), dataModel, precomputeCovariances);
+ return SemBicScorer.scoreDag(GraphTransforms.dagFromCpdag(trueGraph, null), dataModel, precomputeCovariances);
}
+ /**
+ * Returns the normalized value of the statistic.
+ *
+ * @param value The value of the statistic.
+ * @return The normalized value of the statistic.
+ */
@Override
public double getNormValue(double value) {
return tanh(value);
}
+ /**
+ * Returns whether to precompute covariances.
+ *
+ * @param precomputeCovariances whether to precompute covariances.
+ */
public void setPrecomputeCovariances(boolean precomputeCovariances) {
this.precomputeCovariances = precomputeCovariances;
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/FBetaAdj.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/FBetaAdj.java
new file mode 100644
index 0000000000..059d4675d6
--- /dev/null
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/FBetaAdj.java
@@ -0,0 +1,56 @@
+package edu.cmu.tetrad.algcomparison.statistic;
+
+import edu.cmu.tetrad.algcomparison.statistic.utils.AdjacencyConfusion;
+import edu.cmu.tetrad.data.DataModel;
+import edu.cmu.tetrad.graph.Graph;
+
+/**
+ * Calculates the F1 statistic for adjacencies. See
+ *
+ * https://en.wikipedia.org/wiki/F1_score
+ *
+ * We use what's on this page called the "traditional" F1 statistic.
+ *
+ * @author Joseh Ramsey
+ */
+public class FBetaAdj implements Statistic {
+ private static final long serialVersionUID = 23L;
+
+ private double beta = 1;
+
+ @Override
+ public String getAbbreviation() {
+ return "FBetaAdj";
+ }
+
+ @Override
+ public String getDescription() {
+ return "FBeta statistic for adjacencies";
+ }
+
+ @Override
+ public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) {
+ AdjacencyConfusion adjConfusion = new AdjacencyConfusion(trueGraph, estGraph);
+ int adjTp = adjConfusion.getTp();
+ int adjFp = adjConfusion.getFp();
+ int adjFn = adjConfusion.getFn();
+ int adjTn = adjConfusion.getTn();
+ double adjPrecision = adjTp / (double) (adjTp + adjFp);
+ double adjRecall = adjTp / (double) (adjTp + adjFn);
+ return (1 + beta * beta) * (adjPrecision * adjRecall)
+ / (beta * beta * adjPrecision + adjRecall);
+ }
+
+ @Override
+ public double getNormValue(double value) {
+ return value;
+ }
+
+ public double getBeta() {
+ return beta;
+ }
+
+ public void setBeta(double beta) {
+ this.beta = beta;
+ }
+}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/MarkovAdequacyScore.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/MarkovAdequacyScore.java
deleted file mode 100644
index 5bafefbb19..0000000000
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/MarkovAdequacyScore.java
+++ /dev/null
@@ -1,46 +0,0 @@
-package edu.cmu.tetrad.algcomparison.statistic;
-
-import edu.cmu.tetrad.data.DataModel;
-import edu.cmu.tetrad.data.DataSet;
-import edu.cmu.tetrad.graph.Graph;
-import edu.cmu.tetrad.search.ConditioningSetType;
-import edu.cmu.tetrad.search.MarkovCheck;
-import edu.cmu.tetrad.search.test.IndTestFisherZ;
-
-/**
- * Tests whether the p-values under the null distribution are distributed as Uniform, and if so, returns the proportion
- * of judgements of dependence under the Alternative Hypothesis. If the p-values are not distributed as Uniform, zero is
- * returned.
- *
- * @author josephramsey
- */
-public class MarkovAdequacyScore implements Statistic {
- private static final long serialVersionUID = 23L;
- private double alpha = 0.05;
-
- @Override
- public String getAbbreviation() {
- return "MAS";
- }
-
- @Override
- public String getDescription() {
- return "Markov Adequacy Score (depends only on the estimated DAG and the data)";
- }
-
- @Override
- public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) {
- MarkovCheck markovCheck = new MarkovCheck(estGraph, new IndTestFisherZ((DataSet) dataModel, 0.01), ConditioningSetType.LOCAL_MARKOV);
- markovCheck.generateResults();
- return markovCheck.getMarkovAdequacyScore(alpha);
- }
-
- @Override
- public double getNormValue(double value) {
- return value;
- }
-
- public void setAlpha(double alpha) {
- this.alpha = alpha;
- }
-}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumParametersEst.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumParametersEst.java
new file mode 100644
index 0000000000..a09152bf0f
--- /dev/null
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumParametersEst.java
@@ -0,0 +1,74 @@
+package edu.cmu.tetrad.algcomparison.statistic;
+
+import edu.cmu.tetrad.data.DataModel;
+import edu.cmu.tetrad.data.DataSet;
+import edu.cmu.tetrad.graph.Graph;
+import edu.cmu.tetrad.graph.GraphTransforms;
+import edu.cmu.tetrad.graph.Node;
+import edu.cmu.tetrad.search.score.DiscreteBicScore;
+
+import java.io.Serial;
+import java.util.List;
+
+import static org.apache.commons.math3.util.FastMath.tanh;
+
+/**
+ * Number of parameters for a discrete Bayes model of the data. Must be for a discrete dataset.
+ *
+ * @author josephramsey
+ */
+public class NumParametersEst implements Statistic {
+
+ @Serial
+ private static final long serialVersionUID = 23L;
+
+ public NumParametersEst() {
+ }
+
+ @Override
+ public String getAbbreviation() {
+ return "NumParams";
+ }
+
+ @Override
+ public String getDescription() {
+ return "Number of parameters for the estimated graph for a Bayes or SEM model";
+ }
+
+ @Override
+ public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) {
+ if (dataModel.isDiscrete()) {
+ DiscreteBicScore score = new DiscreteBicScore((DataSet) dataModel);
+
+ Graph dag = GraphTransforms.dagFromCpdag(estGraph, null);
+ List nodes = dag.getNodes();
+
+ double params = 0.0;
+
+ for (Node node : dag.getNodes()) {
+ score.setPenaltyDiscount(1);
+ int i = nodes.indexOf(node);
+ List parents = dag.getParents(node);
+ int[] parentIndices = new int[parents.size()];
+
+ for (Node parent : parents) {
+ parentIndices[parents.indexOf(parent)] = nodes.indexOf(parent);
+ }
+
+ params += score.numParameters(i, parentIndices);
+ }
+
+ return params;
+ } else if (dataModel.isContinuous()) {
+ return estGraph.getNumEdges();
+ } else {
+ throw new IllegalArgumentException("Data must be discrete");
+ }
+ }
+
+ @Override
+ public double getNormValue(double value) {
+ return tanh(value / 1e6);
+ }
+}
+
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/utils/BidirectedConfusion.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/utils/BidirectedConfusion.java
index 452b4c6b01..7a78f8d24b 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/utils/BidirectedConfusion.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/utils/BidirectedConfusion.java
@@ -18,6 +18,12 @@ public class BidirectedConfusion {
private int fp;
private int fn;
+ /**
+ * Constructs a new confusion matrix for bidirected edges.
+ *
+ * @param truth The true graph.
+ * @param est The estimated graph.
+ */
public BidirectedConfusion(Graph truth, Graph est) {
this.tp = 0;
this.fp = 0;
@@ -56,18 +62,38 @@ public BidirectedConfusion(Graph truth, Graph est) {
this.tn = all - this.fn - this.fp - this.fn;
}
+ /**
+ * Returns the number of true positives.
+ *
+ * @return The number of true positives.
+ */
public int getTp() {
return this.tp;
}
+ /**
+ * Returns the number of false positives.
+ *
+ * @return The number of false positives.
+ */
public int getFp() {
return this.fp;
}
+ /**
+ * Returns the number of false negatives.
+ *
+ * @return The number of false negatives.
+ */
public int getFn() {
return this.fn;
}
+ /**
+ * Returns the number of true negatives.
+ *
+ * @return The number of true negatives.
+ */
public int getTn() {
return this.tn;
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/utils/TakesIndependenceWrapper.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/utils/TakesIndependenceWrapper.java
index 5c35194606..de9911ec94 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/utils/TakesIndependenceWrapper.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/utils/TakesIndependenceWrapper.java
@@ -3,12 +3,24 @@
import edu.cmu.tetrad.algcomparison.independence.IndependenceWrapper;
/**
- * Author : Jeremy Espino MD Created 7/13/17 2:25 PM
+ * Tags an algorithm as using an independence wrapper.
+ *
+ * @author Jeremy Espino MD Created 7/13/17 2:25 PM
*/
public interface TakesIndependenceWrapper {
+ /**
+ * Returns the independence wrapper.
+ *
+ * @return the independence wrapper.
+ */
IndependenceWrapper getIndependenceWrapper();
+ /**
+ * Sets the independence wrapper.
+ *
+ * @param independenceWrapper the independence wrapper.
+ */
void setIndependenceWrapper(IndependenceWrapper independenceWrapper);
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/utils/UsesScoreWrapper.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/utils/UsesScoreWrapper.java
index d03016493f..be853c9dc2 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/utils/UsesScoreWrapper.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/utils/UsesScoreWrapper.java
@@ -3,11 +3,22 @@
import edu.cmu.tetrad.algcomparison.score.ScoreWrapper;
/**
+ * Tags an algorithm as using a score wrapper.
+ *
* Author : Jeremy Espino MD Created 7/6/17 2:19 PM
*/
public interface UsesScoreWrapper {
+ /**
+ * Returns the score wrapper.
+ * @return the score wrapper.
+ */
ScoreWrapper getScoreWrapper();
+ /**
+ * Sets the score wrapper.
+ *
+ * @param score the score wrapper.
+ */
void setScoreWrapper(ScoreWrapper score);
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/annotation/AbstractAnnotations.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/annotation/AbstractAnnotations.java
index 2ff0a40347..0a4fd09d2c 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/annotation/AbstractAnnotations.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/annotation/AbstractAnnotations.java
@@ -75,7 +75,7 @@ public List> filterByAnnotation(List> annoCl
}
List> list = annoClasses.stream()
- .filter(e -> e.getClazz().isAnnotationPresent(type))
+ .filter(e -> e.clazz().isAnnotationPresent(type))
.collect(Collectors.toList());
return Collections.unmodifiableList(list);
@@ -93,7 +93,7 @@ public List> filterOutByAnnotation(List> ann
}
List> list = annoClasses.stream()
- .filter(e -> !e.getClazz().isAnnotationPresent(type))
+ .filter(e -> !e.clazz().isAnnotationPresent(type))
.collect(Collectors.toList());
return Collections.unmodifiableList(list);
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/annotation/AnnotatedClass.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/annotation/AnnotatedClass.java
index a56a602075..d5ea201e6f 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/annotation/AnnotatedClass.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/annotation/AnnotatedClass.java
@@ -18,6 +18,7 @@
*/
package edu.cmu.tetrad.annotation;
+import java.io.Serial;
import java.io.Serializable;
import java.lang.annotation.Annotation;
@@ -27,37 +28,37 @@
* @param annotation
* @author Kevin V. Bui (kvb2@pitt.edu)
*/
-public class AnnotatedClass implements Serializable {
+public record AnnotatedClass(Class clazz, T annotation) implements Serializable {
+ @Serial
private static final long serialVersionUID = 5060798016477163171L;
- private final Class clazz;
-
- private final T annotation;
-
/**
* Creates an annotated class.
- * @param clazz class
+ *
+ * @param clazz class
* @param annotation annotation
*/
- public AnnotatedClass(Class clazz, T annotation) {
- this.clazz = clazz;
- this.annotation = annotation;
+ public AnnotatedClass {
}
/**
* Gets the class.
+ *
* @return class
*/
- public Class getClazz() {
+ @Override
+ public Class clazz() {
return this.clazz;
}
/**
* Gets the annotation.
+ *
* @return annotation
*/
- public T getAnnotation() {
+ @Override
+ public T annotation() {
return this.annotation;
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/annotation/AnnotatedClassUtils.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/annotation/AnnotatedClassUtils.java
index ca4a8180d2..da14acc66a 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/annotation/AnnotatedClassUtils.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/annotation/AnnotatedClassUtils.java
@@ -64,7 +64,7 @@ public static List> filterByAnnotations
if (annotatedClasses != null && !annotatedClasses.isEmpty()) {
annotatedClasses.stream()
- .filter(e -> e.getClazz().isAnnotationPresent(annotation))
+ .filter(e -> e.clazz().isAnnotationPresent(annotation))
.collect(Collectors.toCollection(() -> list));
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BayesProperties.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BayesProperties.java
index 37a811b5ff..539b5d1033 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BayesProperties.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BayesProperties.java
@@ -41,7 +41,6 @@
public final class BayesProperties {
private final DataSet dataSet;
private final List variables;
- private final int[][] data;
private final int sampleSize;
private final int[] numCategories;
private double chisq;
@@ -49,6 +48,10 @@ public final class BayesProperties {
private double bic;
private double likelihood;
+ /**
+ * Constructs a new BayesProperties object for the given data set.
+ * @param dataSet The data set.
+ */
public BayesProperties(DataSet dataSet) {
if (dataSet == null) {
throw new NullPointerException();
@@ -56,6 +59,7 @@ public BayesProperties(DataSet dataSet) {
this.dataSet = dataSet;
+ int[][] data;
if (dataSet instanceof BoxDataSet) {
DataBox dataBox = ((BoxDataSet) dataSet).getDataBox();
@@ -63,19 +67,18 @@ public BayesProperties(DataSet dataSet) {
VerticalIntDataBox box = new VerticalIntDataBox(dataBox);
- this.data = box.getVariableVectors();
+ box.getVariableVectors();
} else {
- this.data = new int[dataSet.getNumColumns()][];
+ data = new int[dataSet.getNumColumns()][];
this.variables = dataSet.getVariables();
for (int j = 0; j < dataSet.getNumColumns(); j++) {
- this.data[j] = new int[dataSet.getNumRows()];
+ data[j] = new int[dataSet.getNumRows()];
for (int i = 0; i < dataSet.getNumRows(); i++) {
- this.data[j][i] = dataSet.getInt(i, j);
+ data[j][i] = dataSet.getInt(i, j);
}
}
-
}
this.sampleSize = dataSet.getNumRows();
@@ -103,6 +106,7 @@ private static int getRowIndex(int[] dim, int[] values) {
* Calculates the p-value of the graph with respect to the given data.
*
* @param graph The graph.
+ * @return The p-value.
*/
public LikelihoodRet getLikelihoodRatioP(Graph graph) {
@@ -282,90 +286,11 @@ private int getDof2(Graph graph) {
}
private Ret getLikelihoodNode(int node, int[] parents) {
-
DiscreteBicScore bic = new DiscreteBicScore(dataSet);
double lik = bic.localScore(node, parents);
int dof = (numCategories[node] - 1) * parents.length;
-// int d = ;
-//
-// for (int childValue = 0; childValue < c; childValue++) {
-// d++;
-// }
-
-
-// // Number of categories for node.
-// int c = this.numCategories[node];
-//
-// // Numbers of categories of parents.
-// int[] dims = new int[parents.length];
-//
-// for (int p = 0; p < parents.length; p++) {
-// dims[p] = this.numCategories[parents[p]];
-// }
-//
-// // Number of parent states.
-// int r = 1;
-//
-// for (int p = 0; p < parents.length; p++) {
-// r *= dims[p];
-// }
-//
-// // Conditional cell coefs of data for node given parents(node).
-// int[][] n_jk = new int[r][c];
-// int[] n_j = new int[r];
-//
-// int[] parentValues = new int[parents.length];
-//
-// int[][] myParents = new int[parents.length][];
-// for (int i = 0; i < parents.length; i++) {
-// myParents[i] = this.data[parents[i]];
-// }
-//
-// int[] myChild = this.data[node];
-//
-// for (int i = 0; i < this.sampleSize; i++) {
-// for (int p = 0; p < parents.length; p++) {
-// parentValues[p] = myParents[p][i];
-// }
-//
-// int childValue = myChild[i];
-//
-// if (childValue == -99) {
-// throw new IllegalStateException("Please remove or impute missing " +
-// "values (record " + i + " column " + i + ")");
-// }
-//
-// int rowIndex = BayesProperties.getRowIndex(dims, parentValues);
-//
-// n_jk[rowIndex][childValue]++;
-// n_j[rowIndex]++;
-// }
-
-// //Finally, compute the score
-// double lik = 0.0;
-// int dof = 0;
-//
-// for (int rowIndex = 0; rowIndex < r; rowIndex++) {
-// if (rowIndex == 0) continue;
-//
-// if (Thread.interrupted()) break;
-//
-// int d = 0;
-//
-// for (int childValue = 0; childValue < c; childValue++) {
-// int cellCount = n_jk[rowIndex][childValue];
-// int rowCount = n_j[rowIndex];
-//
-// if (cellCount == 0) continue;
-// lik += cellCount * FastMath.log(cellCount / (double) rowCount);
-// d++;
-// }
-//
-// if (d > 0) dof += c - 1;
-// }
-
return new Ret(lik, dof);
}
@@ -391,10 +316,16 @@ private double getDofNode(int node, int[] parents) {
return r * c;
}
+ /**
+ * Returns the number of categories for the given variable.
+ */
public int getSampleSize() {
return this.sampleSize;
}
+ /**
+ * Returns the variable with the given name (assumed the target).
+ */
public Node getVariable(String targetName) {
for (Node node : this.variables) {
if (node.getName().equals(targetName)) {
@@ -413,28 +344,63 @@ private DiscreteVariable getVariable(int i) {
}
}
+ /**
+ * Returns the likelihood ratio test statistic for the given graph and its degrees of freedom.
+ */
private static class Ret {
private final double lik;
private final int dof;
+ /**
+ * Constructs a new Ret object.
+ * @param lik The likelihood.
+ * @param dof The degrees of freedom.
+ */
public Ret(double lik, int dof) {
this.lik = lik;
this.dof = dof;
}
+ /**
+ * Returns the likelihood.
+ * @return The likelihood.
+ */
public double getLik() {
return this.lik;
}
+ /**
+ * Returns the degrees of freedom.
+ * @return The degrees of freedom.
+ */
public int getDof() {
return this.dof;
}
}
- public class LikelihoodRet {
+ /**
+ * Returns the number of categories for the given variable.
+ */
+ public static class LikelihoodRet {
+
+ /**
+ * The p-value.
+ */
public double p;
+
+ /**
+ * The BIC.
+ */
public double bic;
+
+ /**
+ * The chi-squared statistic.
+ */
public double chiSq;
+
+ /**
+ * The degrees of freedom.
+ */
public double dof;
}
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BayesUpdater.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BayesUpdater.java
index 4dbfa13c2f..3796d0a127 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BayesUpdater.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BayesUpdater.java
@@ -28,7 +28,7 @@
* evidence), where evidence takes the form of a Proposition over the variables in the Bayes net, possibly with
* additional information about which variables in the Bayes net have been manipulated. Some updaters may be able to
* calculate joint marginals as well--that is, P(AND_i{Xi = xi'} | evidence). Also, not all updaters can take
- * manipulation information into account. See implementations for details.
+ * manipulation information into account. See implementations for details.)
*
* @author josephramsey
* @see Evidence
@@ -36,6 +36,9 @@
* @see Manipulation
*/
public interface BayesUpdater extends TetradSerializable {
+ /**
+ * Serial version ID for serialization.
+ */
long serialVersionUID = 23L;
/**
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BayesXmlParser.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BayesXmlParser.java
index 39fd3c60c1..8ec2e7fc44 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BayesXmlParser.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BayesXmlParser.java
@@ -106,6 +106,12 @@ private static BayesIm makeBayesIm(BayesPm bayesPm, Element element2) {
return bayesIm;
}
+ /**
+ * Returns the BayesIm object represented by the given element.
+ *
+ * @param element the element
+ * @return the BayesIm object
+ */
public BayesIm getBayesIm(Element element) {
if (!"bayesNet".equals(element.getQualifiedName())) {
throw new IllegalArgumentException("Expecting 'bayesNet' element.");
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BayesXmlRenderer.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BayesXmlRenderer.java
index 2644e376eb..8a2b7b09fa 100755
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BayesXmlRenderer.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BayesXmlRenderer.java
@@ -36,6 +36,12 @@
*/
public final class BayesXmlRenderer {
+ /**
+ * Private constructor to prevent instantiation.
+ *
+ * @param bayesIm the Bayes net
+ * @return the XML element
+ */
public static Element getElement(BayesIm bayesIm) {
if (bayesIm == null) {
throw new NullPointerException();
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BdeMetricCache.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BdeMetricCache.java
index 4f1541073d..e016038a43 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BdeMetricCache.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BdeMetricCache.java
@@ -34,32 +34,36 @@
/**
* Provides a method for computing the score of a model, called the BDe
* metric (Bayesian Dirchlet likelihood equivalence), given a dataset (assumes no missing values) and a Bayes
- * parameterized network (assumes no latent variables).> 0
This version has a method that computes the score for a
- * given factor of a model, where a factor is determined by a node and its parents. It stores scores in a map whose
- * argument is an ordered pair consisting of 1) a node and 2) set of parents. The score for the entire model is the
- * product of the scores of its factors. Since the log of the gamma function is used here the sum of the logs is
- * computed as the score. Compare this with the score method in the BdeMetric class which computes the score for the
- * entire model in one pass. The advantage of the approach in this class is that it is more efficient in the context of
- * a search algorithm where different models are scored but where many of them will have the same factors. This class
- * stores the score (relative to the dataset) for any [node, set of parents] pair and thus avoids the expensive log
- * gamma function calls. Instead it looks in the map scores to see if it has already computed the score and, if so,
- * returns the previously computed value.> 0
See "Learning Bayesian Networks: The Combination of Knowledge and
- * Statistical Data" by David Heckerman, Dan Geiger, and David M. Chickering. Microsoft Technical Report
- * MSR-TR-94-09.> 0
+ * parameterized network (assumes no latent variables).> 0
+ *
+ * This version has a method that computes the score for a given factor of a model, where a factor is determined by a
+ * node and its parents. It stores scores in a map whose argument is an ordered pair consisting of 1) a node and 2) set
+ * of parents. The score for the entire model is the product of the scores of its factors. Since the log of the gamma
+ * function is used here the sum of the logs is computed as the score. Compare this with the score method in the
+ * BdeMetric class which computes the score for the entire model in one pass. The advantage of the approach in this
+ * class is that it is more efficient in the context of a search algorithm where different models are scored but where
+ * many of them will have the same factors. This class stores the score (relative to the dataset) for any [node, set of
+ * parents] pair and thus avoids the expensive log gamma function calls. Instead, it looks in the map scores to see if
+ * it has already computed the score and, if so, returns the previously computed value.> 0
See "Learning Bayesian
+ * Networks: The Combination of Knowledge and Statistical Data" by David Heckerman, Dan Geiger, and David M.
+ * Chickering. Microsoft Technical Report MSR-TR-94-09.> 0
*
* @author Frank Wimberly
*/
public final class BdeMetricCache {
private final DataSet dataSet;
private final List variables;
-
private final BayesPm bayesPm; //Determines the list of variables (nodes)
-
private final Map scores;
private final Map scoreCounts;
-
private double[][] observedCounts;
+ /**
+ * Constructs a BdeMetricCache object for a given dataset and BayesPm.
+ *
+ * @param dataSet The dataset for which the BDe metric is to be computed.
+ * @param bayesPm The BayesPm that determines the list of variables (nodes) and the structure of the graph.
+ */
public BdeMetricCache(DataSet dataSet, BayesPm bayesPm) {
this.bayesPm = bayesPm;
this.dataSet = dataSet;
@@ -71,6 +75,12 @@ public BdeMetricCache(DataSet dataSet, BayesPm bayesPm) {
/**
* Computes the BDe score, using the logarithm of the gamma function, relative to the data, of the factor determined
* by a node and its parents.
+ *
+ * @param node The node of the factor.
+ * @param parents The parents of the node.
+ * @param bayesPmMod The BayesPm that determines the list of variables (nodes) and the structure of the graph.
+ * @param bayesIm The BayesIm that determines the observed counts.
+ * @return The score of the factor.
*/
public double scoreLnGam(Node node, Set parents, BayesPm bayesPmMod,
BayesIm bayesIm) {
@@ -441,6 +451,11 @@ private int getVarIndex(String name) {
/**
* This method is used in testing and debugging and not in the BDe metric calculations.
+ *
+ * @param node The node for which the observed counts are to be returned.
+ * @param bayesPm The BayesPm that determines the list of variables (nodes) and the structure of the graph.
+ * @param bayesIm The BayesIm that determines the observed counts.
+ * @return The observed counts for the given node.
*/
public double[][] getObservedCounts(Node node, BayesPm bayesPm,
BayesIm bayesIm) {
@@ -461,6 +476,10 @@ public double[][] getObservedCounts(Node node, BayesPm bayesPm,
/**
* This is just for testing the operation of the inner class and the map from nodes and parent sets to scores.
+ *
+ * @param node The node of the factor.
+ * @param parents The parents of the node.
+ * @return The score of the factor.
*/
public int getScoreCount(Node node, Set parents) {
NodeParentsPair nodeParents = new NodeParentsPair(node, parents);
@@ -488,16 +507,29 @@ private static final class NodeParentsPair {
private final Node node;
private final Set parents;
+
+ /**
+ * Constructs a NodeParentsPair object for a given node and set of parents.
+ *
+ * @param node The node of the pair.
+ * @param parents The parents of the node.
+ */
public NodeParentsPair(Node node, Set parents) {
this.node = node;
this.parents = parents;
}
+ /**
+ * @return The number of elements in the set of parents plus 1.
+ */
public int calcCount() {
return this.parents.size() + 1;
}
+ /**
+ * @return The node of the pair.
+ */
public int hashCode() {
int hash = 91;
hash = 43 * hash + this.node.hashCode();
@@ -506,6 +538,12 @@ public int hashCode() {
return hash;
}
+ /**
+ * Equals method for NodeParentsPair.
+ *
+ * @param other The other object to compare to.
+ * @return True if the other object is a NodeParentsPair and has the same node and parents as this one.
+ */
public boolean equals(Object other) {
if (other == this) {
return true;
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java
index 77003ff05c..e58a394afb 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java
@@ -31,6 +31,7 @@
import java.io.IOException;
import java.io.ObjectInputStream;
+import java.io.Serial;
import java.text.NumberFormat;
import java.util.*;
@@ -76,8 +77,10 @@ public final class MlBayesIm implements BayesIm {
* Indicates that new rows in this BayesIm should be initialized randomly.
*/
public static final int RANDOM = 1;
+ @Serial
private static final long serialVersionUID = 23L;
private static final double ALLOWABLE_DIFFERENCE = 1.0e-3;
+ static private final Random random = new Random();
/**
* The associated Bayes PM model.
@@ -212,66 +215,19 @@ public static List getParameterNames() {
return new ArrayList<>();
}
- private static double[] getRandomWeights2(int size, double[] biases) {
- assert size >= 0;
-
- double[] row = new double[size];
- double sum = 0.0;
-
- for (int i = 0; i < size; i++) {
-// row[i] = RandomUtil.getInstance().nextDouble() + biases[i];
- double v = RandomUtil.getInstance().nextUniform(0, biases[i]);
- row[i] = v > 0.5 ? 2 * v : v;
- sum += row[i];
- }
-
- for (int i = 0; i < size; i++) {
- row[i] /= sum;
- }
-
- return row;
- }
-
- private static double[] getRandomWeights3(int size) {
- assert size >= 0;
-
- double[] row = new double[size];
- double sum = 0.0;
-
- for (int i = 0; i < size; i++) {
- row[i] = RandomUtil.getInstance().nextBeta(size / 4d, size);
- sum += row[i];
- }
-
- for (int i = 0; i < size; i++) {
- row[i] /= sum;
- }
-
- return row;
- }
-
- /**
- * This method chooses random probabilities for a row which add up to 1.0. Random doubles are drawn from a random
- * distribution, and the final row is then normalized.
- *
- * @param size the length of the row.
- * @return an array with randomly distributed probabilities of this length.
- * @see #randomizeRow
- */
private static double[] getRandomWeights(int size) {
- assert size >= 0;
+ assert size > 0;
double[] row = new double[size];
double sum = 0.0;
- // Renders rows more deterministic.
- final double bias = 0;
+ int strong = (int) Math.floor(random.nextDouble() * size);
for (int i = 0; i < size; i++) {
- row[i] = RandomUtil.getInstance().nextDouble();
-
- if (row[i] > 0.5) {
- row[i] += bias;
+ if (i == strong) {
+ row[i] = 1.0;
+ } else {
+ row[i] = RandomUtil.getInstance().nextDouble() * 0.1;
}
sum += row[i];
@@ -598,12 +554,7 @@ public void clearRow(int nodeIndex, int rowIndex) {
*/
public void randomizeRow(int nodeIndex, int rowIndex) {
int size = getNumColumns(nodeIndex);
- this.probs[nodeIndex][rowIndex] = MlBayesIm.getRandomWeights3(size);
- }
-
- private void randomizeRow2(int nodeIndex, int rowIndex, double[] biases) {
- int size = getNumColumns(nodeIndex);
- this.probs[nodeIndex][rowIndex] = MlBayesIm.getRandomWeights2(size, biases);
+ this.probs[nodeIndex][rowIndex] = MlBayesIm.getRandomWeights(size);
}
/**
@@ -628,39 +579,6 @@ public void randomizeTable(int nodeIndex) {
for (int rowIndex = 0; rowIndex < getNumRows(nodeIndex); rowIndex++) {
randomizeRow(nodeIndex, rowIndex);
}
-// randomizeTable4(nodeIndex);
- }
-
- private void randomizeTable4(int nodeIndex) {
- for (int rowIndex = 0; rowIndex < getNumRows(nodeIndex); rowIndex++) {
- randomizeRow(nodeIndex, rowIndex);
- }
-
- double[][] saved = new double[getNumRows(nodeIndex)][getNumColumns(nodeIndex)];
-
- double max = Double.NEGATIVE_INFINITY;
-
- for (int i = 0; i < 10; i++) {
- for (int rowIndex = 0; rowIndex < getNumRows(nodeIndex); rowIndex++) {
-// randomizeRow(nodeIndex, rowIndex);
- randomizeRow2(nodeIndex, rowIndex, this.probs[nodeIndex][rowIndex]);
- }
-
- int score = score(nodeIndex);
-
- if (score > max) {
- max = score;
- copy(this.probs[nodeIndex], saved);
- }
-
- if (score == getNumParents(nodeIndex)) {
- break;
- }
- }
-
- for (int rowIndex = 0; rowIndex < getNumRows(nodeIndex); rowIndex++) {
- copy(saved, this.probs[nodeIndex]);
- }
}
private int score(int nodeIndex) {
@@ -1103,12 +1021,10 @@ public boolean equals(Object o) {
return true;
}
- if (!(o instanceof BayesIm)) {
+ if (!(o instanceof BayesIm otherIm)) {
return false;
}
- BayesIm otherIm = (BayesIm) o;
-
if (getNumNodes() != otherIm.getNumNodes()) {
return false;
}
@@ -1445,6 +1361,7 @@ private void copyValuesFromOldToNew(int oldNodeIndex, int oldRowIndex,
* class that didn't include it. (That's what the "s.defaultReadObject();" is for. See J. Bloch, Effective Java, for
* help.
*/
+ @Serial
private void readObject(ObjectInputStream s)
throws IOException, ClassNotFoundException {
s.defaultReadObject();
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/BoxDataSet.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/BoxDataSet.java
index 450149b03c..dd2c7268ad 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/BoxDataSet.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/BoxDataSet.java
@@ -124,6 +124,12 @@ public final class BoxDataSet implements DataSet {
*/
private char outputDelimiter = '\t';
+ /**
+ * Constructs a new data set with the given number of rows and columns, with all values set to missing.
+ *
+ * @param dataBox The data box.
+ * @param variables The variables.
+ */
public BoxDataSet(DataBox dataBox, List variables) {
this.dataBox = dataBox;
this.variables = new ArrayList<>(variables);
@@ -135,6 +141,8 @@ public BoxDataSet(DataBox dataBox, List variables) {
/**
* Makes of copy of the given data set.
+ *
+ * @param dataSet The data set to copy.
*/
public BoxDataSet(BoxDataSet dataSet) {
this.name = dataSet.name;
@@ -147,6 +155,8 @@ public BoxDataSet(BoxDataSet dataSet) {
/**
* Generates a simple exemplar of this class to test serialization.
+ *
+ * @return A simple exemplar of this class.
*/
public static BoxDataSet serializableInstance() {
List vars = new ArrayList<>();
@@ -1333,6 +1343,11 @@ public NumberFormat getNumberFormat() {
return this.nf;
}
+ /**
+ * Sets the number format to be used when printing out the data set. The default is the one at
+ *
+ * @param nf The number format to be used when printing out the data set. The default is the one at
+ */
public void setNumberFormat(NumberFormat nf) {
if (nf == null) {
throw new NullPointerException();
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/ByteDataBox.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/ByteDataBox.java
index 3e777c5990..58ae752447 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/ByteDataBox.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/ByteDataBox.java
@@ -23,6 +23,7 @@
import edu.cmu.tetrad.graph.Node;
+import java.io.Serial;
import java.util.ArrayList;
import java.util.List;
@@ -30,6 +31,7 @@
* Stores a 2D array of byte data. Note that the missing value marker for this box is -99.
*/
public class ByteDataBox implements DataBox {
+ @Serial
private static final long serialVersionUID = 23L;
/**
@@ -50,6 +52,9 @@ public class ByteDataBox implements DataBox {
/**
* Constructs an 2D byte array consisting entirely of missing values (-99).
+ *
+ * @param rows the number of rows.
+ * @param cols the number of columns.
*/
public ByteDataBox(int rows, int cols) {
this.data = new byte[rows][cols];
@@ -66,6 +71,8 @@ public ByteDataBox(int rows, int cols) {
/**
* Constructs a new data box using the given 2D byte data array as data.
+ *
+ * @param data the data to use.
*/
public ByteDataBox(byte[][] data) {
int length = data[0].length;
@@ -84,6 +91,8 @@ public ByteDataBox(byte[][] data) {
/**
* Generates a simple exemplar of this class to test serialization.
+ *
+ * @return a simple exemplar of this class to test serialization.
*/
public static BoxDataSet serializableInstance() {
List vars = new ArrayList<>();
@@ -106,7 +115,12 @@ public int numCols() {
}
/**
- * Sets the value at the given row/column to the given Number value. The value used is number.byteValue().
+ * Sets the value at the given row/column to the given Number value. The value used is number.byteValue(). If the
+ * value is null, the missing value marker (-99) is used.
+ *
+ * @param row the row index.
+ * @param col the column index.
+ * @param value the value to store.
*/
public void set(int row, int col, Number value) {
if (value == null) {
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/CellTable.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/CellTable.java
index 58d59bdd6f..472e204ad3 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/CellTable.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/CellTable.java
@@ -23,7 +23,9 @@
import edu.cmu.tetrad.util.MultiDimIntTable;
+import java.util.ArrayList;
import java.util.Arrays;
+import java.util.List;
/**
@@ -35,17 +37,12 @@
*/
public final class CellTable {
-
-// /**
-// * Stores a copy of coordinates for temporary use. (Reused.)
-// */
-// private int[] coordCopy;
-
private final MultiDimIntTable table;
/**
* The value used in the data for missing values.
*/
private int missingValue = -99;
+ private List rows;
/**
* Constructs a new cell table using the given array for dimensions, initializing all cells in the table to zero.
@@ -57,11 +54,22 @@ public CellTable(int[] dims) {
}
public void addToTable(DataSet dataSet, int[] indices) {
+ if (rows == null) {
+ rows = new ArrayList<>();
+ for (int i = 0; i < dataSet.getNumRows(); i++) {
+ rows.add(i);
+ }
+ } else {
+ for (int i = 0; i < rows.size(); i++) {
+ if (rows.get(i) >= dataSet.getNumRows())
+ throw new IllegalArgumentException("Row " + i + " is too large.");
+ }
+ }
+
int[] dims = new int[indices.length];
for (int i = 0; i < indices.length; i++) {
- DiscreteVariable variable =
- (DiscreteVariable) dataSet.getVariable(indices[i]);
+ DiscreteVariable variable = (DiscreteVariable) dataSet.getVariable(indices[i]);
dims[i] = variable.getNumCategories();
}
@@ -70,12 +78,11 @@ public void addToTable(DataSet dataSet, int[] indices) {
int[] coords = new int[indices.length];
points:
- for (int i = 0; i < dataSet.getNumRows(); i++) {
+ for (int i : rows) {
for (int j = 0; j < indices.length; j++) {
try {
coords[j] = dataSet.getInt(i, indices[j]);
} catch (Exception e) {
- e.printStackTrace(); //To change body of catch statement use File | Settings | File Templates.
coords[j] = dataSet.getInt(i, j);
}
@@ -108,7 +115,7 @@ public int getNumValues(int varIndex) {
public long calcMargin(int[] coords) {
int[] coordCopy = internalCoordCopy(coords);
- int sum = 0;
+ long sum = 0;
int i = -1;
while (++i < coordCopy.length) {
@@ -129,8 +136,8 @@ public long calcMargin(int[] coords) {
/**
* An alternative way to specify a marginal calculation. In this case, coords specifies a particular cell in the
* table, and varIndices is an array containing the indices of the variables over which the margin sum should be
- * calculated. The sum is over the cell specified by 'coord' and all of the cells which differ from that cell in any
- * of the specified coordinates.
+ * calculated. The sum is over the cell specified by 'coord' and all the cells which differ from that cell in any of
+ * the specified coordinates.
*
* @param coords an int[]
value
* @param marginVars an int[]
value
@@ -150,16 +157,7 @@ public long calcMargin(int[] coords, int[] marginVars) {
* Makes a copy of the coordinate array so that the original is not messed up.
*/
private int[] internalCoordCopy(int[] coords) {
- int[] coordCopy = Arrays.copyOf(coords, coords.length);
-
-// if ((this.coordCopy == null) ||
-// (this.coordCopy.length != coords.length)) {
-// this.coordCopy = new int[coords.length];
-// }
-//
-// System.arraycopy(coords, 0, this.coordCopy, 0, coords.length);
-
- return coordCopy;
+ return Arrays.copyOf(coords, coords.length);
}
private int getMissingValue() {
@@ -173,6 +171,19 @@ public void setMissingValue(int missingValue) {
public long getValue(int[] testCell) {
return this.table.getValue(testCell);
}
+
+ public void setRows(List rows) {
+ if (rows == null) {
+ this.rows = null;
+ } else {
+ for (int i = 0; i < rows.size(); i++) {
+ if (rows.get(i) == null) throw new NullPointerException("Row " + i + " is null.");
+ if (rows.get(i) < 0) throw new IllegalArgumentException("Row " + i + " is negative.");
+ }
+
+ this.rows = rows;
+ }
+ }
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/DataBox.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/DataBox.java
index ffb1ca67c7..28f117f7a8 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/DataBox.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/DataBox.java
@@ -46,12 +46,17 @@ public interface DataBox extends TetradSerializable {
* Sets the value at the given row and column to the given Number. This number may be interpreted differently
* depending on how values are stored. A value of null is interpreted as a missing value.
*
+ * @param row the row index.
+ * @param col the column index.
+ * @param value the value to store.
* @throws IllegalArgumentException if the given value cannot be stored (because it's out of range or cannot be
* converted or whatever).
*/
void set(int row, int col, Number value) throws IllegalArgumentException;
/**
+ * @param row the row index.
+ * @param col the column index.
* @return the value at the given row and column as a Number. If the value is missing, null is uniformly returned.
*/
Number get(int row, int col);
@@ -62,10 +67,17 @@ public interface DataBox extends TetradSerializable {
DataBox copy();
/**
+ * @param rows the row indices.
+ * @param cols the column indices.
* @return this data box, restricted to the given rows and columns.
*/
DataBox viewSelection(int[] rows, int[] cols);
+ /**
+ * Returns a data box of the same dimensions as this one, without setting any values.
+ *
+ * @return a new data box.
+ */
DataBox like();
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/DataFilter.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/DataFilter.java
index 6d15b59482..c8a342b506 100755
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/DataFilter.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/DataFilter.java
@@ -32,6 +32,9 @@ public interface DataFilter {
/**
* Interpolates the given data set, producing a data set with no missing values.
+ *
+ * @param dataSet the data set to interpolate.
+ * @return the interpolated data set.
*/
DataSet filter(DataSet dataSet);
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/DataSet.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/DataSet.java
index 0a5e5e0efa..bc7d3dc3d1 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/DataSet.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/DataSet.java
@@ -40,12 +40,16 @@ public interface DataSet extends DataModel {
/**
* Adds the given variable to the data set.
*
+ * @param variable The variable to add.
* @throws IllegalArgumentException if the variable is neither continuous nor discrete.
*/
void addVariable(Node variable);
/**
* Adds the given variable at the given index.
+ *
+ * @param index The index at which to add the variable.
+ * @param variable The variable to add.
*/
void addVariable(int index, Node variable);
@@ -53,6 +57,8 @@ public interface DataSet extends DataModel {
* Changes the variable for the given column from from
to
* to
. Supported currently only for discrete variables.
*
+ * @param from The variable to change.
+ * @param to The variable to change to.
* @throws IllegalArgumentException if the given change is not supported.
*/
void changeVariable(Node from, Node to);
@@ -64,22 +70,28 @@ public interface DataSet extends DataModel {
/**
* Ensures that the dataset has at least columns
columns. Used for pasting data into the dataset. When
- * creating new columns, names in the excludedVarialbeNames
list may not be used. The purpose of this
+ * creating new columns, names in the excludedVariableNames
list may not be used. The purpose of this
* is to allow these names to be set later by the calling class, without incurring conflicts.
+ *
+ * @param columns The number of columns to ensure.
+ * @param excludedVariableNames The names of variables that should not be used for new columns.
*/
void ensureColumns(int columns, List excludedVariableNames);
/**
- * Returns true if and only if this data set contains at least one missing value.
+ * @return true if and only if this data set contains at least one missing value.
*/
boolean existsMissingValue();
/**
* Ensures that the dataset has at least rows
rows. Used for pasting data into the dataset.
+ *
+ * @param rows The number of rows to ensure.
*/
void ensureRows(int rows);
/**
+ * @param variable The variable to check.
* @return the column index of the given variable.
*/
int getColumn(Node variable);
@@ -99,6 +111,8 @@ public interface DataSet extends DataModel {
Matrix getCovarianceMatrix();
/**
+ * @param row The index of the case.
+ * @param column The index of the variable.
* @return the value at the given row and column as a double. For discrete data, returns the integer value cast to a
* double.
*/
@@ -147,11 +161,13 @@ public interface DataSet extends DataModel {
int[] getSelectedIndices();
/**
+ * @param column The index of the variable.
* @return the variable at the given column.
*/
Node getVariable(int column);
/**
+ * @param name The name of the variable.
* @return the variable with the given name.
*/
Node getVariable(String name);
@@ -167,45 +183,54 @@ public interface DataSet extends DataModel {
List getVariables();
/**
- * @return true if this is a continuous data set--that is, if it contains at least one column and all of the columns
+ * @return true if this is a continuous data set--that is, if it contains at least one column and all the columns
* are continuous.
*/
boolean isContinuous();
/**
- * @return true if this is a discrete data set--that is, if it contains at least one column and all of the columns
- * are discrete.
+ * @return true if this is a discrete data set--that is, if it contains at least one column and all the columns are
+ * discrete.
*/
boolean isDiscrete();
/**
* @return true if this is a continuous data set--that is, if it contains at least one continuous column and one
- * discrete columnn.
+ * discrete column.
*/
boolean isMixed();
/**
+ * @param variable The variable to check.
* @return true iff the given column has been marked as selected.
*/
boolean isSelected(Node variable);
/**
* Removes the variable (and data) at the given index.
+ *
+ * @param index The index of the variable to remove.
*/
void removeColumn(int index);
/**
* Removes the given variable, along with all of its data.
+ *
+ * @param variable The variable to remove.
*/
void removeColumn(Node variable);
/**
* Removes the given columns from the data set.
+ *
+ * @param selectedCols The indices of the columns to remove.
*/
void removeCols(int[] selectedCols);
/**
* Removes the given rows from the data set.
+ *
+ * @param selectedRows The indices of the rows to remove.
*/
void removeRows(int[] selectedRows);
@@ -215,6 +240,7 @@ public interface DataSet extends DataModel {
*
* @param row The index of the case.
* @param column The index of the variable.
+ * @param value The value to set.
*/
void setDouble(int row, int column, double value);
@@ -222,8 +248,9 @@ public interface DataSet extends DataModel {
* Sets the value at the given (row, column) to the given int value, assuming the variable for the column is
* discrete.
*
- * @param row The index of the case.
- * @param col The index of the variable.
+ * @param row The index of the case.
+ * @param col The index of the variable.
+ * @param value The value to set.
*/
void setInt(int row, int col, int value);
@@ -246,17 +273,21 @@ public interface DataSet extends DataModel {
* Creates and returns a dataset consisting of those variables in the list vars. Vars must be a subset of the
* variables of this DataSet. The ordering of the elements of vars will be the same as in the list of variables in
* this DataSet.
+ *
+ * @return a new data set consisting of the variables in the list vars.
*/
DataSet subsetColumns(List vars);
/**
- * @return a new data set in which the the column at indices[i] is placed at index i, for i = 0 to indices.length -
- * 1. (View instead?)
+ * @param columns The indices of the columns to include in the new data set.
+ * @return a new data set in which the column at indices[i] is placed at index i, for i = 0 to indices.length - 1.
+ * (View instead?)
*/
DataSet subsetColumns(int[] columns);
/**
- * @return a new data set in which the the row at indices[i] is placed at index i, for i = 0 to indices.length - 1.
+ * @param rows The indices of the rows to include in the new data set.
+ * @return a new data set in which the row at indices[i] is placed at index i, for i = 0 to indices.length - 1.
* (View instead?)
*/
DataSet subsetRows(int[] rows);
@@ -268,16 +299,22 @@ public interface DataSet extends DataModel {
/**
* The number format of the dataset.
+ *
+ * @return The number format of the dataset.
*/
NumberFormat getNumberFormat();
/**
* The number formatter used to print out continuous values.
+ *
+ * @param nf The number formatter used to print out continuous values.
*/
void setNumberFormat(NumberFormat nf);
/**
- * The character used a delimiter when the dataset is output.
+ * The character used a delimiter when the dataset is output
+ *
+ * @param character The character used as a delimiter when the dataset is output
*/
void setOutputDelimiter(Character character);
@@ -286,12 +323,33 @@ public interface DataSet extends DataModel {
*/
void permuteRows();
+ /**
+ * Returns the map of column names to tooltips.
+ *
+ * @return The map of column names to tooltips.
+ */
Map getColumnToTooltip();
+ /**
+ * Checks if the given object is equal to this dataset.
+ *
+ * @param o The object to check.
+ * @return True if the given object is equal to this dataset.
+ */
boolean equals(Object o);
+ /**
+ * Returns a copy of this dataset.
+ *
+ * @return A copy of this dataset.
+ */
DataSet copy();
+ /**
+ * Returns a dataset with the same dimensions as this dataset, but with no data.
+ *
+ * @return a dataset with the same dimensions as this dataset, but with no data.
+ */
DataSet like();
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/GeneralAndersonDarlingTest.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/GeneralAndersonDarlingTest.java
index a380fc86e4..76b101fa10 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/GeneralAndersonDarlingTest.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/GeneralAndersonDarlingTest.java
@@ -21,12 +21,17 @@
package edu.cmu.tetrad.data;
+import edu.cmu.tetrad.util.RandomUtil;
import org.apache.commons.math3.distribution.RealDistribution;
+import org.apache.commons.math3.distribution.UniformRealDistribution;
import org.apache.commons.math3.util.FastMath;
+import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
+import static org.apache.commons.math3.util.FastMath.*;
+
/**
* Implements the Anderson-Darling test against the given CDF, with P values calculated as in R's ad.test method (in
* package nortest).
@@ -118,23 +123,79 @@ private void runTest() {
}
double a = -n - (1.0 / numSummed) * h;
- double aa = (1 + 0.75 / numSummed + 2.25 / FastMath.pow(numSummed, 2)) * a;
+ double aa = (1 + 0.75 / numSummed + 2.25 / pow(numSummed, 2)) * a;
double p;
if (aa < 0.2) {
- p = 1 - FastMath.exp(-13.436 + 101.14 * aa - 223.73 * aa * aa);
+ p = 1 - exp(-13.436 + 101.14 * aa - 223.73 * aa * aa);
} else if (aa < 0.34) {
- p = 1 - FastMath.exp(-8.318 + 42.796 * aa - 59.938 * aa * aa);
+ p = 1 - exp(-8.318 + 42.796 * aa - 59.938 * aa * aa);
} else if (aa < 0.6) {
- p = FastMath.exp(0.9177 - 4.279 * aa - 1.38 * aa * aa);
+ p = exp(0.9177 - 4.279 * aa - 1.38 * aa * aa);
} else {
- p = FastMath.exp(1.2937 - 5.709 * aa + 0.0186 * aa * aa);
+ p = exp(1.2937 - 5.709 * aa + 0.0186 * aa * aa);
}
this.aSquared = a;
this.aSquaredStar = aa;
this.p = p;
}
+
+ private double c(double n) {
+ return .01265 + .1757 / n;
+ }
+
+ private double g1(double x) {
+ return sqrt(x) * (1 - x) * (49 * x - 102);
+ }
+
+ private double g2(double x) {
+ return -.00022633 + (6.54034 - (14.6538 - (14.458 - (8.259 - 1.91864 * x) * x) * x) * x) * x;
+ }
+
+ private double g3(double x) {
+ return -130.2137 + (745.2337 - (1705.091 - (1950.646 - (1116.360 - 255.7844 * x) * x) * x) * x) * x;
+ }
+
+ private double errfix(double n, double x) {
+ if (x < c(n)) {
+ return (.0037 / pow(n, 3) + .00078 / pow(n, 2) + .00006 / n) * g1(x / c(n));
+ } else if (x < .8) {
+ return (.04213 / n + .01365 / pow(n , 2)) * g2((x - c(n)) / (.8 - c(n)));
+ } else {
+ return g3(x) / n;
+ }
+ }
+
+ private double adinf(double z) {
+ if (0 < z && z < 2) {
+ return pow(z, -0.5) * exp(-1.2337141 / z) * (2.00012 + (0.247105 - (.0649821 - (.0347962 - (.0116720 - .00168691 * z) * z) * z) * z) * z);
+ } else if (z >= 2) {
+ return exp( -exp(1.0776 - (2.30695 - (.43424 - (.082433 - (.008056 - .0003146 * z) * z) * z) * z) * z));
+ } else {
+ return 0;
+ }
+ }
+
+ public double getProbTail(double n, double z) {
+ return adinf(z) + errfix(n, adinf(z));
+ }
+
+ public static void main(String[] args) {
+ List data = new ArrayList<>();
+
+ for (int i = 0; i < 500; i++) {
+// data.add(RandomUtil.getInstance().nextUniform(0, 1));
+ data.add(RandomUtil.getInstance().nextBeta(2, 5));
+ }
+
+ GeneralAndersonDarlingTest test = new GeneralAndersonDarlingTest(data, new UniformRealDistribution(0, 1));
+
+ System.out.println(test.getASquared());
+ System.out.println(test.getASquaredStar());
+ System.out.println(test.getP());
+ System.out.println(test.getProbTail(data.size(), test.getASquaredStar()));
+ }
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/IDataReader.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/IDataReader.java
deleted file mode 100644
index 38f66cc916..0000000000
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/IDataReader.java
+++ /dev/null
@@ -1,30 +0,0 @@
-package edu.cmu.tetrad.data;
-
-/**
- * Identifies a class that can read data from a file.
- *
- * @author josephramsey
- */
-public interface IDataReader {
-
- /**
- * The delimiter between entries in a line, one of DelimiterType.WHITESPACE, DelimiterType.TAB, DelimiterType.COMMA,
- * DelimiterType.COLON
- */
- void setDelimiter(DelimiterType delimiterType);
-
- /**
- * True if case IDs are provided in the first column of the data.
- *
- * @deprecated
- */
- void setIdsSupplied(boolean caseIdsPresent);
-
- /**
- * The String identifier of the case ID column.
- *
- * @deprecated
- */
- void setIdLabel(String caseIdsLabel);
-
-}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Edge.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Edge.java
index 374b74cac7..ed6f1ba5c6 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Edge.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Edge.java
@@ -317,7 +317,7 @@ public final String toString() {
}
public final int hashCode() {
- return 1;
+ return node1.hashCode() + node2.hashCode();
}
/**
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/EdgeListGraph.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/EdgeListGraph.java
index cca46a120e..040379594c 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/EdgeListGraph.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/EdgeListGraph.java
@@ -20,6 +20,8 @@
///////////////////////////////////////////////////////////////////////////////
package edu.cmu.tetrad.graph;
+import edu.cmu.tetrad.data.DataBox;
+
import java.beans.PropertyChangeListener;
import java.beans.PropertyChangeSupport;
import java.io.IOException;
@@ -75,6 +77,7 @@ public class EdgeListGraph implements Graph, TripleClassifier {
private Set underLineTriples = new HashSet<>();
private Set dottedUnderLineTriples = new HashSet<>();
private Set ambiguousTriples = new HashSet<>();
+ private Map> parentsHash = new HashMap<>();
//==============================CONSTUCTORS===========================//
@@ -130,6 +133,7 @@ public EdgeListGraph(EdgeListGraph graph) throws IllegalArgumentException {
}
this.edgesSet = new HashSet<>(graph.edgesSet);
this.namesHash = new HashMap<>(graph.namesHash);
+ this.parentsHash = new HashMap<>(graph.parentsHash);
// this.paths = new Paths(this);
this.underLineTriples = graph.getUnderLines();
@@ -358,25 +362,29 @@ public Edge getDirectedEdge(Node node1, Node node2) {
*/
@Override
public List getParents(Node node) {
- List parents = new ArrayList<>();
- Set edges = this.edgeLists.get(node);
+ if (!parentsHash.containsKey(node)) {
+ List parents = new ArrayList<>();
+ Set edges = this.edgeLists.get(node);
- if (edges == null) {
- throw new IllegalArgumentException("Node " + node + " is not in the graph.");
- }
+ if (edges == null) {
+ throw new IllegalArgumentException("Node " + node + " is not in the graph.");
+ }
- for (Edge edge : edges) {
- if (edge == null) continue;
+ for (Edge edge : edges) {
+ if (edge == null) continue;
- Endpoint endpoint1 = edge.getDistalEndpoint(node);
- Endpoint endpoint2 = edge.getProximalEndpoint(node);
+ Endpoint endpoint1 = edge.getDistalEndpoint(node);
+ Endpoint endpoint2 = edge.getProximalEndpoint(node);
- if (endpoint1 == Endpoint.TAIL && endpoint2 == Endpoint.ARROW) {
- parents.add(edge.getDistalNode(node));
+ if (endpoint1 == Endpoint.TAIL && endpoint2 == Endpoint.ARROW) {
+ parents.add(edge.getDistalNode(node));
+ }
}
+
+ parentsHash.put(node, parents);
}
- return parents;
+ return parentsHash.get(node);
}
/**
@@ -578,6 +586,9 @@ public boolean removeEdge(Node node1, Node node2) {
removeTriplesNotInGraph();
+ parentsHash.remove(node1);
+ parentsHash.remove(node2);
+
return removeEdges(edges);
}
@@ -680,6 +691,9 @@ public boolean addEdge(Edge edge) {
this.edgeLists.get(edge.getNode1()).add(edge);
this.edgeLists.get(edge.getNode2()).add(edge);
this.edgesSet.add(edge);
+
+ this.parentsHash.remove(edge.getNode1());
+ this.parentsHash.remove(edge.getNode2());
}
if (Edges.isDirectedEdge(edge)) {
@@ -810,6 +824,7 @@ public boolean equals(Object o) {
public void fullyConnect(Endpoint endpoint) {
this.edgesSet.clear();
this.edgeLists.clear();
+ this.parentsHash.clear();
for (Node node : this.nodes) {
this.edgeLists.put(node, new HashSet<>());
@@ -940,6 +955,9 @@ public boolean removeEdge(Edge edge) {
this.edgeLists.put(edge.getNode1(), edgeList1);
this.edgeLists.put(edge.getNode2(), edgeList2);
+ this.parentsHash.remove(edge.getNode1());
+ this.parentsHash.remove(edge.getNode2());
+
getPcs().firePropertyChange("edgeRemoved", edge, null);
return true;
}
@@ -997,6 +1015,8 @@ public boolean removeNode(Node node) {
Set edgeList2 = this.edgeLists.get(node2);
edgeList2.remove(edge);
this.edgesSet.remove(edge);
+ this.parentsHash.remove(edge.getNode1());
+ this.parentsHash.remove(edge.getNode2());
changed = true;
}
@@ -1006,6 +1026,7 @@ public boolean removeNode(Node node) {
this.edgeLists.remove(node);
this.nodes.remove(node);
+ this.parentsHash.remove(node);
this.namesHash.remove(node.getName());
removeTriplesNotInGraph();
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphTransforms.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphTransforms.java
index 012a9b4663..50e3929e83 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphTransforms.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphTransforms.java
@@ -3,6 +3,7 @@
import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.search.utils.DagInCpcagIterator;
import edu.cmu.tetrad.search.utils.DagToPag;
+import edu.cmu.tetrad.search.utils.GraphSearchUtils;
import edu.cmu.tetrad.search.utils.MeekRules;
import edu.cmu.tetrad.util.CombinationGenerator;
import org.jetbrains.annotations.NotNull;
@@ -16,19 +17,19 @@
* @author josephramsey
*/
public class GraphTransforms {
- public static Graph dagFromCPDAG(Graph graph) {
- return dagFromCPDAG(graph, null);
+ public static Graph dagFromCpdag(Graph graph) {
+ return dagFromCpdag(graph, null);
}
- public static Graph dagFromCPDAG(Graph graph, Knowledge knowledge) {
+ /**
+ * Returns a DAG from the given CPDAG. If the given CPDAG is not a PDAG, returns null.
+ * @param graph the CPDAG
+ * @param knowledge the knowledge
+ * @return a DAG from the given CPDAG. If the given CPDAG is not a PDAG, returns null.
+ */
+ public static Graph dagFromCpdag(Graph graph, Knowledge knowledge) {
Graph dag = new EdgeListGraph(graph);
- for (Edge edge : dag.getEdges()) {
- if (Edges.isBidirectedEdge(edge)) {
- throw new IllegalArgumentException("That 'cpdag' contains a bidirected edge.");
- }
- }
-
MeekRules rules = new MeekRules();
if (knowledge != null) {
@@ -213,7 +214,6 @@ public static Graph dagToPag(Graph trueGraph) {
return new DagToPag(trueGraph).convert();
}
-
private static void direct(Node a, Node c, Graph graph) {
Edge before = graph.getEdge(a, c);
Edge after = Edges.directedEdge(a, c);
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/LayoutUtil.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/LayoutUtil.java
index dbb476ca16..8f07aa92bb 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/LayoutUtil.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/LayoutUtil.java
@@ -182,7 +182,10 @@ private static List> getTiers(Graph graph) {
List thisTier = new LinkedList<>();
for (Node node : notFound) {
- if (found.containsAll(graph.getParents(node))) {
+ List nodesInTo = graph.getNodesInTo(node, Endpoint.ARROW);
+ nodesInTo.removeAll(graph.getNodesOutTo(node, Endpoint.ARROW));
+
+ if (found.containsAll(nodesInTo)) {
thisTier.add(node);
}
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Paths.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Paths.java
index 112506f4aa..393dedc072 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Paths.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Paths.java
@@ -4,6 +4,7 @@
import edu.cmu.tetrad.search.utils.SepsetMap;
import edu.cmu.tetrad.util.SublistGenerator;
import edu.cmu.tetrad.util.TaskManager;
+import edu.cmu.tetrad.util.TetradLogger;
import edu.cmu.tetrad.util.TetradSerializable;
import java.util.*;
@@ -1588,7 +1589,10 @@ private boolean visibleEdgeHelperVisit(Node c, Node a, Node b, LinkedList
public boolean existsDirectedCycle() {
for (Node node : graph.getNodes()) {
- if (existsDirectedPathFromTo(node, node)) return true;
+ if (existsDirectedPathFromTo(node, node)) {
+ TetradLogger.getInstance().forceLogMessage("Cycle found at node " + node.getName() + ".");
+ return true;
+ }
}
return false;
}
@@ -1706,6 +1710,10 @@ public boolean isMSeparatedFrom(Node node1, Node node2, Set z) {
return !isMConnectedTo(node1, node2, z);
}
+ public boolean isMSeparatedFrom(Node node1, Node node2, Set z, Map> ancestors) {
+ return !isMConnectedTo(node1, node2, z, ancestors);
+ }
+
/**
* @return true iff there is a semi-directed path from node1 to node2
*/
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/BFci.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/BFci.java
index e001129de8..adcd9d6cbd 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/BFci.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/BFci.java
@@ -29,6 +29,7 @@
import edu.cmu.tetrad.search.utils.FciOrient;
import edu.cmu.tetrad.search.utils.SepsetProducer;
import edu.cmu.tetrad.search.utils.SepsetsGreedy;
+import edu.cmu.tetrad.util.RandomUtil;
import edu.cmu.tetrad.util.TetradLogger;
import java.util.List;
@@ -36,21 +37,19 @@
import static edu.cmu.tetrad.graph.GraphUtils.gfciExtraEdgeRemovalStep;
/**
- * Uses BOSS in place of FGES for the initial step in the GFCI algorithm.
- * This tends to produce a accurate PAG than GFCI as a result, for the latent variables case. This is a simple
- * substitution; the reference for GFCI is here:
- *
- * J.M. Ogarrio and P. Spirtes and J. Ramsey, "A Hybrid Causal Search Algorithm
- * for Latent Variable Models," JMLR 2016. Here, BOSS has been substituted for FGES.
- *
- * BOSS is a an algorithm that is currently being written up for publication,
- * so we don't yet have a reference for it.
- *
- * For BOSS only a score is needed, but there are steps in GFCI that require
- * a test, so for this method, both a test and a score need to be given.
- *
- * This class is configured to respect knowledge of forbidden and required
- * edges, including knowledge of temporal tiers.
+ * Uses BOSS in place of FGES for the initial step in the GFCI algorithm. This tends to produce a accurate PAG than GFCI
+ * as a result, for the latent variables case. This is a simple substitution; the reference for GFCI is here:
+ *
+ * J.M. Ogarrio and P. Spirtes and J. Ramsey, "A Hybrid Causal Search Algorithm for Latent Variable Models," JMLR 2016.
+ * Here, BOSS has been substituted for FGES.
+ *
+ * BOSS is a an algorithm that is currently being written up for publication, so we don't yet have a reference for it.
+ *
+ * For BOSS only a score is needed, but there are steps in GFCI that require a test, so for this method, both a test and
+ * a score need to be given.
+ *
+ * This class is configured to respect knowledge of forbidden and required edges, including knowledge of temporal
+ * tiers.
*
* @author josephramsey
* @author bryan andrews
@@ -83,6 +82,7 @@ public final class BFci implements IGraphSearch {
private int depth = -1;
private boolean doDiscriminatingPathRule = true;
private boolean bossUseBes = false;
+ private long seed = -1;
/**
@@ -109,6 +109,10 @@ public BFci(IndependenceTest test, Score score) {
* @return The discovered graph.
*/
public Graph search() {
+ if (seed != -1) {
+ RandomUtil.getInstance().setSeed(seed);
+ }
+
List nodes = getIndependenceTest().getVariables();
this.logger.log("info", "Starting FCI algorithm.");
@@ -195,20 +199,43 @@ public IndependenceTest getIndependenceTest() {
return this.independenceTest;
}
+ /**
+ * Returns the number of times to restart the search.
+ *
+ * @param numStarts The number of times to restart the search.
+ */
public void setNumStarts(int numStarts) {
this.numStarts = numStarts;
}
+ /**
+ * Sets the depth of the search (for the constraint-based step).
+ *
+ * @param depth The depth of the search.
+ */
public void setDepth(int depth) {
this.depth = depth;
}
+ /**
+ * Sets whether the discriminating path rule should be used.
+ *
+ * @param doDiscriminatingPathRule True if the discriminating path rule should be used, false otherwise.
+ */
public void setDoDiscriminatingPathRule(boolean doDiscriminatingPathRule) {
this.doDiscriminatingPathRule = doDiscriminatingPathRule;
}
+ /**
+ * Sets whether the BES should be used.
+ *
+ * @param useBes True if the BES should be used, false otherwise.
+ */
public void setBossUseBes(boolean useBes) {
this.bossUseBes = useBes;
}
+ public void setSeed(long seed) {
+ this.seed = seed;
+ }
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Boss.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Boss.java
index d2ca4feeb0..d9c08c567a 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Boss.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Boss.java
@@ -8,35 +8,35 @@
import edu.cmu.tetrad.search.utils.GrowShrinkTree;
import java.util.*;
-import java.util.concurrent.*;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ForkJoinPool;
import static edu.cmu.tetrad.util.RandomUtil.shuffle;
/**
- * Implements Best Order Score Search (BOSS). The following references are relevant:
- *
- * Lam, W. Y., Andrews, B., & Ramsey, J. (2022, August). Greedy relaxations of the sparsest permutation algorithm.
- * In Uncertainty in Artificial Intelligence (pp. 1052-1062). PMLR.
- *
- * Teyssier, M., & Koller, D. (2012). Ordering-based search: A simple and effective algorithm for learning Bayesian
- * networks. arXiv preprint arXiv:1207.1429.
- *
- * Solus, L., Wang, Y., & Uhler, C. (2021). Consistency guarantees for greedy permutation-based causal inference
- * algorithms. Biometrika, 108(4), 795-814.
- *
- * The BOSS algorithm is based on the idea that implied DAGs for permutations are most optimal in their BIC scores
- * when the variables in the permutations are ordered causally--that is, so that that causes in the models come before
- * effects in a topological order.
- *
- * This algorithm is implemented as a "plugin-in" algorithm to a PermutationSearch object (see), which deals with
- * certain details of knowledge handling that are common to different permutation searches.
- *
- * BOSS, like GRaSP (see), is characterized by high adjacency and orientation precision (especially) and recall for
+ * Implements Best Order Score Search (BOSS). The following references are relevant:
+ *
+ * Lam, W. Y., Andrews, B., & Ramsey, J. (2022, August). Greedy relaxations of the sparsest permutation algorithm.
+ * In Uncertainty in Artificial Intelligence (pp. 1052-1062). PMLR.
+ *
+ * Teyssier, M., & Koller, D. (2012). Ordering-based search: A simple and effective algorithm for learning Bayesian
+ * networks. arXiv preprint arXiv:1207.1429.
+ *
+ * Solus, L., Wang, Y., & Uhler, C. (2021). Consistency guarantees for greedy permutation-based causal inference
+ * algorithms. Biometrika, 108(4), 795-814.
+ *
+ * The BOSS algorithm is based on the idea that implied DAGs for permutations are most optimal in their BIC scores when
+ * the variables in the permutations are ordered causally--that is, so that that causes in the models come before
+ * effects in a topological order.
+ *
+ * This algorithm is implemented as a "plugin-in" algorithm to a PermutationSearch object (see), which deals with
+ * certain details of knowledge handling that are common to different permutation searches.
+ *
+ * BOSS, like GRaSP (see), is characterized by high adjacency and orientation precision (especially) and recall for
* moderate sample sizes. BOSS scales up currently further than GRaSP to larger variable sets and denser graphs and so
- * is currently preferable from a practical standpoint, though performance is essentially identical.
- *
- * The algorithm works as follows:
- *
+ * is currently preferable from a practical standpoint, though performance is essentially identical.
+ *
+ * The algorithm works as follows:
*
* - Start with an arbitrary ordering.
* - Run the permutation search to find a better ordering.
@@ -44,24 +44,24 @@
* - Optionally, Run BES this CPDAG.
*
- Return this CPDAG.
*
- *
- * The optional BES step is needed for correctness, though with large
+ *
+ * The optional BES step is needed for correctness, though with large
* models is has very little effect on the output, since nearly all edges
- * are already oriented, so a parameter is included to turn that step off.
- *
- * Knowledge can be used with this search. If tiered knowledge is used,
+ * are already oriented, so a parameter is included to turn that step off.
+ *
+ * Knowledge can be used with this search. If tiered knowledge is used,
* then the procedure is carried out for each tier separately, given the
* variables preceding that tier, which allows the Boss algorithm to address
* tiered (e.g., time series) problems with larger numbers of variables.
* However, knowledge of required and forbidden edges is correctly implemented
- * for arbitrary such knowledge.
- *
- * A parameter is included to restart the search a certain number of time.
+ * for arbitrary such knowledge.
+ *
+ * A parameter is included to restart the search a certain number of time.
* The idea is that the goal is to optimize a BIC score, so if several runs
* are done of the algorithm for the same data, the model with the highest
- * BIC score should be returned and the others ignored.
- *
- * This class is meant to be used in the context of the PermutationSearch
+ * BIC score should be returned and the others ignored.
+ *
+ * This class is meant to be used in the context of the PermutationSearch
* class (see).
*
* @author bryanandrews
@@ -71,21 +71,40 @@
* @see Knowledge
*/
public class Boss implements SuborderSearch {
+ // The score.
private final Score score;
+ // The variables.
private final List variables;
+ // The parents.
private final Map> parents;
+ // The grow-shrink trees.
private Map gsts;
+ // The set of all variables.
private Set all;
+ // The pool for parallelism.
private ForkJoinPool pool;
+ // The knowledge.
private Knowledge knowledge = new Knowledge();
+ // The BES algorithm.
private BesPermutation bes = null;
+ // The number of random starts to use.
private int numStarts = 1;
+ // True if the order of the variables in the data should be used for an initial best-order search, false if a random
+ // permutation should be used. (Subsequence automatic best order runs will use random permutations.) This is
+ // included so that the algorithm will be capable of outputting the same results with the same data without any
+ // randomness.
private boolean useDataOrder = true;
+ // True if the grow-shrink trees should be reset after each best-mutation step.
private boolean resetAfterBM = false;
+ // True if the grow-shrink trees should be reset after each restart.
private boolean resetAfterRS = true;
+ // The number of threads to use.
private int numThreads = 1;
+ // True if verbose output should be printed.
private List bics;
+ // The BIC scores.
private List times;
+ // True if verbose output should be printed.
private boolean verbose = false;
@@ -103,6 +122,15 @@ public Boss(Score score) {
}
}
+ /**
+ * Searches a suborder of the variables. The prefix is the set of variables that must precede the suborder. The
+ * suborder is the set of variables to be ordered. The gsts is a map from variables to GrowShrinkTrees, which are
+ * used to cache scores for the variables. The searchSuborder method will update the suborder to be the best
+ * ordering found.
+ * @param prefix The prefix of the suborder.
+ * @param suborder The suborder.
+ * @param gsts The GrowShrinkTree being used to do caching of scores.
+ */
@Override
public void searchSuborder(List prefix, List suborder, Map gsts) {
assert this.numStarts > 0;
@@ -129,8 +157,8 @@ public void searchSuborder(List prefix, List suborder, Map 0 && this.resetAfterRS) {
- for (Node root: suborder) {
- this.gsts.get(root).reset();
+ for (Node root : suborder) {
+ this.gsts.get(root).reset();
}
}
@@ -196,6 +224,10 @@ public void setUseBes(boolean use) {
}
}
+ /**
+ * Sets the knowledge to be used for the search.
+ * @param knowledge This knowledge. If null, no knowledge will be used.
+ */
@Override
public void setKnowledge(Knowledge knowledge) {
this.knowledge = knowledge;
@@ -214,12 +246,20 @@ public void setNumStarts(int numStarts) {
this.numStarts = numStarts;
}
+ /**
+ * Sets whether the grow-shrink trees should be reset after each best-mutation step.
+ * @param reset True if so.
+ */
public void setResetAfterBM(boolean reset) {
this.resetAfterBM = reset;
}
+ /**
+ * Sets whether the grow-shrink trees should be reset after each restart.
+ * @param reset True if so.
+ */
public void setResetAfterRS(boolean reset) {
- this.resetAfterRS = reset;
+ this.resetAfterRS = reset;
}
public void setVerbose(boolean verbose) {
@@ -253,7 +293,7 @@ public List getTimes() {
return this.times;
}
- /**
+ /**
* True if the order of the variables in the data should be used for an initial best-order search, false if a random
* permutation should be used. (Subsequence automatic best order runs will use random permutations.) This is
* included so that the algorithm will be capable of outputting the same results with the same data without any
@@ -297,14 +337,14 @@ private boolean betterMutationAsync(List prefix, List suborder, Node
if (this.resetAfterBM) this.gsts.get(x).reset();
double runningScore = 0;
- for (i = with.length - 1 ; i >= 0 ; i--) {
+ for (i = with.length - 1; i >= 0; i--) {
runningScore += with[i];
scores[i] += runningScore;
}
runningScore = 0;
- for (i = 0 ; i < without.length ; i++) {
+ for (i = 0; i < without.length; i++) {
runningScore += without[i];
scores[i + 1] += runningScore;
}
@@ -323,30 +363,6 @@ private boolean betterMutationAsync(List prefix, List suborder, Node
return true;
}
- private static class Trace implements Callable {
- private final GrowShrinkTree gst;
- private final Set all;
- private final Set prefix;
- private final double[] scores;
- private final int index;
-
- Trace(GrowShrinkTree gst, Set all, Set prefix, double[] scores, int index) {
- this.gst = gst;
- this.all = all;
- this.prefix = new HashSet<>(prefix);
- this.scores = scores;
- this.index = index;
- }
-
- @Override
- public Void call() {
- double score = gst.trace(this.prefix, this.all);
- this.scores[index] = score;
-
- return null;
- }
- }
-
private boolean betterMutation(List prefix, List suborder, Node x) {
ListIterator itr = suborder.listIterator();
double[] scores = new double[suborder.size() + 1];
@@ -357,6 +373,8 @@ private boolean betterMutation(List prefix, List suborder, Node x) {
int curr = 0;
while (itr.hasNext()) {
+ if (Thread.currentThread().isInterrupted()) return false;
+
Node z = itr.next();
if (this.knowledge.isRequired(x.getName(), z.getName())) {
@@ -423,10 +441,6 @@ private double update(List prefix, List suborder) {
return score;
}
-
- // alter this code so that it roughly obeys tiers.
-
-
private void makeValidKnowledgeOrder(List order) {
if (this.knowledge.isEmpty()) return;
@@ -462,4 +476,31 @@ private void makeValidKnowledgeOrder(List order) {
}
}
}
+
+
+ // alter this code so that it roughly obeys tiers.
+
+ private static class Trace implements Callable {
+ private final GrowShrinkTree gst;
+ private final Set all;
+ private final Set prefix;
+ private final double[] scores;
+ private final int index;
+
+ Trace(GrowShrinkTree gst, Set all, Set prefix, double[] scores, int index) {
+ this.gst = gst;
+ this.all = all;
+ this.prefix = new HashSet<>(prefix);
+ this.scores = scores;
+ this.index = index;
+ }
+
+ @Override
+ public Void call() {
+ double score = gst.trace(this.prefix, this.all);
+ this.scores[index] = score;
+
+ return null;
+ }
+ }
}
\ No newline at end of file
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/BossLingam.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/BossLingam.java
index ed9758a42a..047a6684e6 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/BossLingam.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/BossLingam.java
@@ -39,35 +39,33 @@
import java.util.List;
/**
- * Implements the BOSS-LiNGAM algorithm which first finds a CPDAG for the variables
- * and then uses a non-Gaussian orientation method to orient the undirected edges. The reference is as follows:
- *
- *
>Hoyer et al., "Causal discovery of linear acyclic models with arbitrary
- * distributions," UAI 2008.
- *
- * The test for normality used for residuals is Anderson-Darling, following 'ad.test'
- * in the nortest package of R. The default alpha level is 0.05--that is, p values from AD below 0.05 are taken to
- * indicate nongaussianity.
- *
- * It is assumed that the CPDAG is the result of a CPDAG search such as PC or GES. In any
- * case, it is important that the residuals be independent for ICA to work.
- *
- * This may be replaced by a more general algorithm that allows alternatives for the
- * CPDAG search and for the the non-Gaussian orientation method.
- *
- * This class is not configured to respect knowledge of forbidden and required
- * edges.
+ * Implements the BOSS-LiNGAM algorithm which first finds a CPDAG for the variables and then uses a non-Gaussian
+ * orientation method to orient the undirected edges. The reference is as follows:
+ *
+ * Hoyer et al., "Causal discovery of linear acyclic models with arbitrary distributions," UAI 2008.
+ *
+ * The test for normality used for residuals is Anderson-Darling, following 'ad.test' in the nortest package of R. The
+ * default alpha level is 0.05--that is, p values from AD below 0.05 are taken to indicate nongaussianity.
+ *
+ * It is assumed that the CPDAG is the result of a CPDAG search such as PC or GES. In any case, it is important that the
+ * residuals be independent for ICA to work.
+ *
+ * This may be replaced by a more general algorithm that allows alternatives for the CPDAG search and for the the
+ * non-Gaussian orientation method.
+ *
+ * This class is not configured to respect knowledge of forbidden and required edges.
*
* @author peterspirtes
* @author patrickhoyer
* @author josephramsey
*/
public class BossLingam {
+ // The CPDAG whose unoriented edges are to be oriented.
private final Graph cpdag;
+ // The dataset to use.
private final DataSet dataSet;
+ // The p-values of the search.
private double[] pValues;
- private double alpha = 0.05;
-
/**
* Constructor.
@@ -113,7 +111,7 @@ public Graph search() {
int i = nodes.indexOf(X);
int j = nodes.indexOf(Y);
- double lr = Fask.faskLeftRightV2(_data[i], _data[j], true, 0);
+ double lr = Fask.faskLeftRightV2(_data[i], _data[j], true, 0);
if (lr > 0.0) {
toOrient.removeEdge(edge);
@@ -148,7 +146,7 @@ public void setAlpha(double alpha) {
throw new IllegalArgumentException("Alpha is in range [0, 1]");
}
- this.alpha = alpha;
+ // The alpha level for the search.
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Bpc.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Bpc.java
index bca013898f..b14c9c7a04 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Bpc.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Bpc.java
@@ -34,39 +34,37 @@
/**
- *
Implements the Build Pure Clusters (BPC) algorithm, which allows one to identify
- * clusters of measured variables in a dataset that are explained by a single latent. The algorithm outputs these
- * clusters, which can then be used for further analysis, such as inferring structure over the latents. For the latter,
- * see for instance the MimBuild algorithm.
- *
- * The reference for BPC is this:
- *
- * Silva, R., Scheines, R., Glymour, C., Spirtes, P., & Chickering, D. M. (2006).
- * Learning the Structure of Linear Latent Variable Models. Journal of Machine Learning Research, 7(2).
- *
- * For a more detailed description of the algorithm, see the paper above. The
- * algorithm is based on the idea of finding cliques in the graph of the covariance matrix. The algorithm is initialized
- * by finding all maximal cliques in the graph of the covariance matrix. Then, the algorithm iterates over the cliques,
- * and for each clique, it tests whether the clique is explained by a single latent. If so, the clique is added to the
- * set of clusters. If not, the clique is partitioned into smaller cliques, and the process is repeated for each of the
- * smaller cliques. The algorithm stops when all cliques have been tested.
- *
- * Some more References:
- *
- * Silva, R.; Scheines, R.; Spirtes, P.; Glymour, C. (2003). "Learning measurement models".
- * Technical report CMU-CALD-03-100, Center for Automated Learning and Discovery, Carnegie Mellon University.
- *
- * Bollen, K. (1990). "Outlier screening and distribution-free test for vanishing tetrads."
- * Sociological Methods and Research 19, 80-92.
- *
- * Wishart, J. (1928). "Sampling errors in the theory of two factors". British Journal of
- * Psychology 19, 180-187.
- *
- * Bron, C. and Kerbosch, J. (1973) "Algorithm 457: Finding all cliques of an undirected graph".
- * Communications of ACM 16, 575-577.
- *
- * This class is not configured to respect knowledge of forbidden and required
- * edges.
+ * Implements the Build Pure Clusters (BPC) algorithm, which allows one to identify clusters of measured variables in a
+ * dataset that are explained by a single latent. The algorithm outputs these clusters, which can then be used for
+ * further analysis, such as inferring structure over the latents. For the latter, see for instance the MimBuild
+ * algorithm.
+ *
+ * The reference for BPC is this:
+ *
+ * Silva, R., Scheines, R., Glymour, C., Spirtes, P., & Chickering, D. M. (2006). Learning the Structure of Linear
+ * Latent Variable Models. Journal of Machine Learning Research, 7(2).
+ *
+ * For a more detailed description of the algorithm, see the paper above. The algorithm is based on the idea of finding
+ * cliques in the graph of the covariance matrix. The algorithm is initialized by finding all maximal cliques in the
+ * graph of the covariance matrix. Then, the algorithm iterates over the cliques, and for each clique, it tests whether
+ * the clique is explained by a single latent. If so, the clique is added to the set of clusters. If not, the clique is
+ * partitioned into smaller cliques, and the process is repeated for each of the smaller cliques. The algorithm stops
+ * when all cliques have been tested.
+ *
+ * Some more References:
+ *
+ * Silva, R.; Scheines, R.; Spirtes, P.; Glymour, C. (2003). "Learning measurement models". Technical report
+ * CMU-CALD-03-100, Center for Automated Learning and Discovery, Carnegie Mellon University.
+ *
+ * Bollen, K. (1990). "Outlier screening and distribution-free test for vanishing tetrads." Sociological Methods and
+ * Research 19, 80-92.
+ *
+ * Wishart, J. (1928). "Sampling errors in the theory of two factors". British Journal of Psychology 19, 180-187.
+ *
+ * Bron, C. and Kerbosch, J. (1973) "Algorithm 457: Finding all cliques of an undirected graph". Communications of ACM
+ * 16, 575-577.
+ *
+ * This class is not configured to respect knowledge of forbidden and required edges.
*
* @author Ricardo Silva
* @see Fofc
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Ccd.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Ccd.java
index 4c98673ad6..5792fc2c45 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Ccd.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Ccd.java
@@ -32,29 +32,28 @@
import java.util.*;
/**
- *
Implemented the Cyclic Causal Discovery (CCD) algorithm by Thomas Richardson.
- * A reference for this is here:
- *
- * Richardson, T. S. (2013). A discovery algorithm for directed cyclic graphs. arXiv
- * preprint arXiv:1302.3599.
- *
- * See also Chapter 7 of:
- *
- * Glymour, C. N., & Cooper, G. F. (Eds.). (1999). Computation, causation, and
- * discovery. Aaai Press.
- *
- * The graph takes continuous data from a cyclic model as input and returns a cyclic
- * PAG graphs, with various types of underlining, that represents a Markov equivalence of the true DAG.
- *
- * This class is not configured to respect knowledge of forbidden and required
- * edges.
+ * Implemented the Cyclic Causal Discovery (CCD) algorithm by Thomas Richardson. A reference for this is here:
+ *
+ * Richardson, T. S. (2013). A discovery algorithm for directed cyclic graphs. arXiv preprint arXiv:1302.3599.
+ *
+ * See also Chapter 7 of:
+ *
+ * Glymour, C. N., & Cooper, G. F. (Eds.). (1999). Computation, causation, and discovery. Aaai Press.
+ *
+ * The graph takes continuous data from a cyclic model as input and returns a cyclic PAG graphs, with various types of
+ * underlining, that represents a Markov equivalence of the true DAG.
+ *
+ * This class is not configured to respect knowledge of forbidden and required edges.
*
* @author Frank C. Wimberly
* @author josephramsey
*/
public final class Ccd implements IGraphSearch {
+ // The independence test to be used.
private final IndependenceTest independenceTest;
+ // The nodes in the graph.
private final List nodes;
+ // Whether the R1 rule should be applied.
private boolean applyR1;
/**
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Cfci.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Cfci.java
index 69e575f82f..5476d18a1b 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Cfci.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Cfci.java
@@ -34,12 +34,11 @@
/**
- * Adjusts FCI (see) to use conservative orientation as in CPC (see). Because the
- * collider orientation is conservative, there may be ambiguous triples; these may be retrieved using that accessor
- * method.
- *
- * This class is configured to respect knowledge of forbidden and required
- * edges, including knowledge of temporal tiers.
+ * Adjusts FCI (see) to use conservative orientation as in CPC (see). Because the collider orientation is conservative,
+ * there may be ambiguous triples; these may be retrieved using that accessor method.
+ *
+ * This class is configured to respect knowledge of forbidden and required edges, including knowledge of temporal
+ * tiers.
*
* @author josephramsey
* @see Fci
@@ -49,55 +48,33 @@
*/
public final class Cfci implements IGraphSearch {
- /**
- * The SepsetMap being constructed.
- */
+ // The SepsetMap being constructed.
private final SepsetMap sepsets = new SepsetMap();
- /**
- * The variables to search over (optional)
- */
+ // The variables to search over (optional)
private final List variables = new ArrayList<>();
- /**
- * The independence test.
- */
+ // The independence test.
private final IndependenceTest independenceTest;
- /**
- * The logger to use.
- */
+ // The logger to use.
private final TetradLogger logger = TetradLogger.getInstance();
- /**
- * The PAG being constructed.
- */
+ // The PAG being constructed.
private Graph graph;
- /**
- * The background knowledge.
- */
+ // The background knowledge.
private Knowledge knowledge = new Knowledge();
- /**
- * Flag for complete rule set, true if you should use complete rule set, false otherwise.
- */
+ // Flag for complete rule set, true if you should use complete rule set, false otherwise.
private boolean completeRuleSetUsed = true;
- /**
- * True iff the possible msep search is done.
- */
+ // True iff the possible msep search is done.
private boolean possibleMsepSearchDone = true;
- /**
- * The maximum length for any discriminating path. -1 if unlimited; otherwise, a positive integer.
- */
+ // The maximum length for any discriminating path. -1 if unlimited; otherwise, a positive integer.
private int maxReachablePathLength = -1;
- /**
- * Set of ambiguous unshielded triples.
- */
+ // Set of ambiguous unshielded triples.
private Set ambiguousTriples;
- /**
- * The depth for the fast adjacency search.
- */
+ // The depth for the fast adjacency search.
private int depth = -1;
- /**
- * Elapsed time of last search.
- */
+ // Elapsed time of last search.
private long elapsedTime;
+ // Whether verbose output (about independencies) is output.
private boolean verbose;
+ // Whether to do the discriminating path rule.
private boolean doDiscriminatingPathRule;
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/CheckKnowledge.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/CheckKnowledge.java
new file mode 100644
index 0000000000..8889df82dd
--- /dev/null
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/CheckKnowledge.java
@@ -0,0 +1,81 @@
+package edu.cmu.tetrad.search;
+
+import edu.cmu.tetrad.data.Knowledge;
+import edu.cmu.tetrad.data.KnowledgeEdge;
+import edu.cmu.tetrad.graph.Edge;
+import edu.cmu.tetrad.graph.Edges;
+import edu.cmu.tetrad.graph.Graph;
+import edu.cmu.tetrad.graph.Node;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+
+/**
+ * Identifies violations of knowledge for a given graph. Both forbidden and required knowledge is checked, by separate
+ * methods. Sorted lists of edges violating knowledge are returned.
+ *
+ * @author josephramsey
+ */
+public class CheckKnowledge {
+
+ /**
+ * Private constructor to prevent instantiation.
+ */
+ private CheckKnowledge() {
+ }
+
+ /**
+ * Returns a sorted list of edges that violate the given knowledge.
+ *
+ * @param graph the graph.
+ * @param knowledge the knowledge.
+ * @return a sorted list of edges that violate the given knowledge.
+ */
+ public static List forbiddenViolations(Graph graph, Knowledge knowledge) {
+ List forbiddenViolations = new ArrayList<>();
+
+ for (Edge edge : graph.getEdges()) {
+ if (edge.isDirected()) {
+ Node x = Edges.getDirectedEdgeTail(edge);
+ Node y = Edges.getDirectedEdgeHead(edge);
+
+ if (knowledge.isForbidden(x.getName(), y.getName())) {
+ forbiddenViolations.add(edge);
+ }
+ }
+ }
+
+ Collections.sort(forbiddenViolations);
+
+ return forbiddenViolations;
+ }
+
+ /**
+ * Returns a sorted list of edges that are required by knowledge but which do not appear in the graph.
+ *
+ * @param graph the graph.
+ * @param knowledge the knowledge.
+ * @return a sorted list of edges that are required by knowledge but which do not appear in the graph.
+ */
+ public static List requiredViolations(Graph graph, Knowledge knowledge) {
+ List requiredViolations = new ArrayList<>();
+
+ Iterator knowledgeEdgeIterator = knowledge.requiredEdgesIterator();
+
+ while (knowledgeEdgeIterator.hasNext()) {
+ KnowledgeEdge edge = knowledgeEdgeIterator.next();
+ Node x = graph.getNode(edge.getFrom());
+ Node y = graph.getNode(edge.getTo());
+
+ if (!graph.containsEdge(Edges.directedEdge(x, y))) {
+ requiredViolations.add(Edges.directedEdge(x, y));
+ }
+ }
+
+ Collections.sort(requiredViolations);
+
+ return requiredViolations;
+ }
+}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/CompositeIndependenceTest.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/CompositeIndependenceTest.java
new file mode 100644
index 0000000000..5f18b940cb
--- /dev/null
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/CompositeIndependenceTest.java
@@ -0,0 +1,41 @@
+package edu.cmu.tetrad.search;
+
+import edu.cmu.tetrad.data.DataModel;
+import edu.cmu.tetrad.graph.Node;
+import edu.cmu.tetrad.search.test.IndependenceResult;
+
+import java.util.List;
+import java.util.Set;
+
+public class CompositeIndependenceTest implements IndependenceTest {
+ private final IndependenceTest[] independenceTests;
+
+ public CompositeIndependenceTest(IndependenceTest[] independenceTests) {
+ this.independenceTests = independenceTests;
+ }
+
+ @Override
+ public IndependenceResult checkIndependence(Node x, Node y, Set z) {
+ return null;
+ }
+
+ @Override
+ public List getVariables() {
+ return null;
+ }
+
+ @Override
+ public DataModel getData() {
+ return null;
+ }
+
+ @Override
+ public boolean isVerbose() {
+ return false;
+ }
+
+ @Override
+ public void setVerbose(boolean verbose) {
+
+ }
+}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Cstar.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Cstar.java
index df608c89e7..f2076b5e9b 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Cstar.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Cstar.java
@@ -16,33 +16,32 @@
import java.util.concurrent.*;
/**
- * Implements the CStaR algorithm (Stekhoven et al., 2012), which finds a CPDAG of that
- * data and then tries all orientations of the undirected edges about a variable in the CPDAG to estimate a minimum
- * bound on the effect for a given edge. Some references include the following:
- *
- * Stekhoven, D. J., Moraes, I., Sveinbjörnsson, G., Hennig, L., Maathuis, M. H., and
- * Bühlmann, P. (2012). Causal stability ranking. Bioinformatics, 28(21), 2819-2823.
- *
- * Meinshausen, N., and Bühlmann, P. (2010). Stability selection. Journal of the Royal
- * Statistical Society: Series B (Statistical Methodology), 72(4), 417-473.
- *
- * Colombo, D., and Maathuis, M. H. (2014). Order-independent constraint-based causal
- * structure learning. The Journal of Machine Learning Research, 15(1), 3741-3782.
- *
- * This class is not configured to respect knowledge of forbidden and required
- * edges.
+ * Implements the CStaR algorithm (Stekhoven et al., 2012), which finds a CPDAG of that data and then tries all
+ * orientations of the undirected edges about a variable in the CPDAG to estimate a minimum bound on the effect for a
+ * given edge. Some references include the following:
+ *
+ * Stekhoven, D. J., Moraes, I., Sveinbjörnsson, G., Hennig, L., Maathuis, M. H., and Bühlmann, P. (2012). Causal
+ * stability ranking. Bioinformatics, 28(21), 2819-2823.
+ *
+ * Meinshausen, N., and Bühlmann, P. (2010). Stability selection. Journal of the Royal Statistical Society: Series B
+ * (Statistical Methodology), 72(4), 417-473.
+ *
+ * Colombo, D., and Maathuis, M. H. (2014). Order-independent constraint-based causal structure learning. The Journal of
+ * Machine Learning Research, 15(1), 3741-3782.
+ *
+ * This class is not configured to respect knowledge of forbidden and required edges.
*
* @author josephramsey
* @see Ida
*/
public class Cstar {
+ private final IndependenceWrapper test;
+ private final ScoreWrapper score;
+ private final Parameters parameters;
private boolean parallelized = false;
private int numSubsamples = 30;
private int topBracket = 5;
private double selectionAlpha = 0.0;
- private final IndependenceWrapper test;
- private final ScoreWrapper score;
- private final Parameters parameters;
private CpdagAlgorithm cpdagAlgorithm = CpdagAlgorithm.PC_STABLE;
private SampleStyle sampleStyle = SampleStyle.SUBSAMPLE;
private boolean verbose;
@@ -600,6 +599,7 @@ private Graph getPatternFges(DataSet sample) {
private Graph getPatternBoss(DataSet sample) {
Score score = this.score.getScore(sample, parameters);
PermutationSearch boss = new PermutationSearch(new Boss(score));
+ boss.setSeed(parameters.getLong(Params.SEED));
return boss.search();
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Dagma.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Dagma.java
index dc717c05c0..56be90398c 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Dagma.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Dagma.java
@@ -26,34 +26,24 @@
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.utils.MeekRules;
-import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.LUDecomposition;
+import org.apache.commons.math3.linear.RealMatrix;
-import java.util.*;
+import java.util.List;
import static org.apache.commons.math3.linear.MatrixUtils.*;
import static org.apache.commons.math3.util.FastMath.*;
/**
- *
Implements the DAGMA algorithm. The reference is here:
- *
- * NEEDS DOCUMENTATION
+ * Implements the DAGMA algorithm. The reference is here:
+ *
+ * NEEDS DOCUMENTATION
*
* @author bryanandrews
*/
public class Dagma {
- private RealMatrix cov;
- private List variables;
- private RealMatrix I;
- private int d;
-
- private double lambda1;
- private double wThreshold;
- private boolean cpdag;
-
private final double[] T;
-
private final double muInit;
private final double muFactor;
private final int warmIter;
@@ -63,6 +53,13 @@ public class Dagma {
private final double b1;
private final double b2;
private final double tol;
+ private final RealMatrix cov;
+ private final List variables;
+ private final RealMatrix I;
+ private final int d;
+ private double lambda1;
+ private double wThreshold;
+ private boolean cpdag;
/**
@@ -80,7 +77,7 @@ public Dagma(DataSet dataset) {
this.cpdag = true;
// M-matrix s values
- this.T = new double[] {1.0, .9, .8, .7};
+ this.T = new double[]{1.0, .9, .8, .7};
// central path coefficient and decay factor
this.muInit = 1.0;
@@ -91,8 +88,8 @@ public Dagma(DataSet dataset) {
this.maxIter = 70000;
this.lr = 3e-4;
this.checkpoint = 1000;
- this.b1=0.99;
- this.b2=0.999;
+ this.b1 = 0.99;
+ this.b2 = 0.999;
this.tol = 1e-6;
}
@@ -266,10 +263,10 @@ private RealMatrix getMMatrix(RealMatrix W, double s) {
RealMatrix M = this.I.scalarMultiply(s);
for (int i = 0; i < this.d; i++) {
- for (int j = 0; j < this.d; j++) {
- M.addToEntry(i, j, -W.getEntry(i, j) * W.getEntry(i, j));
- }
+ for (int j = 0; j < this.d; j++) {
+ M.addToEntry(i, j, -W.getEntry(i, j) * W.getEntry(i, j));
}
+ }
return M;
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Demixer.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Demixer.java
index b3854ed823..f8acaa5af5 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Demixer.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Demixer.java
@@ -16,19 +16,34 @@
* @author Madelyn Glymour
*/
public class Demixer {
-
+ // number of variables in the data set
private final int numVars;
+ // number of cases in the data set
private final int numCases;
- private final int numClusters; // number of clusters
+ // number of clusters
+ private final int numClusters;
+ // the data set
private final DataSet data;
- private final double[][] dataArray; // v-by-n data matrix
+ // the data set as a double array
+ private final double[][] dataArray;
+ // the means of each variable for each model
private final Matrix[] variances;
- private final double[][] meansArray; // k-by-v matrix representing means for each variable for each of k models
- private final Matrix[] variancesArray; // k-by-v-by-v matrix representing covariance matrix for each of k models
- private final double[] weightsArray; // array of length k representing weights for each model
- private final double[][] gammaArray; // k-by-n matrix representing gamma for each data case in each model
+ // the means of each variable for each model
+ private final double[][] meansArray;
+ // the variances of each variable for each model
+ private final Matrix[] variancesArray;
+ // the weights of each model
+ private final double[] weightsArray;
+ // the gamma values for each case in each model
+ private final double[][] gammaArray;
+ // whether the algorithm has been run
private boolean demixed = false;
+ /**
+ * Constructor. Initializes the means, weights, and covariance matrices for each model.
+ * @param data the data set
+ * @param k the number of models
+ */
public Demixer(DataSet data, int k) {
this.numClusters = k;
this.data = data;
@@ -61,6 +76,22 @@ public Demixer(DataSet data, int k) {
}
}
+ static double getVar(int i, int v, int v2, int numCases, double[][] gammaArray, double[][] dataArray, double[][] meansArray) {
+ double varNumerator;
+ double varDivisor;
+ double var;
+ varNumerator = 0;
+ varDivisor = 0;
+
+ for (int j = 0; j < numCases; j++) {
+ varNumerator += gammaArray[i][j] * (dataArray[j][v] - meansArray[i][v]) * (dataArray[j][v2] - meansArray[i][v2]);
+ varDivisor += gammaArray[i][j];
+ }
+
+ var = varNumerator / varDivisor;
+ return var;
+ }
+
/*
* Runs the E-M algorithm iteratively until the weights array converges. Returns a MixtureModel object containing
* the final values of the means, covariance matrices, weights, and gammas arrays.
@@ -203,22 +234,6 @@ private void maximization() {
}
- static double getVar(int i, int v, int v2, int numCases, double[][] gammaArray, double[][] dataArray, double[][] meansArray) {
- double varNumerator;
- double varDivisor;
- double var;
- varNumerator = 0;
- varDivisor = 0;
-
- for (int j = 0; j < numCases; j++) {
- varNumerator += gammaArray[i][j] * (dataArray[j][v] - meansArray[i][v]) * (dataArray[j][v2] - meansArray[i][v2]);
- varDivisor += gammaArray[i][j];
- }
-
- var = varNumerator / varDivisor;
- return var;
- }
-
/*
* For an input case and model, returns the value of the model's normal PDF for that case, using the current
* estimations of the means and covariance matrix
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/DirectLingam.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/DirectLingam.java
index 70675a83d2..5155f8d529 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/DirectLingam.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/DirectLingam.java
@@ -34,27 +34,30 @@
import static org.apache.commons.math3.util.FastMath.*;
/**
- * Implements the Direct-LiNGAM algorithm. The reference is here:
- *
- * S. Shimizu, T. Inazumi, Y. Sogawa, A. Hyvärinen, Y. Kawahara, T. Washio, P. O. Hoyer and K. Bollen.
- * DirectLiNGAM: A direct method for learning a linear non-Gaussian structural equation model. Journal of Machine
- * Learning Research, 12(Apr): 1225–1248, 2011.
- *
- * A. Hyvärinen and S. M. Smith. Pairwise likelihood ratios for estimation of non-Gaussian
- * structural evaluation models. Journal of Machine Learning Research 14:111-152, 2013.
- *
- * NEEDS DOCUMENTATION
+ * Implements the Direct-LiNGAM algorithm. The reference is here:
+ *
+ * S. Shimizu, T. Inazumi, Y. Sogawa, A. Hyvärinen, Y. Kawahara, T. Washio, P. O. Hoyer and K. Bollen. DirectLiNGAM: A
+ * direct method for learning a linear non-Gaussian structural equation model. Journal of Machine Learning Research,
+ * 12(Apr): 1225–1248, 2011.
+ *
+ * A. Hyvärinen and S. M. Smith. Pairwise likelihood ratios for estimation of non-Gaussian structural evaluation models.
+ * Journal of Machine Learning Research 14:111-152, 2013.
*
* @author bryanandrews
*/
public class DirectLingam {
-
+ // the data set
private final DataSet dataset;
+ // the variables
private final List variables;
+ // the grow-shrink trees
private final Map gsts;
/**
* Constructor.
+ *
+ * @param dataset the data set
+ * @param score the score
*/
public DirectLingam(DataSet dataset, Score score) {
this.dataset = dataset;
@@ -70,7 +73,9 @@ public DirectLingam(DataSet dataset, Score score) {
}
/**
- * NEEDS DOCUMENTATION
+ * Performs the search. Returns a graph.
+ *
+ * @return a graph
*/
public Graph search() {
List U = new ArrayList<>(this.variables);
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FactorAnalysis.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FactorAnalysis.java
index 0fa17ed559..73986adb59 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FactorAnalysis.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FactorAnalysis.java
@@ -33,28 +33,28 @@
import static org.apache.commons.math3.util.FastMath.abs;
/**
- * Implements the classical Factor Analysis algorithm. Some references include:
- * Horst, P. (1965). Factor analysis of data matrices. Holt, Rinehart and Winston.
- * This work has good specifications and explanations of factor analysis algorithm and methods of communality
- * estimation.
- *
- * Rummel, R. J. (1988). Applied factor analysis. Northwestern University Press. This
- * book is a good companion to the book listed above. While it doesn't specify any actual algorithm, it has a great
- * introduction to the subject that gives the reader a good appreciation of the philosophy and the mathematics behind
- * factor analysis.
- *
- * This class is not configured to respect knowledge of forbidden and required
- * edges.
+ * Implements the classical Factor Analysis algorithm. Some references include: Horst, P. (1965). Factor analysis of
+ * data matrices. Holt, Rinehart and Winston. This work has good specifications and explanations of factor analysis
+ * algorithm and methods of communality estimation.
+ *
+ * Rummel, R. J. (1988). Applied factor analysis. Northwestern University Press. This book is a good companion to the
+ * book listed above. While it doesn't specify any actual algorithm, it has a great introduction to the subject that
+ * gives the reader a good appreciation of the philosophy and the mathematics behind factor analysis.
+ *
+ * This class is not configured to respect knowledge of forbidden and required edges.
*
* @author Mike Freenor
*/
public class FactorAnalysis {
+ // the covariance matrix
private final CovarianceMatrix covariance;
-
// method-specific fields that get used
private LinkedList factorLoadingVectors;
+ // the threshold for the algorithm
private double threshold = 0.001;
+ // the number of factors to find
private int numFactors = 2;
+ // the residual matrix
private Matrix residual;
/**
@@ -75,80 +75,33 @@ public FactorAnalysis(DataSet dataSet) {
this.covariance = new CovarianceMatrix(dataSet);
}
-
- //designed for normalizing a vector.
- //as usual, vectors are treated as matrices to simplify operations elsewhere
- private static Matrix normalizeRows(Matrix matrix) {
- LinkedList normalizedRows = new LinkedList<>();
- for (int i = 0; i < matrix.getNumRows(); i++) {
- Vector vector = matrix.getRow(i);
- Matrix colVector = new Matrix(matrix.getNumColumns(), 1);
- for (int j = 0; j < matrix.getNumColumns(); j++)
- colVector.set(j, 0, vector.get(j));
-
- normalizedRows.add(FactorAnalysis.normalizeVector(colVector));
- }
-
- Matrix result = new Matrix(matrix.getNumRows(), matrix.getNumColumns());
- for (int i = 0; i < matrix.getNumRows(); i++) {
- Matrix normalizedRow = normalizedRows.get(i);
- for (int j = 0; j < matrix.getNumColumns(); j++) {
- result.set(i, j, normalizedRow.get(j, 0));
- }
- }
-
- return result;
- }
-
- private static Matrix normalizeVector(Matrix vector) {
- double scalar = FastMath.sqrt(vector.transpose().times(vector).get(0, 0));
- return vector.scalarMult(1.0 / scalar);
- }
-
- private static Matrix matrixExp(Matrix matrix, double exponent) {
- Matrix result = new Matrix(matrix.getNumRows(), matrix.getNumColumns());
- for (int i = 0; i < matrix.getNumRows(); i++) {
- for (int j = 0; j < matrix.getNumColumns(); j++) {
- result.set(i, j, FastMath.pow(matrix.get(i, j), exponent));
- }
- }
- return result;
- }
-
/**
* Successive method with residual matrix.
*
- * This algorithm makes use of a helper algorithm.
- * Together, they solve for an unrotated factor loading matrix.
+ * This algorithm makes use of a helper algorithm. Together, they solve for an unrotated factor loading matrix.
*
* This method calls upon its helper to find column vectors, with which it constructs its factor loading matrix.
* Upon receiving each successive column vector from its helper method, it makes sure that we want to keep this
- * vector instead of discarding it.
- * After keeping a vector, a residual matrix is calculated, upon which solving for
+ * vector instead of discarding it. After keeping a vector, a residual matrix is calculated, upon which solving for
* the next column vector is directly dependent.
*
* We stop looking for new vectors either when we've accounted for close to all the variance in the original
* correlation matrix, or when the "d scalar" for a new vector is less than 1 (the d-scalar is the corresponding
* diagonal for the factor loading matrix -- thus, when it's less than 1, the vector we've solved for barely
- * accounts for any more variance).
- * This means we've already "pulled out" all the variance we can from the
- * residual matrix, and we should stop as further factors don't explain much more
- * (and serve to complicate the
- * model).
+ * accounts for any more variance). This means we've already "pulled out" all the variance we can from the residual
+ * matrix, and we should stop as further factors don't explain much more (and serve to complicate the model).
*
* PSEUDO-CODE:
*
* 0th Residual Matrix = Original Correlation Matrix Ask helper for the 1st factor (first column vector in our
- * factor loading vector) Add 1st factor's d-scalar
- * (for i'th factor, call its d-scalar the i'th d-scalar) to a list
+ * factor loading vector) Add 1st factor's d-scalar (for i'th factor, call its d-scalar the i'th d-scalar) to a list
* of d-scalars.
*
* While the ratio of the sum of d-scalars to the trace of the original correlation matrix is less than .99 (in
* other words, while we haven't accounted for practically all the variance):
*
* i'th residual matrix = (i - 1)'th residual matrix SUBTRACT the major product moment of (i - 1)'th factor loading
- * vector Ask helper for i'th factor If i'th factor's d-value is less than 1, throw it out and end loop.
- * Otherwise,
+ * vector Ask helper for i'th factor If i'th factor's d-value is less than 1, throw it out and end loop. Otherwise,
* add it to the factor loading matrix and continue loop.
*
* END PSEUDO-CODE
@@ -286,8 +239,6 @@ public Matrix successiveFactorVarimax(Matrix factorLoadingMatrix) {
return result;
}
- // ------------------Private methods-------------------//
-
/**
* Sets the threshold.
*
@@ -315,21 +266,56 @@ public Matrix getResidual() {
return this.residual;
}
+
+ //designed for normalizing a vector.
+ //as usual, vectors are treated as matrices to simplify operations elsewhere
+ private static Matrix normalizeRows(Matrix matrix) {
+ LinkedList normalizedRows = new LinkedList<>();
+ for (int i = 0; i < matrix.getNumRows(); i++) {
+ Vector vector = matrix.getRow(i);
+ Matrix colVector = new Matrix(matrix.getNumColumns(), 1);
+ for (int j = 0; j < matrix.getNumColumns(); j++)
+ colVector.set(j, 0, vector.get(j));
+
+ normalizedRows.add(FactorAnalysis.normalizeVector(colVector));
+ }
+
+ Matrix result = new Matrix(matrix.getNumRows(), matrix.getNumColumns());
+ for (int i = 0; i < matrix.getNumRows(); i++) {
+ Matrix normalizedRow = normalizedRows.get(i);
+ for (int j = 0; j < matrix.getNumColumns(); j++) {
+ result.set(i, j, normalizedRow.get(j, 0));
+ }
+ }
+
+ return result;
+ }
+
+ private static Matrix normalizeVector(Matrix vector) {
+ double scalar = FastMath.sqrt(vector.transpose().times(vector).get(0, 0));
+ return vector.scalarMult(1.0 / scalar);
+ }
+
+ private static Matrix matrixExp(Matrix matrix, double exponent) {
+ Matrix result = new Matrix(matrix.getNumRows(), matrix.getNumColumns());
+ for (int i = 0; i < matrix.getNumRows(); i++) {
+ for (int j = 0; j < matrix.getNumColumns(); j++) {
+ result.set(i, j, FastMath.pow(matrix.get(i, j), exponent));
+ }
+ }
+ return result;
+ }
+
/**
- * Helper method for the basic structure successive factor method above.
- * Takes a residual matrix and an approximation
- * vector, and finds both the factor loading vector and the "d scalar"
- * which is used to determine the amount of
- * total variance accounted for so far.
+ * Helper method for the basic structure successive factor method above. Takes a residual matrix and an
+ * approximation vector, and finds both the factor loading vector and the "d scalar" which is used to determine the
+ * amount of total variance accounted for so far.
*
- * The helper takes, to begin with, the unit vector as its approximation to the factor column vector.
- * With each
- * iteration, it approximates a bit closer --
- * the d-scalar for each successive step eventually converges to a value
+ * The helper takes, to begin with, the unit vector as its approximation to the factor column vector. With each
+ * iteration, it approximates a bit closer -- the d-scalar for each successive step eventually converges to a value
* (provably).
*
- * Thus, the ratio between the last iteration's d-scalar and this iteration's d-scalar should approach 1.
- * When this
+ * Thus, the ratio between the last iteration's d-scalar and this iteration's d-scalar should approach 1. When this
* ratio gets sufficiently close to 1, the algorithm halts and returns its getModel approximation.
*
* Important to note: the residual matrix stays fixed for this entire algorithm.
@@ -346,8 +332,7 @@ public Matrix getResidual() {
* times yet, a failsafe):
*
* i'th U Vector = residual matrix * (i - 1)'th factor loading i'th L Scalar = transpose((i - 1)'th factor loading)
- * * i'th U Vector i'th D Scalar = square root(i'th L Scalar)
- * i'th factor loading = i'th U Vector / i'th D Scalar
+ * * i'th U Vector i'th D Scalar = square root(i'th L Scalar) i'th factor loading = i'th U Vector / i'th D Scalar
*
* Return the final i'th factor loading as our best approximation.
*/
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fas.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fas.java
index 8942219381..b14950acb5 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fas.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fas.java
@@ -34,27 +34,25 @@
import java.util.*;
/**
- *
Implements the Fast Adjacency Search (FAS), which is the adjacency search of the PC algorithm (see). This is a
- * useful algorithm in many contexts, including as the first step of FCI (see).
- *
- * The idea of FAS is that at a given stage of the search, an edge X*-*Y is removed from the
- * graph if X _||_ Y | S, where S is a subset of size d either of adj(X) or of adj(Y), where d is the depth of the
- * search. The fast adjacency search performs this procedure for each pair of adjacent edges in the graph and for each
- * depth d = 0, 1, 2, ..., d1, where d1 is either the maximum depth or else the first such depth at which no edges can
- * be removed. The interpretation of this adjacency search is different for different algorithm, depending on the
- * assumptions of the algorithm. A mapping from {x, y} to S({x, y}) is returned for edges x *-* y that have been
- * removed.
- *
- * FAS may optionally use a heuristic from Causation, Prediction and Search, which (like PC-Stable)
- * renders the output invariant to the order of the input variables.
- *
- * This algorithm was described in the earlier edition of this book:
- *
- * Spirtes, P., Glymour, C. N., Scheines, R., & Heckerman, D. (2000). Causation, prediction, and search. MIT
- * press.
- *
- * This class is configured to respect knowledge of forbidden and required edges, including knowledge of temporal
- * tiers.
+ * Implements the Fast Adjacency Search (FAS), which is the adjacency search of the PC algorithm (see). This is a useful
+ * algorithm in many contexts, including as the first step of FCI (see).
+ *
+ * The idea of FAS is that at a given stage of the search, an edge X*-*Y is removed from the graph if X _||_ Y | S,
+ * where S is a subset of size d either of adj(X) or of adj(Y), where d is the depth of the search. The fast adjacency
+ * search performs this procedure for each pair of adjacent edges in the graph and for each depth d = 0, 1, 2, ..., d1,
+ * where d1 is either the maximum depth or else the first such depth at which no edges can be removed. The
+ * interpretation of this adjacency search is different for different algorithm, depending on the assumptions of the
+ * algorithm. A mapping from {x, y} to S({x, y}) is returned for edges x *-* y that have been removed.
+ *
+ * FAS may optionally use a heuristic from Causation, Prediction and Search, which (like PC-Stable) renders the output
+ * invariant to the order of the input variables.
+ *
+ * This algorithm was described in the earlier edition of this book:
+ *
+ * Spirtes, P., Glymour, C. N., Scheines, R., & Heckerman, D. (2000). Causation, prediction, and search. MIT press.
+ *
+ * This class is configured to respect knowledge of forbidden and required edges, including knowledge of temporal
+ * tiers.
*
* @author peterspirtes
* @author clarkglymour
@@ -64,19 +62,29 @@
* @see Knowledge
*/
public class Fas implements IFas {
+ // The test to be used for conditional independence tests.
private final IndependenceTest test;
+ // The logger.
private final TetradLogger logger = TetradLogger.getInstance();
+ // The knowledge.
private Knowledge knowledge = new Knowledge();
+ // The number of independence tests that were done.
private int numIndependenceTests;
+ // The sepsets that were discovered in the search.
private SepsetMap sepset = new SepsetMap();
+ // The heuristic to use.
private PcCommon.PcHeuristicType heuristic = PcCommon.PcHeuristicType.NONE;
+ // The depth of the search.
private int depth = 1000;
+ // Whether the stable adjacency search should be used.
private boolean stable = true;
+ // The elapsed time of the search.
private long elapsedTime = 0L;
+ // Whether verbose output should be printed.
private PrintStream out = System.out;
+ // Whether verbose output should be printed.
private boolean verbose = false;
-
/**
* Constructor.
*
@@ -99,10 +107,10 @@ public Graph search() {
/**
* Discovers all adjacencies in data. The procedure is to remove edges in the graph which connect pairs of
- * variables which are independent, conditional on some other set of variables in the graph (the "sepset"). These are
- * removed in tiers. First, edges which are independent conditional on zero other variables are removed, then edges
- * which are independent conditional on one other variable are removed, then two, then three, and so on, until no
- * more edges can be removed from the graph. The edges which remain in the graph after this procedure are the
+ * variables which are independent, conditional on some other set of variables in the graph (the "sepset"). These
+ * are removed in tiers. First, edges which are independent conditional on zero other variables are removed, then
+ * edges which are independent conditional on one other variable are removed, then two, then three, and so on, until
+ * no more edges can be removed from the graph. The edges which remain in the graph after this procedure are the
* adjacencies in the data.
*
* @param nodes A list of nodes to search over.
@@ -175,7 +183,9 @@ public Graph search(List nodes) {
}
for (int d = 0; d <= _depth; d++) {
- System.out.println("Depth: " + d);
+ if (verbose) {
+ System.out.println("Depth: " + d);
+ }
boolean more;
@@ -321,11 +331,11 @@ public void setPcHeuristicType(PcCommon.PcHeuristicType pcHeuristic) {
}
/**
- * Sets whether the stable adjacency search should be used. Default is false. Default is false. See the
- * following reference for this:
- *
- * Colombo, D., & Maathuis, M. H. (2014). Order-independent constraint-based causal structure learning. J. Mach.
- * Learn. Res., 15(1), 3741-3782.
+ * Sets whether the stable adjacency search should be used. Default is false. Default is false. See the following
+ * reference for this:
+ *
+ * Colombo, D., & Maathuis, M. H. (2014). Order-independent constraint-based causal structure learning. J. Mach.
+ * Learn. Res., 15(1), 3741-3782.
*
* @param stable True iff the case.
*/
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fasd.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fasd.java
index 01056c1fe4..79277d689a 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fasd.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fasd.java
@@ -27,20 +27,20 @@
import edu.cmu.tetrad.search.utils.LogUtilsSearch;
import edu.cmu.tetrad.search.utils.SepsetMap;
import edu.cmu.tetrad.util.ChoiceGenerator;
+import edu.cmu.tetrad.util.NumberFormatUtil;
import edu.cmu.tetrad.util.TetradLogger;
import java.io.PrintStream;
-import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.*;
/**
- *
Adjusts FAS (see) for the deterministic case by refusing to removed edges
- * based on conditional independence tests that are judged to be deterministic. That is, if X _||_ Y | Z, but Z
- * determines X or Y, then the edge X---Y is not removed.
- *
- * This class is configured to respect knowledge of forbidden and required
- * edges, including knowledge of temporal tiers.
+ * Adjusts FAS (see) for the deterministic case by refusing to removed edges based on conditional independence tests
+ * that are judged to be deterministic. That is, if X _||_ Y | Z, but Z determines X or Y, then the edge X---Y is not
+ * removed.
+ *
+ * This class is configured to respect knowledge of forbidden and required edges, including knowledge of temporal
+ * tiers.
*
* @author peterspirtes
* @author josephramsey.
@@ -49,49 +49,31 @@
*/
public class Fasd implements IFas {
- /**
- * The independence test. This should be appropriate to the types
- */
+ // The independence test. This should be appropriate to the types
private final IndependenceTest test;
- /**
- * The logger, by default the empty logger.
- */
+ // The logger, by default the empty logger.
private final TetradLogger logger = TetradLogger.getInstance();
- private final NumberFormat nf = new DecimalFormat("0.00E0");
- /**
- * The search graph. It is assumed going in that all the true adjacencies of x are in this graph for every node
- * x. It is hoped (i.e., true in the large sample limit) that true adjacencies are never removed.
- */
+ // The number formatter.
+ private final NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat();
+ // The search graph. It is assumed going in that all the true adjacencies of x are in this graph for every node x.
+ // It is hoped (i.e., true in the large sample limit) that true adjacencies are never removed.
private final Graph graph;
- /**
- * Specification of which edges are forbidden or required.
- */
+ // Specification of which edges are forbidden or required.
private Knowledge knowledge = new Knowledge();
- /**
- * The maximum number of variables conditioned on in any conditional independence test. If the depth is -1, it will
- * be taken to be the maximum value, which is 1000. Otherwise, it should be set to a non-negative integer.
- */
+ // The maximum number of variables conditioned on in any conditional independence test. If the depth is -1, it will
+ // be taken to be the maximum value, which is 1000. Otherwise, it should be set to a non-negative integer.
private int depth = 1000;
- /**
- * The number of independence tests.
- */
+ // The number of independence tests.
private int numIndependenceTests;
- /**
- * The sepsets found during the search.
- */
+ // The sepsets found during the search.
private SepsetMap sepset = new SepsetMap();
- /**
- * The depth 0 graph, specified initially.
- */
+ // The depth 0 graph, specified initially.
private Graph externalGraph;
- /**
- * True iff verbose output should be printed.
- */
+ // True iff verbose output should be printed.
private boolean verbose;
-
+ // The output stream.
private PrintStream out = System.out;
-
/**
* Constructs a new FastAdjacencySearch.
*
@@ -104,10 +86,10 @@ public Fasd(IndependenceTest test) {
/**
* Discovers all adjacencies in data. The procedure is to remove edges in the graph which connect pairs of
- * variables which are independent, conditional on some other set of variables in the graph (the "sepset"). These are
- * removed in tiers. First, edges which are independent conditional on zero other variables are removed, then edges
- * which are independent conditional on one other variable are removed, then two, then three, and so on, until no
- * more edges can be removed from the graph. The edges which remain in the graph after this procedure are the
+ * variables which are independent, conditional on some other set of variables in the graph (the "sepset"). These
+ * are removed in tiers. First, edges which are independent conditional on zero other variables are removed, then
+ * edges which are independent conditional on one other variable are removed, then two, then three, and so on, until
+ * no more edges can be removed from the graph. The edges which remain in the graph after this procedure are the
* adjacencies in the data.
*
* @return a graph which indicates which variables are independent conditional on which other variables
@@ -299,7 +281,6 @@ private boolean searchAtDepth0(List nodes, IndependenceTest test, MapImplements the FASK (Fast Adjacency Skewness) algorithm, which makes decisions for adjacency
- * and orientation using a combination of conditional independence testing, judgments of nonlinear adjacency, and
- * pairwise orientation due to non-Gaussianity. The reference is this:
- *
- * Sanchez-Romero, R., Ramsey, J. D., Zhang, K., Glymour, M. R., Huang, B., and Glymour, C.
- * (2019). Estimating feedforward and feedback effective connections from fMRI time series: Assessments of statistical
- * methods. Network Neuroscience, 3(2), 274-30
- *
- * Some adjustments have been made in some ways from that version, and some additional pairwise options
- * have been added from this reference:
- *
- * Hyvärinen, A., and Smith, S. M. (2013). Pairwise likelihood ratios for estimation of non-Gaussian structural
- * equation models. Journal of Machine Learning Research, 14(Jan), 111-152.
- *
- * This method (and the Hyvarinen and Smith methods) make the assumption that the data are generated by
- * a linear, non-Gaussian causal process and attempts to recover the causal graph for that process. They do not attempt
- * to recover the parametrization of this graph; for this a separate estimation algorithm would be needed, such as
- * linear regression regressing each node onto its parents. A further assumption is made, that there are no latent
- * common causes of the algorithm. This is not a constraint on the pairwise orientation methods, since they orient with
- * respect only to the two variables at the endpoints of an edge and so are happy with all other variables being
- * considered latent with respect to that single edge. However, if the built-in adjacency search is used (FAS-Stable),
- * the existence of latents will throw this method off.
- *
+ * Implements the FASK (Fast Adjacency Skewness) algorithm, which makes decisions for adjacency and orientation using a
+ * combination of conditional independence testing, judgments of nonlinear adjacency, and pairwise orientation due to
+ * non-Gaussianity. The reference is this:
+ *
+ * Sanchez-Romero, R., Ramsey, J. D., Zhang, K., Glymour, M. R., Huang, B., and Glymour, C. (2019). Estimating
+ * feedforward and feedback effective connections from fMRI time series: Assessments of statistical methods. Network
+ * Neuroscience, 3(2), 274-30
+ *
+ * Some adjustments have been made in some ways from that version, and some additional pairwise options have been added
+ * from this reference:
+ *
+ * Hyvärinen, A., and Smith, S. M. (2013). Pairwise likelihood ratios for estimation of non-Gaussian structural equation
+ * models. Journal of Machine Learning Research, 14(Jan), 111-152.
+ *
+ * This method (and the Hyvarinen and Smith methods) make the assumption that the data are generated by a linear,
+ * non-Gaussian causal process and attempts to recover the causal graph for that process. They do not attempt to recover
+ * the parametrization of this graph; for this a separate estimation algorithm would be needed, such as linear
+ * regression regressing each node onto its parents. A further assumption is made, that there are no latent common
+ * causes of the algorithm. This is not a constraint on the pairwise orientation methods, since they orient with respect
+ * only to the two variables at the endpoints of an edge and so are happy with all other variables being considered
+ * latent with respect to that single edge. However, if the built-in adjacency search is used (FAS-Stable), the
+ * existence of latents will throw this method off.
*
* As was shown in the Hyvarinen and Smith paper above, FASK works quite well even if the graph contains feedback loops
* in most configurations, including 2-cycles. 2-cycles can be detected fairly well if the FASK left-right rule is
@@ -124,10 +123,9 @@
* concat_BOLDfslfilter_60_FullMacaque.txt --prefix Fask_Test_MacaqueFull --algorithm fask --faskAdjacencyMethod 1
* --depth -1 --test sem-bic-test --score sem-bic-score --semBicRule 1 --penaltyDiscount 2 --skewEdgeThreshold 0.3
* --faskLeftRightRule 1 --faskDelta -0.3 --twoCycleScreeningThreshold 0 --orientationAlpha 0.1 -structurePrior 0
- *
- *
- * This class is configured to respect knowledge of forbidden and required
- * edges, including knowledge of temporal tiers.
+ *
+ * This class is configured to respect knowledge of forbidden and required edges, including knowledge of temporal
+ * tiers.
*
* @author josephramsey
* @author rubensanchez
@@ -176,6 +174,7 @@ public final class Fask implements IGraphSearch {
private LeftRight leftRight = LeftRight.RSKEW;
// The graph resulting from search.
private Graph graph;
+ private long seed = -1;
/**
* Constructor.
@@ -201,40 +200,83 @@ public Fask(DataSet dataSet, Score score, IndependenceTest test) {
this.orientationAlpha = 0.01;
}
- private static double cu(double[] x, double[] y, double[] condition) {
- double exy = 0.0;
- int n = 0;
+ public static double faskLeftRightV2(double[] x, double[] y, boolean empirical, double delta) {
+ double sx = skewness(x);
+ double sy = skewness(y);
+ double r = correlation(x, y);
+ double lr = Fask.correxp(x, y, x) - Fask.correxp(x, y, y);
- for (int k = 0; k < x.length; k++) {
- if (condition[k] > 0) {
- exy += x[k] * y[k];
- n++;
- }
+ if (empirical) {
+ lr *= signum(sx) * signum(sy);
}
- return exy / n;
+ if (r < delta) {
+ lr *= -1;
+ }
+
+ return lr;
}
- // Returns E(XY | Z > 0); Z is typically either X or Y.
- private static double E(double[] x, double[] y, double[] z) {
- double exy = 0.0;
- int n = 0;
+ public static double faskLeftRightV1(double[] x, double[] y, boolean empirical, double delta) {
+ double left = Fask.cu(x, y, x) / (sqrt(Fask.cu(x, x, x) * Fask.cu(y, y, x)));
+ double right = Fask.cu(x, y, y) / (sqrt(Fask.cu(x, x, y) * Fask.cu(y, y, y)));
+ double lr = left - right;
- for (int k = 0; k < x.length; k++) {
- if (z[k] > 0) {
- exy += x[k] * y[k];
- n++;
- }
+ double r = correlation(x, y);
+ double sx = skewness(x);
+ double sy = skewness(y);
+
+ if (empirical) {
+ r *= signum(sx) * signum(sy);
}
- return exy / n;
+ lr *= signum(r);
+ if (r < delta) lr *= -1;
+
+ return lr;
}
+ public static double robustSkew(double[] x, double[] y, boolean empirical) {
- // Returns E(XY | Z > 0) / sqrt(E(XX | Z > 0) * E(YY | Z > 0)). Z is typically either X or Y.
- private static double correxp(double[] x, double[] y, double[] z) {
- return Fask.E(x, y, z) / sqrt(Fask.E(x, x, z) * Fask.E(y, y, z));
+ if (empirical) {
+ x = correctSkewness(x, skewness(x));
+ y = correctSkewness(y, skewness(y));
+ }
+
+ double[] lr = new double[x.length];
+
+ for (int i = 0; i < x.length; i++) {
+ lr[i] = g(x[i]) * y[i] - x[i] * g(y[i]);
+ }
+
+ return correlation(x, y) * mean(lr);
+ }
+
+ public static double skew(double[] x, double[] y, boolean empirical) {
+
+ if (empirical) {
+ x = correctSkewness(x, skewness(x));
+ y = correctSkewness(y, skewness(y));
+ }
+
+ double[] lr = new double[x.length];
+
+ for (int i = 0; i < x.length; i++) {
+ lr[i] = x[i] * x[i] * y[i] - x[i] * y[i] * y[i];
+ }
+
+ return correlation(x, y) * mean(lr);
+ }
+
+ public static double g(double x) {
+ return log(cosh(FastMath.max(x, 0)));
+ }
+
+ public static double[] correctSkewness(double[] data, double sk) {
+ double[] data2 = new double[data.length];
+ for (int i = 0; i < data.length; i++) data2[i] = data[i] * signum(sk);
+ return data2;
}
/**
@@ -272,10 +314,12 @@ public Graph search() {
if (this.adjacencyMethod == AdjacencyMethod.BOSS) {
PermutationSearch fas = new PermutationSearch(new Boss(this.score));
+ fas.setSeed(seed);
fas.setKnowledge(this.knowledge);
G = fas.search();
} else if (this.adjacencyMethod == AdjacencyMethod.GRASP) {
Grasp fas = new Grasp(this.score);
+ fas.setSeed(seed);
fas.setDepth(5);
fas.setNonSingularDepth(1);
fas.setUncoveredDepth(1);
@@ -469,8 +513,8 @@ public double[][] getB() {
}
/**
- * Returns a matrix of left-right scores for the search. If lr = getLrScores(), then lr[i][j]
- * is the left right scores leftRight(data[i], data[j]);
+ * Returns a matrix of left-right scores for the search. If lr = getLrScores(), then lr[i][j] is the left right
+ * scores leftRight(data[i], data[j]);
*
* @return This matrix as a double[][] array.
*/
@@ -631,72 +675,41 @@ public double leftRight(double[] x, double[] y) {
throw new IllegalStateException("Left right rule not configured: " + this.leftRight);
}
- public static double faskLeftRightV2(double[] x, double[] y, boolean empirical, double delta) {
- double sx = skewness(x);
- double sy = skewness(y);
- double r = correlation(x, y);
- double lr = Fask.correxp(x, y, x) - Fask.correxp(x, y, y);
- if (empirical) {
- lr *= signum(sx) * signum(sy);
- }
-
- if (r < delta) {
- lr *= -1;
- }
-
- return lr;
- }
-
- public static double faskLeftRightV1(double[] x, double[] y, boolean empirical, double delta) {
- double left = Fask.cu(x, y, x) / (sqrt(Fask.cu(x, x, x) * Fask.cu(y, y, x)));
- double right = Fask.cu(x, y, y) / (sqrt(Fask.cu(x, x, y) * Fask.cu(y, y, y)));
- double lr = left - right;
+ private static double cu(double[] x, double[] y, double[] condition) {
+ double exy = 0.0;
- double r = correlation(x, y);
- double sx = skewness(x);
- double sy = skewness(y);
+ int n = 0;
- if (empirical) {
- r *= signum(sx) * signum(sy);
+ for (int k = 0; k < x.length; k++) {
+ if (condition[k] > 0) {
+ exy += x[k] * y[k];
+ n++;
+ }
}
- lr *= signum(r);
- if (r < delta) lr *= -1;
-
- return lr;
+ return exy / n;
}
- public static double robustSkew(double[] x, double[] y, boolean empirical) {
-
- if (empirical) {
- x = correctSkewness(x, skewness(x));
- y = correctSkewness(y, skewness(y));
- }
-
- double[] lr = new double[x.length];
+ // Returns E(XY | Z > 0); Z is typically either X or Y.
+ private static double E(double[] x, double[] y, double[] z) {
+ double exy = 0.0;
+ int n = 0;
- for (int i = 0; i < x.length; i++) {
- lr[i] = g(x[i]) * y[i] - x[i] * g(y[i]);
+ for (int k = 0; k < x.length; k++) {
+ if (z[k] > 0) {
+ exy += x[k] * y[k];
+ n++;
+ }
}
- return correlation(x, y) * mean(lr);
+ return exy / n;
}
- public static double skew(double[] x, double[] y, boolean empirical) {
-
- if (empirical) {
- x = correctSkewness(x, skewness(x));
- y = correctSkewness(y, skewness(y));
- }
-
- double[] lr = new double[x.length];
-
- for (int i = 0; i < x.length; i++) {
- lr[i] = x[i] * x[i] * y[i] - x[i] * y[i] * y[i];
- }
- return correlation(x, y) * mean(lr);
+ // Returns E(XY | Z > 0) / sqrt(E(XX | Z > 0) * E(YY | Z > 0)). Z is typically either X or Y.
+ private static double correxp(double[] x, double[] y, double[] z) {
+ return Fask.E(x, y, z) / sqrt(Fask.E(x, x, z) * Fask.E(y, y, z));
}
private double tanh(double[] x, double[] y, boolean empirical) {
@@ -715,10 +728,6 @@ private double tanh(double[] x, double[] y, boolean empirical) {
return correlation(x, y) * mean(lr);
}
- public static double g(double x) {
- return log(cosh(FastMath.max(x, 0)));
- }
-
private boolean knowledgeOrients(Node X, Node Y) {
return this.knowledge.isForbidden(Y.getName(), X.getName()) || this.knowledge.isRequired(X.getName(), Y.getName());
}
@@ -727,12 +736,6 @@ private boolean edgeForbiddenByKnowledge(Node X, Node Y) {
return this.knowledge.isForbidden(Y.getName(), X.getName()) && this.knowledge.isForbidden(X.getName(), Y.getName());
}
- public static double[] correctSkewness(double[] data, double sk) {
- double[] data2 = new double[data.length];
- for (int i = 0; i < data.length; i++) data2[i] = data[i] * signum(sk);
- return data2;
- }
-
private boolean twoCycleTest(int i, int j, double[][] D, Graph G0, List V) {
Node X = V.get(i);
Node Y = V.get(j);
@@ -853,6 +856,10 @@ private void logTwoCycle(NumberFormat nf, List variables, double[][] d, No
);
}
+ public void setSeed(long seed) {
+ this.seed = seed;
+ }
+
/**
* Enumerates the options left-right rules to use for FASK. Options include the FASK left-right rule and three
* left-right rules from the Hyvarinen and Smith pairwise orientation paper: Robust Skew, Skew, and Tanh. In that
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FastIca.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FastIca.java
index 7692028629..9026f5bf9e 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FastIca.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FastIca.java
@@ -31,21 +31,22 @@
import static org.apache.commons.math3.util.FastMath.*;
/**
- * Translates a version of the FastICA algorithm used in R from Fortran
- * into Java for use in Tetrad. This can be used in various algorithms that assume linearity and non-gaussianity, as for
- * example LiNGAM and LiNG-D. There is one difference from the R, in that in R FastICA can operate over complex numbers,
- * whereas here it is restricted to real numbers. A useful reference is this:
- *
- * Oja, E., & Hyvarinen, A. (2000). Independent component analysis:
- * algorithms and applications. Neural networks, 13(4-5), 411-430.
- *
- * The documentation of the R version is as follows, all of which is true of this
- * translation (so far as I know) except for its being in R and its allowing complex values.
+ * Translates a version of the FastICA algorithm used in R from Fortran into Java for use in Tetrad. This can be used in
+ * various algorithms that assume linearity and non-gaussianity, as for example LiNGAM and LiNG-D. There is one
+ * difference from the R, in that in R FastICA can operate over complex numbers, whereas here it is restricted to real
+ * numbers. A useful reference is this:
+ *
+ * Oja, E., & Hyvarinen, A. (2000). Independent component analysis: algorithms and applications. Neural networks,
+ * 13(4-5), 411-430.
+ *
+ * The documentation of the R version is as follows, all of which is true of this translation (so far as I know) except
+ * for its being in R and its allowing complex values.
*
* Description:
*
* This is an R and C code implementation of the FastICA algorithm of Aapo Hyvarinen et al. (URL:
- * http://www.cis.hut.fi/aapo/) to perform Independent Component Analysis (ICA) and Projection Pursuit.
+ * http://www.cis.hut.fi/aapo/) to perform Independent Component Analysis
+ * (ICA) and Projection Pursuit.
*
* Usage:
*
@@ -98,8 +99,7 @@
* First, the data is centered by subtracting the mean of each column of the data matrix X.
*
* The data matrix is then `whitened' by projecting the data onto it's principle component directions i.e. X -> XK
- * where K is a pre-whitening matrix.
- * The user can specify the number of components.
+ * where K is a pre-whitening matrix. The user can specify the number of components.
*
* The ICA algorithm then estimates a matrix W s.t XKW = S . W is chosen to maximize the neg-entropy approximation under
* the constraints that W is an orthonormal matrix. This constraint ensures that the estimated components are
@@ -120,86 +120,57 @@
*
* A. Hyvarinen and E. Oja (2000) Independent Component Analysis: Algorithms and Applications, _Neural Networks_,
* *13(4-5)*:411-430
- *
*
* @author josephramsey
*/
public class FastIca {
- /**
- * The algorithm type where all components are extracted simultaneously.
- */
+ // The algorithm type where all components are extracted simultaneously.
public static int PARALLEL;
- /**
- * The algorithm type where the components are extracted one at a time.
- */
+ // The algorithm type where the components are extracted one at a time.
public static int DEFLATION = 1;
- /**
- * One of the function types that can be used to approximate negative entropy.
- */
+ // One of the function types that can be used to approximate negative entropy.
public static int LOGCOSH = 2;
- /**
- * The other function type that can be used to approximate negative entropy.
- */
+ // The other function type that can be used to approximate negative entropy.
public static int EXP = 3;
- /**
- * A data matrix with n rows representing observations and p columns representing variables.
- */
+ // A data matrix with n rows representing observations and p columns representing variables.
private final Matrix X;
- /**
- * The number of independent components to be extracted.
- */
+ // The number of independent components to be extracted.
private int numComponents;
- /**
- * If algorithmType == PARALLEL, the components are extracted simultaneously (the default).
- * if algorithmType == DEFLATION, the components are extracted one at a time.
- */
+ // If algorithmType == PARALLEL, the components are extracted simultaneously (the default). if algorithmType ==
+ // DEFLATION, the components are extracted one at a time.
private int algorithmType = FastIca.PARALLEL;
- /**
- * The function type to be used, either LOGCOSH or EXP.
- */
+ // The function type to be used, either LOGCOSH or EXP.
private int function = FastIca.LOGCOSH;
- /**
- * Constant in range [1, 2] used in approximation to neg-entropy when 'fun == "logcosh". Default = 1.0.
- */
+ // Constant in range [1, 2] used in approximation to neg-entropy when 'fun == "logcosh". Default = 1.0.
private double alpha = 1.1;
- /**
- * A logical value indicating whether rows of the data matrix 'X' should be standardized beforehand. Default =
- * false.
- */
+ // A logical value indicating whether rows of the data matrix 'X' should be standardized beforehand. Default =
+ // false.
private boolean rowNorm;
- /**
- * Maximum number of iterations to perform. Default = 200.
- */
+ // Maximum number of iterations to perform. Default = 200.
private int maxIterations = 200;
- /**
- * A positive scalar giving the tolerance at which the un-mixing matrix is considered to have converged. Default =
- * 1e-04.
- */
+ // A positive scalar giving the tolerance at which the un-mixing matrix is considered to have converged. Default =
+ // 1e-04.
private double tolerance = 1e-04;
- /**
- * A logical value indicating the level of output as the algorithm runs. Default = false.
- */
+ // A logical value indicating the level of output as the algorithm runs. Default = false.
private boolean verbose;
- /**
- * Initial un-mixing matrix of dimension (n.comp,n.comp). If null (default), then a matrix of normal r.v.'s is used.
- */
+ // Initial un-mixing matrix of dimension (n.comp,n.comp). If null (default), then a matrix of normal r.v.'s is
+ // used.
private Matrix wInit;
-
/**
* Constructs an instance of the Fast ICA algorithm, taking as arguments the two arguments that cannot be defaulted:
* the data matrix itself and the number of components to be extracted.
@@ -211,10 +182,9 @@ public FastIca(Matrix X, int numComponents) {
this.numComponents = numComponents;
}
-
/**
- * If algorithmType == PARALLEL, the components are extracted simultaneously (the default).
- * if algorithmType == DEFLATION, the components are extracted one at a time.
+ * If algorithmType == PARALLEL, the components are extracted simultaneously (the default). if algorithmType ==
+ * DEFLATION, the components are extracted one at a time.
*
* @param algorithmType This type.
*/
@@ -297,8 +267,8 @@ public void setVerbose(boolean verbose) {
}
/**
- * Sets the initial un-mixing matrix of dimension (n.comp,n.comp).
- * If NULL (default), then a random matrix of normal r.v.'s is used.
+ * Sets the initial un-mixing matrix of dimension (n.comp,n.comp). If NULL (default), then a random matrix of normal
+ * r.v.'s is used.
*
* @param wInit This matrix.
*/
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fci.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fci.java
index fee0ed8e8a..fc16f74b6b 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fci.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fci.java
@@ -38,28 +38,25 @@
import java.util.Set;
/**
- * Implements the Fast Causal Inference (FCI) algorithm due to Peter Spirtes, which addressed
- * the case where latent common causes cannot be assumed not to exist with respect to the data set being analyzed. That
- * is, it is assumed that there may be variables that are not included in the data that nonetheless may be causes of two
- * or more variables that are included in data.
- *
- * Two alternatives are provided for doing the final orientation step, one due to Peter Spirtes,
- * which is arrow complete, and another due to Jiji Zhang, which is arrow and tail complete.
- *
- * This algorithm, with the Spirtes final orientation rules, was given in an earlier version of
- * this book:
- *
- * Spirtes, P., Glymour, C. N., Scheines, R., & Heckerman, D. (2000). Causation,
- * prediction, and search. MIT press.
- *
- * The algorithm with the Zhang final orientation rules was given in this reference:
- *
- * Zhang, J. (2008). On the completeness of orientation rules for causal discovery in the presence
- * of latent confounders and selection bias. Artificial Intelligence, 172(16-17), 1873-1896.
- *
- *
- * This class is configured to respect knowledge of forbidden and required
- * edges, including knowledge of temporal tiers.
+ * Implements the Fast Causal Inference (FCI) algorithm due to Peter Spirtes, which addressed the case where latent
+ * common causes cannot be assumed not to exist with respect to the data set being analyzed. That is, it is assumed that
+ * there may be variables that are not included in the data that nonetheless may be causes of two or more variables that
+ * are included in data.
+ *
+ * Two alternatives are provided for doing the final orientation step, one due to Peter Spirtes, which is arrow
+ * complete, and another due to Jiji Zhang, which is arrow and tail complete.
+ *
+ * This algorithm, with the Spirtes final orientation rules, was given in an earlier version of this book:
+ *
+ * Spirtes, P., Glymour, C. N., Scheines, R., & Heckerman, D. (2000). Causation, prediction, and search. MIT press.
+ *
+ * The algorithm with the Zhang final orientation rules was given in this reference:
+ *
+ * Zhang, J. (2008). On the completeness of orientation rules for causal discovery in the presence of latent confounders
+ * and selection bias. Artificial Intelligence, 172(16-17), 1873-1896.
+ *
+ * This class is configured to respect knowledge of forbidden and required edges, including knowledge of temporal
+ * tiers.
*
* @author peterspirtes
* @author clarkglymour
@@ -72,22 +69,35 @@
* @see Knowledge
*/
public final class Fci implements IGraphSearch {
+ // The variables to search over.
private final List variables = new ArrayList<>();
+ // The independence test to use.
private final IndependenceTest independenceTest;
+ // The logger.
private final TetradLogger logger = TetradLogger.getInstance();
+ // The sepsets from FAS.
private SepsetMap sepsets;
+ // The background knowledge.
private Knowledge knowledge = new Knowledge();
+ // Whether the Zhang complete rule set should be used.
private boolean completeRuleSetUsed = true;
+ // Whether the possible msep step should be done.
private boolean possibleMsepSearchDone = true;
+ // The maximum length of any discriminating path.
private int maxPathLength = -1;
+ // The depth of search.
private int depth = -1;
+ // The elapsed time of search.
private long elapsedTime;
+ // Whether verbose output should be printed.
private boolean verbose;
+ // The PC heuristic type to use.
private PcCommon.PcHeuristicType heuristic = PcCommon.PcHeuristicType.NONE;
+ // Whether the stable options should be used.
private boolean stable = true;
+ // Whether the discriminating path rule should be used.
private boolean doDiscriminatingPathRule = true;
-
/**
* Constructor.
*
@@ -136,7 +146,11 @@ public Fci(IndependenceTest independenceTest, List searchVars) {
this.variables.removeAll(remVars);
}
-
+ /**
+ * Performs the search.
+ *
+ * @return The graph.
+ */
public Graph search() {
long start = MillisecondTimes.timeMillis();
@@ -246,8 +260,8 @@ public void setKnowledge(Knowledge knowledge) {
}
/**
- * Sets whether the Zhang complete rule set should be used; false if only R1-R4 (the rule set of the original
- * FCI) should be used. False by default.
+ * Sets whether the Zhang complete rule set should be used; false if only R1-R4 (the rule set of the original FCI)
+ * should be used. False by default.
*
* @param completeRuleSetUsed True for the complete Zhang rule set.
*/
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FciMax.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FciMax.java
index 362d9906b7..16fd4af3e8 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FciMax.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FciMax.java
@@ -29,6 +29,7 @@
import edu.cmu.tetrad.search.utils.SepsetMap;
import edu.cmu.tetrad.search.utils.SepsetsSet;
import edu.cmu.tetrad.util.*;
+import org.jetbrains.annotations.NotNull;
import java.util.ArrayList;
import java.util.List;
@@ -38,25 +39,24 @@
import java.util.concurrent.RecursiveTask;
/**
- * Modifies FCI to do orientation of unshielded colliders (X*-*Y*-*Z with X and Z
- * not adjacent) using the max-P rule (see the PC-Max algorithm). This reference is relevant:
- *
- * Raghu, V. K., Zhao, W., Pu, J., Leader, J. K., Wang, R., Herman, J., ... &
- * Wilson, D. O. (2019). Feasibility of lung cancer prediction from low-dose CT scan and smoking factors using causal
- * models. Thorax, 74(7), 643-649.
- *
- * Max-P triple orientation is a method for orienting unshielded triples
- * X*=-*Y*-*Z as one of the following: (a) Collider, X->Y<-Z, or (b) Noncollider, X-->Y-->Z, or X<-Y<-Z, or X<-Y->Z. One
- * does this by conditioning on subsets of adj(X) or adj(Z). One first checks conditional independence of X and Z
- * conditional on each of these subsets, and lists the p-values for each test. Then, one chooses the conditioning set
- * out of all of these that maximizes the p-value. If this conditioning set contains Y, then the triple is judged to be
- * a noncollider; otherwise, it is judged to be a collider.
- *
- * All unshielded triples in the graph given by FAS are judged as colliders
- * or non-colliders and the colliders oriented. Then the final FCI orientation rules are applied, as in FCI.
- *
- * This class is configured to respect knowledge of forbidden and required
- * edges, including knowledge of temporal tiers.
+ * Modifies FCI to do orientation of unshielded colliders (X*-*Y*-*Z with X and Z not adjacent) using the max-P rule
+ * (see the PC-Max algorithm). This reference is relevant:
+ *
+ * Raghu, V. K., Zhao, W., Pu, J., Leader, J. K., Wang, R., Herman, J., ... & Wilson, D. O. (2019). Feasibility of
+ * lung cancer prediction from low-dose CT scan and smoking factors using causal models. Thorax, 74(7), 643-649.
+ *
+ * Max-P triple orientation is a method for orienting unshielded triples X*=-*Y*-*Z as one of the following: (a)
+ * Collider, X->Y<-Z, or (b) Noncollider, X-->Y-->Z, or X<-Y<-Z, or X<-Y->Z. One does this by
+ * conditioning on subsets of adj(X) or adj(Z). One first checks conditional independence of X and Z conditional on each
+ * of these subsets, and lists the p-values for each test. Then, one chooses the conditioning set out of all of these
+ * that maximizes the p-value. If this conditioning set contains Y, then the triple is judged to be a noncollider;
+ * otherwise, it is judged to be a collider.
+ *
+ * All unshielded triples in the graph given by FAS are judged as colliders or non-colliders and the colliders oriented.
+ * Then the final FCI orientation rules are applied, as in FCI.
+ *
+ * This class is configured to respect knowledge of forbidden and required edges, including knowledge of temporal
+ * tiers.
*
* @author josephramsey
* @see Fci
@@ -65,21 +65,33 @@
* @see Knowledge
*/
public final class FciMax implements IGraphSearch {
+ // The independence test.
private final IndependenceTest independenceTest;
+ // The logger.
private final TetradLogger logger = TetradLogger.getInstance();
+ // The sepsets from the FAS search.
private SepsetMap sepsets;
+ // The background knowledge.
private Knowledge knowledge = new Knowledge();
+ // The elapsed time of search.
private long elapsedTime;
+ // The PC heuristic from PC used in search.
private PcCommon.PcHeuristicType pcHeuristicType = PcCommon.PcHeuristicType.NONE;
+ // Whether the stable option will be used for search.
private boolean stable = false;
+ // Whether the discriminating path rule will be used in search.
private boolean completeRuleSetUsed = true;
+ // Whether the discriminating path rule will be used in search.
private boolean doDiscriminatingPathRule = false;
+ // Whether the discriminating path rule will be used in search.
private boolean possibleMsepSearchDone = true;
+ // The maximum length of any discriminating path, or -1 if unlimited.
private int maxPathLength = -1;
+ // The maximum number of variables conditioned in any test.
private int depth = -1;
+ // Whether verbose output should be printed.
private boolean verbose = false;
-
/**
* Constructor.
*/
@@ -91,7 +103,6 @@ public FciMax(IndependenceTest independenceTest) {
this.independenceTest = independenceTest;
}
-
/**
* Performs the search and returns the PAG.
*
@@ -129,14 +140,7 @@ public Graph search() {
// Step CI C (Zhang's step F3.)
- FciOrient fciOrient = new FciOrient(new SepsetsSet(this.sepsets, this.independenceTest));
-
- fciOrient.setCompleteRuleSetUsed(this.completeRuleSetUsed);
- fciOrient.setMaxPathLength(this.maxPathLength);
- fciOrient.setDoDiscriminatingPathColliderRule(this.doDiscriminatingPathRule);
- fciOrient.setDoDiscriminatingPathTailRule(this.doDiscriminatingPathRule);
- fciOrient.setVerbose(this.verbose);
- fciOrient.setKnowledge(this.knowledge);
+ FciOrient fciOrient = getFciOrient();
fciOrient.fciOrientbk(this.knowledge, graph, graph.getNodes());
addColliders(graph);
@@ -281,6 +285,19 @@ public void setDoDiscriminatingPathRule(boolean doDiscriminatingPathRule) {
this.doDiscriminatingPathRule = doDiscriminatingPathRule;
}
+ @NotNull
+ private FciOrient getFciOrient() {
+ FciOrient fciOrient = new FciOrient(new SepsetsSet(this.sepsets, this.independenceTest));
+
+ fciOrient.setCompleteRuleSetUsed(this.completeRuleSetUsed);
+ fciOrient.setMaxPathLength(this.maxPathLength);
+ fciOrient.setDoDiscriminatingPathColliderRule(this.doDiscriminatingPathRule);
+ fciOrient.setDoDiscriminatingPathTailRule(this.doDiscriminatingPathRule);
+ fciOrient.setVerbose(this.verbose);
+ fciOrient.setKnowledge(this.knowledge);
+ return fciOrient;
+ }
+
private void addColliders(Graph graph) {
Map scores = new ConcurrentHashMap<>();
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fges.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fges.java
index 094758eeb9..e3324f03e6 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fges.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fges.java
@@ -43,37 +43,33 @@
import static org.apache.commons.math3.util.FastMath.min;
/**
- * Implements the Fast Greedy Equivalence Search (FGES) algorithm. This is
- * an implementation of the Greedy Equivalence Search algorithm, originally due to Chris Meek but developed
- * significantly by Max Chickering. FGES uses with some optimizations that allow it to scale accurately to thousands of
- * variables accurately for the sparse case. The reference for FGES is this:
- *
- * Ramsey, J., Glymour, M., Sanchez-Romero, R., & Glymour, C. (2017).
- * A million variables and more: the fast greedy equivalence search algorithm for learning high-dimensional graphical
- * causal models, with an application to functional magnetic resonance images. International journal of data science and
- * analytics, 3, 121-129.
- *
- * The reference for Chickering's GES is this:
- *
- * Chickering (2002) "Optimal structure identification with greedy search"
- * Journal of Machine Learning Research.
- *
- * FGES works for the continuous case, the discrete case, and the mixed
- * continuous/discrete case, so long as a BIC score is available for the type of data in question.
- *
- * To speed things up, it has been assumed that variables X and Y with zero
- * correlation do not correspond to edges in the graph. This is a restricted form of the heuristic speedup assumption,
- * something GES does not assume. This heuristic speedup assumption needs to be explicitly turned on using
- * setHeuristicSpeedup(true).
- *
- * Also, edges to be added or remove from the graph in the forward or
- * backward phase, respectively are cached, together with the ancillary information needed to do the additions or
- * removals, to reduce rescoring.
- *
- * A number of other optimizations were also. See code for details.
- *
- * This class is configured to respect knowledge of forbidden and required
- * edges, including knowledge of temporal tiers.
+ * Implements the Fast Greedy Equivalence Search (FGES) algorithm. This is an implementation of the Greedy Equivalence
+ * Search algorithm, originally due to Chris Meek but developed significantly by Max Chickering. FGES uses with some
+ * optimizations that allow it to scale accurately to thousands of variables accurately for the sparse case. The
+ * reference for FGES is this:
+ *
+ * Ramsey, J., Glymour, M., Sanchez-Romero, R., & Glymour, C. (2017). A million variables and more: the fast greedy
+ * equivalence search algorithm for learning high-dimensional graphical causal models, with an application to functional
+ * magnetic resonance images. International journal of data science and analytics, 3, 121-129.
+ *
+ * The reference for Chickering's GES is this:
+ *
+ * Chickering (2002) "Optimal structure identification with greedy search" Journal of Machine Learning Research.
+ *
+ * FGES works for the continuous case, the discrete case, and the mixed continuous/discrete case, so long as a BIC score
+ * is available for the type of data in question.
+ *
+ * To speed things up, it has been assumed that variables X and Y with zero correlation do not correspond to edges in
+ * the graph. This is a restricted form of the heuristic speedup assumption, something GES does not assume. This
+ * heuristic speedup assumption needs to be explicitly turned on using setHeuristicSpeedup(true).
+ *
+ * Also, edges to be added or remove from the graph in the forward or backward phase, respectively are cached, together
+ * with the ancillary information needed to do the additions or removals, to reduce rescoring.
+ *
+ * A number of other optimizations were also. See code for de tails.
+ *
+ * This class is configured to respect knowledge of forbidden and required edges, including knowledge of temporal
+ * tiers.
*
* @author Ricardo Silva
* @author josephramsey
@@ -83,17 +79,16 @@
* @see Knowledge
*/
public final class Fges implements IGraphSearch, DagScorer {
+ // Used to find semidirected paths for cycle checking.
private final Set emptySet = new HashSet<>();
+ // Used to find semidirected paths for cycle checking.
private final int[] count = new int[1];
+ // Used to find semidirected paths for cycle checking.
private final int depth = 10000;
- /**
- * The logger for this class. The config needs to be set.
- */
+ // The logger for this class. The config needs to be set.
private final TetradLogger logger = TetradLogger.getInstance();
- /**
- * The top n graphs found by the algorithm, where n is numPatternsToStore.
- */
+ // The top n graphs found by the algorithm, where n is numPatternsToStore.
private final LinkedList topGraphs = new LinkedList<>();
// Potential arrows sorted by bump high to low. The first one is a candidate for adding to the graph.
private final SortedSet sortedArrows = new ConcurrentSkipListSet<>();
@@ -101,63 +96,41 @@ public final class Fges implements IGraphSearch, DagScorer {
private final Map arrowsMap = new ConcurrentHashMap<>();
// private final Map arrowsMapBackward = new ConcurrentHashMap<>();
private boolean faithfulnessAssumed = false;
- /**
- * Specification of forbidden and required edges.
- */
+ // Specification of forbidden and required edges.
private Knowledge knowledge = new Knowledge();
- /**
- * List of variables in the data set, in order.
- */
+ // List of variables in the data set, in order.
private List variables;
- /**
- * An initial graph to start from.
- */
+ // An initial graph to start from.
private Graph initialGraph;
- /**
- * If non-null, edges not adjacent in this graph will not be added.
- */
+ // If non-null, edges not adjacent in this graph will not be added.
private Graph boundGraph = null;
- /**
- * Elapsed time of the most recent search.
- */
+ // Elapsed time of the most recent search.
private long elapsedTime;
- /**
- * The totalScore for discrete searches.
- */
+ // The totalScore for discrete searches.
private Score score;
- /**
- * True if verbose output should be printed.
- */
+ // True if verbose output should be printed.
private boolean verbose = false;
private boolean meekVerbose = false;
// Map from variables to their column indices in the data set.
private ConcurrentMap hashIndices;
// A graph where X--Y means that X and Y have non-zero total effect on one another.
private Graph effectEdgesGraph;
-
// Where printed output is sent.
private PrintStream out = System.out;
-
// The graph being constructed.
private Graph graph;
-
// Arrows with the same totalScore are stored in this list to distinguish their order in sortedArrows.
// The ordering doesn't matter; it just has to be transitive.
private int arrowIndex = 0;
-
// The score of the model.
private double modelScore;
-
// Internal.
private Mode mode = Mode.heuristicSpeedup;
-
// Bounds the degree of the graph.
private int maxDegree = -1;
-
// True if the first step of adding an edge to an empty graph should be scored in both directions
// for each edge with the maximum score chosen.
private boolean symmetricFirstStep = false;
-
// True, if FGES should run in a single thread, no if parallelized.
private boolean parallelized = false;
@@ -195,8 +168,8 @@ private static Node traverseSemiDirected(Node node, Edge edge) {
}
/**
- * Greedy equivalence search: Start from the empty graph, add edges till the model is significant.
- * Then start deleting edges till a minimum is achieved.
+ * Greedy equivalence search: Start from the empty graph, add edges till the model is significant. Then start
+ * deleting edges till a minimum is achieved.
*
* @return the resulting Pattern.
*/
@@ -240,7 +213,7 @@ public Graph search() {
this.logger.forceLogMessage("Elapsed time = " + (elapsedTime) / 1000. + " s");
}
- this.modelScore = scoreDag(GraphTransforms.dagFromCPDAG(graph, null), true);
+ this.modelScore = scoreDag(GraphTransforms.dagFromCpdag(graph, null), true);
return graph;
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FgesMb.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FgesMb.java
index b78b44f668..cf3940d42c 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FgesMb.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FgesMb.java
@@ -43,37 +43,33 @@
import static org.apache.commons.math3.util.FastMath.min;
/**
- * Implements the Fast Greedy Equivalence Search (FGES) algorithm. This is
- * an implementation of the Greedy Equivalence Search algorithm, originally due to Chris Meek but developed
- * significantly by Max Chickering. FGES uses with some optimizations that allow it to scale accurately to thousands of
- * variables accurately for the sparse case. The reference for FGES is this:
- *
- * Ramsey, J., Glymour, M., Sanchez-Romero, R., & Glymour, C. (2017).
- * A million variables and more: the fast greedy equivalence search algorithm for learning high-dimensional graphical
- * causal models, with an application to functional magnetic resonance images. International journal of data science and
- * analytics, 3, 121-129.
- *
- * The reference for Chickering's GES is this:
- *
- * Chickering (2002) "Optimal structure identification with greedy search"
- * Journal of Machine Learning Research.
- *
- * FGES works for the continuous case, the discrete case, and the mixed
- * continuous/discrete case, so long as a BIC score is available for the type of data in question.
- *
- * To speed things up, it has been assumed that variables X and Y with zero
- * correlation do not correspond to edges in the graph. This is a restricted form of the heuristic speedup assumption,
- * something GES does not assume. This heuristic speedup assumption needs to be explicitly turned on using
- * setHeuristicSpeedup(true).
- *
- * Also, edges to be added or remove from the graph in the forward or
- * backward phase, respectively are cached, together with the ancillary information needed to do the additions or
- * removals, to reduce rescoring.
- *
- * A number of other optimizations were also. See code for details.
- *
- * This class is configured to respect knowledge of forbidden and required
- * edges, including knowledge of temporal tiers.
+ * Implements the Fast Greedy Equivalence Search (FGES) algorithm. This is an implementation of the Greedy Equivalence
+ * Search algorithm, originally due to Chris Meek but developed significantly by Max Chickering. FGES uses with some
+ * optimizations that allow it to scale accurately to thousands of variables accurately for the sparse case. The
+ * reference for FGES is this:
+ *
+ * Ramsey, J., Glymour, M., Sanchez-Romero, R., & Glymour, C. (2017). A million variables and more: the fast greedy
+ * equivalence search algorithm for learning high-dimensional graphical causal models, with an application to functional
+ * magnetic resonance images. International journal of data science and analytics, 3, 121-129.
+ *
+ * The reference for Chickering's GES is this:
+ *
+ * Chickering (2002) "Optimal structure identification with greedy search" Journal of Machine Learning Research.
+ *
+ * FGES works for the continuous case, the discrete case, and the mixed continuous/discrete case, so long as a BIC score
+ * is available for the type of data in question.
+ *
+ * To speed things up, it has been assumed that variables X and Y with zero correlation do not correspond to edges in
+ * the graph. This is a restricted form of the heuristic speedup assumption, something GES does not assume. This
+ * heuristic speedup assumption needs to be explicitly turned on using setHeuristicSpeedup(true).
+ *
+ * Also, edges to be added or remove from the graph in the forward or backward phase, respectively are cached, together
+ * with the ancillary information needed to do the additions or removals, to reduce rescoring.
+ *
+ * A number of other optimizations were also. See code for details.
+ *
+ * This class is configured to respect knowledge of forbidden and required edges, including knowledge of temporal
+ * tiers.
*
* @author Ricardo Silva
* @author josephramsey
@@ -83,45 +79,42 @@
* @see Knowledge
*/
public final class FgesMb implements DagScorer {
- public enum TrimmingStyle {
- NONE, ADJACENT_TO_TARGETS, MARKOV_BLANKET_GRAPH, SEMIDIRECTED_PATHS_TO_TARGETS
- }
-
+ //===internal===//
+ private final Set emptySet = new HashSet<>();
+ private final int[] count = new int[1];
+ private final int depth = 10000;
+ //The top n graphs found by the algorithm, where n is numPatternsToStore.
+ private final LinkedList topGraphs = new LinkedList<>();
+ // Potential arrows sorted by bump high to low. The first one is a candidate for adding to the graph.
+ private final SortedSet sortedArrows = new ConcurrentSkipListSet<>();
+ private final TetradLogger logger = TetradLogger.getInstance();
+ private final Map arrowsMap = new ConcurrentHashMap<>();
+ List targets = new ArrayList<>();
// The number of times the forward phase is iterated to expand to new adjacencies.
private int numExpansions = 2;
-
// The style of trimming to use.
private int trimmingStyle = 3; // default MB trimming.
-
// Bounds the degree of the graph.
private int maxDegree = -1;
-
// Whether one-edge faithfulness is assumed (less general but faster).
private boolean faithfulnessAssumed = false;
-
// The knowledge to use in the search.
private Knowledge knowledge = new Knowledge();
-
// True, if FGES should run in a single thread, no if parallelized.
private boolean parallelized = false;
-
- //===internal===//
- private final Set