diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/Tetrad.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/Tetrad.java index 2381b03061..c28b289fc6 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/Tetrad.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/Tetrad.java @@ -70,6 +70,12 @@ public final class Tetrad implements PropertyChangeListener { "Tetrad " + Version.currentViewableVersion() .toString(); + + /** + * Skip latest version checking + */ + private static boolean skipLatest; + //==============================CONSTRUCTORS===========================// public Tetrad() { @@ -98,7 +104,7 @@ public void propertyChange(final PropertyChangeEvent e) { * contents and all of the images which Tetrad IV uses, all in their proper * relative directories.

* - * @param argv moves line arguments (none for now). + * @param argv --skip-latest argument will skip checking for latest version. */ public static void main(final String[] argv) { setLookAndFeel(); @@ -106,6 +112,14 @@ public static void main(final String[] argv) { // This is needed to get numbers to be parsed and rendered uniformly, especially in the interface. Locale.setDefault(Locale.US); + // Check if we should skip checking for latest version + if (argv.length > 0 && argv[0] != null && argv[0].compareToIgnoreCase("--skip-latest") == 0) { + skipLatest = true; + } else { + skipLatest = false; + } + + new Tetrad().launchFrame(); } @@ -174,7 +188,7 @@ public Dimension getPreferredSize() { // this.frame.setMinimumSize(Toolkit.getDefaultToolkit().getScreenSize()); // this.frame.setMaximumSize(Toolkit.getDefaultToolkit().getScreenSize()); - SplashScreen.show(getFrame(), "Loading Tetrad...", 1000); + SplashScreen.show(getFrame(), "Loading Tetrad...", 1000, skipLatest); getFrame().setContentPane(getDesktop()); getFrame().pack(); @@ -231,6 +245,7 @@ private JFrame getFrame() { private TetradDesktop getDesktop() { return desktop; } + } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/ConstructTemplateAction.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/ConstructTemplateAction.java index 0a5979818b..6cdbd288f8 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/ConstructTemplateAction.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/ConstructTemplateAction.java @@ -424,7 +424,7 @@ private Node addNode(String nodeType, String nodeName, int centerX, return node; } - private void addEdge(String nodeName1, String nodeName2) { + public void addEdge(String nodeName1, String nodeName2) { // Retrieve the nodes from the session wrapper. Node node1 = getSessionWrapper().getNode(nodeName1); @@ -451,6 +451,8 @@ private void addEdge(String nodeName1, String nodeName2) { // Add the edge. getSessionWrapper().addEdge(edge); + getSessionWorkbench().revalidate(); + getSessionWorkbench().repaint(); } private static SessionNodeWrapper getNewModelNode(String nextButtonType, diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/SessionEditorNode.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/SessionEditorNode.java index 2b582cc2e3..9b840fc7f3 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/SessionEditorNode.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/SessionEditorNode.java @@ -21,8 +21,7 @@ package edu.cmu.tetradapp.app; -import edu.cmu.tetrad.graph.Edge; -import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.graph.*; import edu.cmu.tetrad.session.*; import edu.cmu.tetrad.util.*; import edu.cmu.tetradapp.editor.EditorWindow; @@ -38,6 +37,8 @@ import javax.swing.border.TitledBorder; import javax.swing.event.InternalFrameAdapter; import javax.swing.event.InternalFrameEvent; +import javax.swing.event.MenuDragMouseEvent; +import javax.swing.event.MenuDragMouseListener; import java.awt.*; import java.awt.Point; import java.awt.event.*; @@ -105,8 +106,6 @@ public SessionEditorNode(SessionNodeWrapper modelNode, SimulationStudy simulatio this.simulationStudy = simulationStudy; displayComp.setName(modelNode.getSessionName()); - - System.out.println("modelNode.getSessionName(): " + modelNode.getButtonType()); if (displayComp instanceof NoteDisplayComp) { createParamObjects(this); @@ -127,7 +126,6 @@ public void mousePressed(MouseEvent e) { } }); } else { - // Any node type except notes setDisplayComp(displayComp); setLayout(new BorderLayout()); add((JComponent) getSessionDisplayComp(), BorderLayout.CENTER); @@ -266,8 +264,6 @@ public void internalFrameClosing(InternalFrameEvent e) { DesktopController.getInstance().addEditorWindow(editorWindow, JLayeredPane.PALETTE_LAYER); editorWindow.pack(); - editorWindow.setResizable(true); - editorWindow.setMaximum(true); editorWindow.setVisible(true); spawnedEditor = editorWindow; @@ -391,6 +387,22 @@ public void mousePressed(MouseEvent e) { } }); +// sessionEditorNode.addMouseMotionListener(new MouseMotionAdapter() { +// public void mouseMoved(MouseEvent e) { +// Point p = e.getPoint(); +// if (p.getX() > 40 && p.getY() > 40) { +// ToolTipManager toolTipManager = +// ToolTipManager.sharedInstance(); +// toolTipManager.setInitialDelay(750); +// JPopupMenu popup = sessionEditorNode.getPopup(); +// +// if (!popup.isShowing()) { +// popup.show(sessionEditorNode, e.getX(), e.getY()); +// } +// } +// } +// }); + sessionEditorNode.addComponentListener(new ComponentAdapter() { public void componentMoved(ComponentEvent e) { sessionEditorNode.getSimulationStudy().getSession().setSessionChanged(true); @@ -564,11 +576,17 @@ public void actionPerformed(ActionEvent e) { return popup; } + JPopupMenu popup = null; + /** * Creates the popup for the node. */ private JPopupMenu getPopup() { - JPopupMenu popup = new JPopupMenu(); + if (popup != null && popup.isShowing()) { + return popup; + } + + popup = new JPopupMenu(); JMenuItem createModel = new JMenuItem("Create Model"); createModel.setToolTipText("Creates a new model for this node" + @@ -739,52 +757,63 @@ public void actionPerformed(ActionEvent e) { } }); -// JMenuItem editSimulationParameters = -// new JMenuItem("Edit Parameters..."); -// editSimulationParameters.setToolTipText(""); + JMenuItem editSimulationParameters = + new JMenuItem("Edit Parameters..."); + editSimulationParameters.setToolTipText(""); -// editSimulationParameters.addActionListener(new ActionListener() { -// public void actionPerformed(ActionEvent e) { -// SessionModel model = getSessionNode().getModel(); -// Class modelClass; -// -// if (model == null) { -// modelClass = determineTheModelClass(getSessionNode()); -// } else { -// modelClass = model.getClass(); -// } + editSimulationParameters.addActionListener(new ActionListener() { + public void actionPerformed(ActionEvent e) { + SessionModel model = getSessionNode().getModel(); + Class modelClass; + + if (model == null) { + modelClass = determineTheModelClass(getSessionNode()); + } else { + modelClass = model.getClass(); + } + + if (!getSessionNode().existsParameterizedConstructor( + modelClass)) { + JOptionPane.showMessageDialog(JOptionUtils.centeringComp(), + "There is no parameterization for this model."); + return; + } + + Parameters param = getSessionNode().getParam(modelClass); + Object[] arguments = + getSessionNode().getModelConstructorArguments( + modelClass); + + if (param != null) { + try { + editParameters(modelClass, param, arguments); + int ret = JOptionPane.showConfirmDialog(JOptionUtils.centeringComp(), + "Should I overwrite the contents of this box and all delete the contents\n" + + "of all boxes downstream?", + "Double check...", JOptionPane.YES_NO_OPTION); + if (ret == JOptionPane.YES_OPTION) { + getSessionNode().destroyModel(); + getSessionNode().createModel(modelClass, true); + } + } catch (Exception e1) { + e1.printStackTrace(); + } + } + } + }); + +// final SessionNode thisNode = getSessionNode(); // -// if (!getSessionNode().existsParameterizedConstructor( -// modelClass)) { -// JOptionPane.showMessageDialog(JOptionUtils.centeringComp(), -// "There is no parameterization for this model."); -// return; -// } +// // popup.add(getConsistentParentMenuItems(getConsistentParentBoxTypes(thisNode))); +//// popup.add(getConsistentChildBoxMenus(getConsistentChildBoxTypes(thisNode, null))); // -// Parameters param = getSessionNode().getParam(modelClass); -// Object[] arguments = -// getSessionNode().getModelConstructorArguments( -// modelClass); +// addConsistentParentMenuItems(popup, getConsistentParentBoxTypes(thisNode)); +// addConsistentChildBoxMenus(popup, getConsistentChildBoxTypes(thisNode, null)); // -// if (param != null) { -// try { -// editParameters(modelClass, param, arguments); -// int ret = JOptionPane.showConfirmDialog(JOptionUtils.centeringComp(), -// "Should I overwrite the contents of this box and all delete the contents\n" + -// "of all boxes downstream?", -// "Double check...", JOptionPane.YES_NO_OPTION); -// if (ret == JOptionPane.YES_OPTION) { -// getSessionNode().destroyModel(); -// getSessionNode().createModel(modelClass, true); -// } -// } catch (Exception e1) { -// e1.printStackTrace(); -// } -// } -// } -// }); +// popup.addSeparator(); popup.add(createModel); + popup.add(editSimulationParameters); popup.add(editModel); popup.add(destroyModel); @@ -796,85 +825,111 @@ public void actionPerformed(ActionEvent e) { popup.addSeparator(); -// final SessionNode thisNode = getSessionNode(); -// -// popup.add(getConsistentParentMenuItems(getConsistentParentBoxTypes(thisNode))); -// popup.add(getConsistentChildBoxMenus(getConsistentChildBoxTypes(thisNode, null))); - -// popup.addSeparator(); - addEditLoggerSettings(popup); popup.add(propagateDownstream); return popup; } - private JMenu getConsistentChildBoxMenus(List consistentChildBoxes) { - JMenu newChildren = new JMenu("New Child Box"); - - for (String _type : consistentChildBoxes) { - final JMenuItem menuItem = new JMenuItem(_type); - - menuItem.addActionListener(new ActionListener() { - @Override - public void actionPerformed(ActionEvent e) { - new ConstructTemplateAction("Test").addChild(SessionEditorNode.this, menuItem.getText()); - } - }); - - newChildren.add(menuItem); - } - return newChildren; - } - - private JMenu getConsistentParentMenuItems(List consistentParentBoxes) { - final JMenu newParents = new JMenu("New Parent Box"); - - for (String _type : consistentParentBoxes) { - final JMenuItem menuItem = new JMenuItem(_type); - - menuItem.addActionListener(new ActionListener() { - @Override - public void actionPerformed(ActionEvent e) { - new ConstructTemplateAction("Test").addParent(SessionEditorNode.this, menuItem.getText()); - } - }); - - newParents.add(menuItem); - } - return newParents; - } - - private List getConsistentChildBoxTypes(SessionNode thisNode, SessionModel model) { - List consistentChildBoxes = new ArrayList<>(); - - for (String type : TetradApplicationConfig.getInstance().getConfigs().keySet()) { - SessionNodeConfig config = TetradApplicationConfig.getInstance().getSessionNodeConfig(type); - Class[] modelClasses = config.getModels(); - - SessionNode newNode = new SessionNode(modelClasses); +// private void addConsistentChildBoxMenus(JPopupMenu menu, List consistentChildBoxes) { +// for (String _type : consistentChildBoxes) { +// final JMenuItem menuItem = new JMenuItem(_type); +// +// menuItem.addActionListener(new ActionListener() { +// @Override +// public void actionPerformed(ActionEvent e) { +// String text = menuItem.getText(); +// String[] tokens = text.split(" "); +// String type = tokens[1]; +// new ConstructTemplateAction("Test").addChild(SessionEditorNode.this, type); +// } +// }); +// +// menu.add(menuItem); +// } +// } - if (thisNode.isConsistentParent(newNode)) { - consistentChildBoxes.add(type); - } - } - return consistentChildBoxes; - } +// private JMenu addConsistentChildBoxMenus(List consistentChildBoxes) { +// JMenu newChildren = new JMenu("New Child Box"); +// +// for (String _type : consistentChildBoxes) { +// final JMenuItem menuItem = new JMenuItem(_type); +// +// menuItem.addActionListener(new ActionListener() { +// @Override +// public void actionPerformed(ActionEvent e) { +// new ConstructTemplateAction("Test").addChild(SessionEditorNode.this, menuItem.getText()); +// } +// }); +// +// +// +// newChildren.add(menuItem); +// } +// return newChildren; +// } - private List getConsistentParentBoxTypes(SessionNode thisNode) { - List consistentParentBoxes = new ArrayList<>(); +// private JMenu addConsistentParentMenuItems(JPopupMenu menu, List consistentParentNodes) { +// final JMenu newParents = new JMenu("New Parent Box"); +// +// for (final SessionNode node : consistentParentNodes) { +// final JMenuItem menuItem = new JMenuItem("Add Links: " + node.getDisplayName()); +// +// menuItem.addActionListener(new ActionListener() { +// @Override +// public void actionPerformed(ActionEvent e) { +// String displayName1 = node.getDisplayName(); +// String displayName2 = SessionEditorNode.this.getSessionNode().getDisplayName(); +// new ConstructTemplateAction("Test").addEdge(displayName1, displayName2); +// } +// }); +// +// menu.add(menuItem); +// } +// +// return newParents; +// } - for (String type : TetradApplicationConfig.getInstance().getConfigs().keySet()) { - SessionNodeConfig config = TetradApplicationConfig.getInstance().getSessionNodeConfig(type); - Class[] modelClasses = config.getModels(); - SessionNode newNode = new SessionNode(modelClasses); +// private List getConsistentChildBoxTypes(SessionNode thisNode, SessionModel model) { +// List consistentChildBoxes = new ArrayList<>(); +// +// List nodes = sessionWorkbench.getSessionWrapper().getNodes(); +// List sessionNodes = new ArrayList<>(); +// for (Node node : nodes) sessionNodes.add(((SessionNodeWrapper) node).getSessionNode()); +// +// Set strings = TetradApplicationConfig.getInstance().getConfigs().keySet(); +// +// for (String type : strings) { +// SessionNodeConfig config = TetradApplicationConfig.getInstance().getSessionNodeConfig(type); +// Class[] modelClasses = config.getModels(); +// +// SessionNode newNode = new SessionNode(modelClasses); +// +// if (newNode.isConsistentParent(thisNode, sessionNodes)) { +// consistentChildBoxes.add("Add " + type); +// } +// } +// +// return consistentChildBoxes; +// } - if (thisNode.isConsistentParent(newNode)) { - consistentParentBoxes.add(type); - } - } - return consistentParentBoxes; - } +// private List getConsistentParentBoxTypes(SessionNode thisNode) { +// List consistentParentBoxes = new ArrayList<>(); +// +// for (Node _node : getSessionWorkbench().getSessionWrapper().getNodes()) { +// SessionNode node = ((SessionNodeWrapper) _node).getSessionNode(); +// +// if (sessionWorkbench.getSessionWrapper().isAncestorOf(thisNode, node)) { +// continue; +// } +// +// if (!thisNode.getParents().contains(node) && thisNode.isConsistentParent(node)) { +// consistentParentBoxes.add(node); +// } +// } +// +// return consistentParentBoxes; +// } /** diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/SessionEditorToolbar.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/SessionEditorToolbar.java index 6279fcb76b..9855901435 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/SessionEditorToolbar.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/SessionEditorToolbar.java @@ -108,19 +108,19 @@ public SessionEditorToolbar(final SessionEditorWorkbench workbench) { "
to construct the object in the second node." + "
As a shortcut, hold down the Control key." + ""), - new ButtonInfo("Data", "Data & Simulation", "data", - "Add a node for a data object."), - new ButtonInfo("Search", "Search", "search", - "Add a node for a search algorithm."), - new ButtonInfo("Comparison", "Comparison", "compare", - "Add a node to compare graphs or SEM IM's."), new ButtonInfo("Graph", "Graph", "graph", "Add a graph node."), + new ButtonInfo("Compare", "Compare", "compare", + "Add a node to compare graphs or SEM IM's."), new ButtonInfo("PM", "Parametric Model", "pm", "Add a node for a parametric model."), new ButtonInfo("IM", "Instantiated Model", "im", "Add a node for an instantiated model."), new ButtonInfo("Estimator", "Estimator", "estimator", "Add a node for an estimator."), + new ButtonInfo("Data", "Data & Simulation", "data", + "Add a node for a data object."), + new ButtonInfo("Search", "Search", "search", + "Add a node for a search algorithm."), new ButtonInfo("Updater", "Updater", "updater", "Add a node for an updater."), new ButtonInfo("Classify", "Classify", "search", diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/SessionEditorToolbarHarry.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/SessionEditorToolbarHarry.java new file mode 100644 index 0000000000..4738cd44de --- /dev/null +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/SessionEditorToolbarHarry.java @@ -0,0 +1,476 @@ +/////////////////////////////////////////////////////////////////////////////// +// For information as to what this class does, see the Javadoc, below. // +// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // +// 2007, 2008, 2009, 2010, 2014, 2015 by Peter Spirtes, Richard Scheines, Joseph // +// Ramsey, and Clark Glymour. // +// // +// This program is free software; you can redistribute it and/or modify // +// it under the terms of the GNU General Public License as published by // +// the Free Software Foundation; either version 2 of the License, or // +// (at your option) any later version. // +// // +// This program is distributed in the hope that it will be useful, // +// but WITHOUT ANY WARRANTY; without even the implied warranty of // +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // +// GNU General Public License for more details. // +// // +// You should have received a copy of the GNU General Public License // +// along with this program; if not, write to the Free Software // +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // +/////////////////////////////////////////////////////////////////////////////// + +package edu.cmu.tetradapp.app; + +import edu.cmu.tetradapp.util.ImageUtils; +import edu.cmu.tetradapp.workbench.AbstractWorkbench; + +import javax.swing.*; +import javax.swing.border.EmptyBorder; +import javax.swing.event.ChangeEvent; +import javax.swing.event.ChangeListener; +import java.awt.*; +import java.awt.event.KeyEvent; +import java.awt.event.MouseAdapter; +import java.awt.event.MouseEvent; +import java.beans.PropertyChangeEvent; +import java.beans.PropertyChangeListener; +import java.util.HashMap; +import java.util.Map; + +/** + * Displays a vertical list of buttons that determine the next action the user + * can take in the session editor workbench, whether it's selecting and moving a + * node, adding a node of a particular type, or adding an edge. + * + * @author Joseph Ramsey jdramsey@andrew.cmu.edu + * @see SessionEditor + * @see SessionEditorToolbarHarry + */ +final class SessionEditorToolbarHarry extends JPanel { + + //=========================MEMBER FIELDS============================// + + /** + * True iff the toolbar is responding to events. + */ + private boolean respondingToEvents = true; + /** + * The node type of the button that is used for the Select/Move tool. + */ + private final String selectType = "Select"; + + + /**\ + * The map from JToggleButtons to String node types. + */ + private final Map nodeTypes = new HashMap<>(); + + /** + * True iff the shift key was down on last click. + */ + private boolean shiftDown = false; + + /** + * The workbench this toolbar controls. + */ + private SessionEditorWorkbench workbench; + + //=============================CONSTRUCTORS==========================// + + /** + * Constructs a new session toolbar. + * + * @param workbench the workbench this toolbar controls. + */ + public SessionEditorToolbarHarry(final SessionEditorWorkbench workbench) { + if (workbench == null) { + throw new NullPointerException("Workbench must not be null."); + } + + this.workbench = workbench; + + // Set up panel. + Box buttonsPanel = Box.createVerticalBox(); +// buttonsPanel.setBackground(new Color(198, 232, 252)); + buttonsPanel.setBorder(new EmptyBorder(10, 10, 10, 10)); + + // Create buttons. + /* + Node infos for all of the nodes. + */ + ButtonInfo[] buttonInfos = new ButtonInfo[]{ + new ButtonInfo("Select", "Select and Move", "move", + "Select and move nodes or groups of nodes " + + "
on the workbench."), +// new ButtonInfo("Edge", "Draw Edge", "flow", +// "Add an edge from one node to another to declare" + +// "
that the object in the first node should be used " + +// "
to construct the object in the second node." + +// "
As a shortcut, hold down the Control key." + +// ""), + new ButtonInfo("Graph", "Graph", "graph", "Add a graph node."), + new ButtonInfo("Data", "Data", "data", + "Add a node for a data object."), +// new ButtonInfo("Data", "Data & Simulation", "data", +// "Add a node for a data object."), +// new ButtonInfo("Search", "Search", "search", +// "Add a node for a search algorithm."), + new ButtonInfo("Simulation", "Simulation", "data", + "Add a node for a data object."), +// new ButtonInfo("Comparison", "Comparison", "compare", +// "Add a node to compare graphs or SEM IM's."), +// new ButtonInfo("PM", "Parametric Model", "pm", +// "Add a node for a parametric model."), +// new ButtonInfo("IM", "Instantiated Model", "im", +// "Add a node for an instantiated model."), +// new ButtonInfo("Estimator", "Estimator", "estimator", +// "Add a node for an estimator."), +// new ButtonInfo("Updater", "Updater", "updater", +// "Add a node for an updater."), +// new ButtonInfo("Classify", "Classify", "search", +// "Add a node for a classifier."), +// new ButtonInfo("Regression", "Regression", "regression", +// "Add a node for a regression."), +// new ButtonInfo("Knowledge", "Knowledge", "knowledge", "Add a knowledge box node."), + new ButtonInfo("Note", "Note", "note", + "Add a note to the session.") + }; + JToggleButton[] buttons = new JToggleButton[buttonInfos.length]; + + for (int i = 0; i < buttonInfos.length; i++) { + buttons[i] = constructButton(buttonInfos[i]); + } + + // Add all buttons to a button group. + ButtonGroup buttonGroup = new ButtonGroup(); + + for (int i = 0; i < buttonInfos.length; i++) { + buttonGroup.add(buttons[i]); + } + + // This seems to be fixed. Now creating weirdness. jdramsey 3/4/2014 +// // Add a focus listener to help buttons not deselect when the +// // mouse slides away from the button. +// FocusListener focusListener = new FocusAdapter() { +// public void focusGained(FocusEvent e) { +// JToggleButton component = (JToggleButton) e.getComponent(); +// component.getModel().setSelected(true); +// } +// }; +// +// for (int i = 0; i < buttonInfos.length; i++) { +// buttons[i].addFocusListener(focusListener); +// } + + // Add an action listener to help send messages to the + // workbench. + ChangeListener changeListener = new ChangeListener() { + public void stateChanged(ChangeEvent e) { + JToggleButton _button = (JToggleButton) e.getSource(); + + if (_button.getModel().isSelected()) { + setWorkbenchMode(_button); +// setCursor(workbench.getCursor()); + } + } + }; + + for (int i = 0; i < buttonInfos.length; i++) { + buttons[i].addChangeListener(changeListener); + } + + // Select the Select button. + JToggleButton button = getButtonForType(this.selectType); + + button.getModel().setSelected(true); + + // Add the buttons to the workbench. + for (int i = 0; i < buttonInfos.length; i++) { + if (!buttonInfos[i].getNodeTypeName().equals("Select")) { + buttonsPanel.add(buttons[i]); + buttonsPanel.add(Box.createVerticalStrut(5)); + } + } + + // Put the panel in a scrollpane. + this.setLayout(new BorderLayout()); + JScrollPane scroll = new JScrollPane(buttonsPanel); + scroll.setPreferredSize(new Dimension(130, 1000)); + add(scroll, BorderLayout.CENTER); + + // Add property change listener so that selection can be moved + // back to "SELECT_MOVE" after an action. + workbench.addPropertyChangeListener(new PropertyChangeListener() { + public void propertyChange(PropertyChangeEvent e) { + if (!isRespondingToEvents()) { + return; + } + + String propertyName = e.getPropertyName(); + + if ("nodeAdded".equals(propertyName)) { + if (!isShiftDown()) { + resetSelectMove(); + } + } + } + }); + + KeyboardFocusManager.getCurrentKeyboardFocusManager() + .addKeyEventDispatcher(new KeyEventDispatcher() { + public boolean dispatchKeyEvent(KeyEvent e) { + int keyCode = e.getKeyCode(); + int id = e.getID(); + + if (keyCode == KeyEvent.VK_SHIFT) { + if (id == KeyEvent.KEY_PRESSED) { + setShiftDown(true); + } else if (id == KeyEvent.KEY_RELEASED) { + setShiftDown(false); + resetSelectMove(); + } + } + + return false; + } + }); + + resetSelectMove(); + } + + /** + * Sets the selection back to move/select. + */ + private void resetSelectMove() { + JToggleButton selectButton = getButtonForType(selectType); + if (!(selectButton.isSelected())) { + selectButton.doClick(); + selectButton.requestFocus(); + } + } + +// /** +// * Sets the selection back to Flowchart. +// */ +// public void resetFlowchart() { +// JToggleButton edgeButton = getButtonForType(edgeType); +// edgeButton.doClick(); +// edgeButton.requestFocus(); +// } + + /** + * True iff the toolbar is responding to events. This may need to be turned + * off temporarily. + */ + private boolean isRespondingToEvents() { + return respondingToEvents; + } + + /** + * Sets whether the toolbar should react to events. This may need to be + * turned off temporarily. + */ + public void setRespondingToEvents(boolean respondingToEvents) { + this.respondingToEvents = respondingToEvents; + } + + protected void processKeyEvent(KeyEvent e) { + System.out.println("process key event " + e); + super.processKeyEvent(e); + } + + //===========================PRIVATE METHODS=========================// + + /** + * Constructs the button with the given node type and image prefix. If the + * node type is "Select", constructs a button that allows nodes to be + * selected and moved. If the node type is "Edge", constructs a button that + * allows edges to be drawn. For other node types, constructs buttons that + * allow those type of nodes to be added to the workbench. If a non-null + * image prefix is provided, images for Up.gif, Down.gif, + * Off.gif and Roll.gif are loaded from the /images + * directory relative to this compiled class and used to provide up, down, + * off, and rollover images for the constructed button. On construction, + * nodes are mapped to their node types in the Map, nodeTypes. + * Listeners are added to the node. + * + * @param buttonInfo contains the info needed to construct the button. + */ + private JToggleButton constructButton(ButtonInfo buttonInfo) { + String imagePrefix = buttonInfo.getImagePrefix(); + + if (imagePrefix == null) { + throw new NullPointerException("Image prefix must not be null."); + } + + JToggleButton button = new JToggleButton(); + + button.addMouseListener(new MouseAdapter() { + public void mouseClicked(MouseEvent e) { + super.mouseClicked(e); + setShiftDown(e.isShiftDown()); +// setControlDown(e.isControlDown()); + } + }); + + if ("Select".equals(buttonInfo.getNodeTypeName())) { + button.setIcon(new ImageIcon(ImageUtils.getImage(this, "move.gif"))); + } else if ("Edge".equals(buttonInfo.getNodeTypeName())) { + button.setIcon( + new ImageIcon(ImageUtils.getImage(this, "flow.gif"))); + } else { + button.setName(buttonInfo.getNodeTypeName()); + button.setText("
" + buttonInfo.getDisplayName() + + "
"); + } + + button.setMaximumSize(new Dimension(110, 40)); // For a vertical box. + button.setToolTipText(buttonInfo.getToolTipText()); + this.nodeTypes.put(button, buttonInfo.getNodeTypeName()); + + return button; + } + + /** + * Sets the state of the workbench in response to a button press. + * + * @param button the JToggleButton whose workbench state is to be set. + */ + private void setWorkbenchMode(JToggleButton button) { + String nodeType = this.nodeTypes.get(button); + + /* + The node type of the button that is used for the edge-drawing tool. + */ + String edgeType = "Edge"; + if (selectType.equals(nodeType)) { + workbench.setWorkbenchMode(AbstractWorkbench.SELECT_MOVE); + workbench.setNextButtonType(null); + setCursor(new Cursor(Cursor.HAND_CURSOR)); + workbench.setCursor(new Cursor(Cursor.HAND_CURSOR)); + } else if (edgeType.equals(nodeType)) { + workbench.setWorkbenchMode(AbstractWorkbench.ADD_EDGE); + workbench.setNextButtonType(null); +// setCursor(workbench.getCursor()); + +// Toolkit toolkit = Toolkit.getDefaultToolkit(); +// Image image = ImageUtils.getImage(this, "arrowCursorImage.png"); +// Cursor c = toolkit.createCustomCursor(image, new Point(10, 10), "img"); +// setCursor(c); +// workbench.setCursor(c); + + setCursor(new Cursor(Cursor.DEFAULT_CURSOR)); + workbench.setCursor(new Cursor(Cursor.DEFAULT_CURSOR)); + } else { + workbench.setWorkbenchMode(AbstractWorkbench.ADD_NODE); + workbench.setNextButtonType(nodeType); + +// Toolkit toolkit = Toolkit.getDefaultToolkit(); +// Image image = ImageUtils.getImage(this, "cursorImage.png"); +// Cursor c = toolkit.createCustomCursor(image, new Point(10, 10), "img"); +// setCursor(c); +// workbench.setCursor(c); + +// setCursor(workbench.getCursor()); + + setCursor(new Cursor(Cursor.CROSSHAIR_CURSOR)); + workbench.setCursor(new Cursor(Cursor.CROSSHAIR_CURSOR)); + + } + } + + /** + * @return the JToggleButton for the given node type, or null if no such + * button exists. + */ + private JToggleButton getButtonForType(String nodeType) { + for (Object o : nodeTypes.keySet()) { + JToggleButton button = (JToggleButton) o; + + if (nodeType.equals(nodeTypes.get(button))) { + return button; + } + } + + return null; + } + + private boolean isShiftDown() { + return shiftDown; + } + + private void setShiftDown(boolean shiftDown) { + this.shiftDown = shiftDown; + } + +// public boolean isControlDown() { +// return shiftDown; +// } +// +// private void setControlDown(boolean shiftDown) { +// this.shiftDown = shiftDown; +// } + + /** + * Holds info for constructing a single button. + */ + private static final class ButtonInfo { + + /** + * This is the name used to construct nodes on the graph of this type. + * Need to coordinate with session. + */ + private String nodeTypeName; + + /** + * The name displayed on the button. + */ + private final String displayName; + + /** + * The prefixes for images for this button. It is assumed that files + * Up.gif, Down.gif, Off.gif and + * Roll.gif are located in the /images directory relative to + * this compiled class. + */ + private final String imagePrefix; + + /** + * Tool tip text displayed for the button. + */ + private final String toolTipText; + + public ButtonInfo(String nodeTypeName, String displayName, + String imagePrefix, String toolTipText) { + this.nodeTypeName = nodeTypeName; + this.displayName = displayName; + this.imagePrefix = imagePrefix; + this.toolTipText = toolTipText; + } + + public String getNodeTypeName() { + return nodeTypeName; + } + + public String getDisplayName() { + return displayName; + } + + public void setNodeTypeName(String nodeTypeName) { + this.nodeTypeName = nodeTypeName; + } + + public String getImagePrefix() { + return imagePrefix; + } + + public String getToolTipText() { + return toolTipText; + } + } +} + + + + + diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/TetradApplicationConfig.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/TetradApplicationConfig.java index e854513022..bac87a9c72 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/TetradApplicationConfig.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/TetradApplicationConfig.java @@ -133,7 +133,7 @@ public SessionNodeConfig getSessionNodeConfig(Class model) { private static Map buildConfiguration(Element root) { Elements elements = root.getChildElements(); ClassLoader loader = getClassLoader(); - Map configs = new HashMap<>(); + Map configs = new LinkedHashMap<>(); for (int i = 0; i < elements.size(); i++) { Element node = elements.get(i); String id = node.getAttributeValue("id"); diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/TetradDesktop.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/TetradDesktop.java index 9131cdb1ce..3d8a8d6ee7 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/TetradDesktop.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/TetradDesktop.java @@ -136,11 +136,6 @@ public TetradDesktop() { // ===========================PUBLIC METHODS============================// - public void buildHpcJobActivityPane() { - - } - - public void newSessionEditor() { String newName = getNewSessionName(); SessionEditor editor = new SessionEditor(newName); @@ -369,8 +364,7 @@ public SessionEditor getFrontmostSessionEditor() { * Reacts to property change events 'editorClosing', 'closeFrame', and * 'name'. * - * @param e - * the property change event. + * @param e the property change event. */ public void propertyChange(PropertyChangeEvent e) { diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/TetradMenuBar.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/TetradMenuBar.java index 533f54d79b..b86b672ebe 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/TetradMenuBar.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/TetradMenuBar.java @@ -59,7 +59,7 @@ final class TetradMenuBar extends JMenuBar { /** * Creates the main menubar for Tetrad. */ - public TetradMenuBar(final TetradDesktop desktop) { + public TetradMenuBar(TetradDesktop desktop) { this.desktop = desktop; setBorder(new EtchedBorder()); diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/WindowMenuListener.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/WindowMenuListener.java index 0d5e6c903d..94e403ad07 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/WindowMenuListener.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/WindowMenuListener.java @@ -66,7 +66,7 @@ final class WindowMenuListener implements MenuListener, ActionListener { * Constructs the window menu listener. Requires to be told which object * the window menu is and which object the desktop pane is. */ - public WindowMenuListener(JMenu windowMenu, final TetradDesktop desktop) { + public WindowMenuListener(JMenu windowMenu, TetradDesktop desktop) { if (windowMenu == null) { throw new NullPointerException("Window menu must not be null."); diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/action/DeleteHpcJobInfoAction.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/action/DeleteHpcJobInfoAction.java index 959f75d226..67fb6029f8 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/action/DeleteHpcJobInfoAction.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/action/DeleteHpcJobInfoAction.java @@ -23,49 +23,45 @@ */ public class DeleteHpcJobInfoAction extends AbstractAction { - private static final long serialVersionUID = 7915068087861233608L; + private static final long serialVersionUID = 7915068087861233608L; - private final Component parentComp; + private final Component parentComp; - public DeleteHpcJobInfoAction(final Component parentComp) { - this.parentComp = parentComp; - } + public DeleteHpcJobInfoAction(final Component parentComp) { + this.parentComp = parentComp; + } - @Override - public void actionPerformed(ActionEvent e) { + @Override + public void actionPerformed(ActionEvent e) { - JTable table = (JTable) e.getSource(); - int modelRow = Integer.valueOf(e.getActionCommand()); - DefaultTableModel finishedJobTableModel = (DefaultTableModel) table - .getModel(); + JTable table = (JTable) e.getSource(); + int modelRow = Integer.valueOf(e.getActionCommand()); + DefaultTableModel finishedJobTableModel = (DefaultTableModel) table.getModel(); - long jobId = Long.valueOf( - finishedJobTableModel.getValueAt(modelRow, - HpcJobActivityEditor.ID_COLUMN).toString()).longValue(); + long jobId = Long.valueOf(finishedJobTableModel.getValueAt(modelRow, HpcJobActivityEditor.ID_COLUMN).toString()) + .longValue(); - int answer = JOptionPane.showConfirmDialog(parentComp, - "Would you like to delete this HPC job id: " + jobId + "?", - "Delete HPC job", JOptionPane.YES_NO_OPTION); + int answer = JOptionPane.showConfirmDialog(parentComp, + "Would you like to delete this HPC job id: " + jobId + "?", "Delete HPC job", + JOptionPane.YES_NO_OPTION); - if (answer == JOptionPane.NO_OPTION) - return; + if (answer == JOptionPane.NO_OPTION) + return; - TetradDesktop desktop = (TetradDesktop) DesktopController.getInstance(); - final HpcJobManager hpcJobManager = desktop.getHpcJobManager(); + TetradDesktop desktop = (TetradDesktop) DesktopController.getInstance(); + final HpcJobManager hpcJobManager = desktop.getHpcJobManager(); - HpcJobInfo hpcJobInfo = hpcJobManager - .findHpcJobInfoById(Long.valueOf( - finishedJobTableModel.getValueAt(modelRow, - HpcJobActivityEditor.ID_COLUMN).toString()) - .longValue()); + HpcJobInfo hpcJobInfo = hpcJobManager.findHpcJobInfoById( + Long.valueOf(finishedJobTableModel.getValueAt(modelRow, HpcJobActivityEditor.ID_COLUMN).toString()) + .longValue()); - if (hpcJobInfo != null) { - // Update table - finishedJobTableModel.removeRow(modelRow); - table.updateUI(); - hpcJobManager.removeHpcJobInfoTransaction(hpcJobInfo); - } + if (hpcJobInfo != null) { + // Update table + finishedJobTableModel.removeRow(modelRow); + table.updateUI(); + hpcJobManager.removeHpcJobInfoTransaction(hpcJobInfo); + } - } + } } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/action/HpcAccountSelectionAction.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/action/HpcAccountSelectionAction.java index f890e32611..66b6992a42 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/action/HpcAccountSelectionAction.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/action/HpcAccountSelectionAction.java @@ -18,39 +18,36 @@ */ public class HpcAccountSelectionAction extends AbstractAction { - private static final long serialVersionUID = -5506074283478552872L; - - private final List hpcAccounts; - - private final List checkedHpcAccountList; - - private final JTabbedPane tabbedPane; - - public HpcAccountSelectionAction(final List hpcAccounts, - final List checkedHpcAccountList, - final JTabbedPane tabbedPane) { - this.hpcAccounts = hpcAccounts; - this.checkedHpcAccountList = checkedHpcAccountList; - this.tabbedPane = tabbedPane; - } - - @Override - public void actionPerformed(ActionEvent e) { - final JCheckBox checkBox = (JCheckBox) e.getSource(); - for (HpcAccount hpcAccount : hpcAccounts) { - if (checkBox.getText().equals(hpcAccount.getConnectionName())) { - if (checkBox.isSelected() - && !checkedHpcAccountList.contains(hpcAccount)) { - checkedHpcAccountList.add(hpcAccount); - } else if (!checkBox.isSelected() - && checkedHpcAccountList.contains(hpcAccount)) { - checkedHpcAccountList.remove(hpcAccount); + private static final long serialVersionUID = -5506074283478552872L; + + private final List hpcAccounts; + + private final List checkedHpcAccountList; + + private final JTabbedPane tabbedPane; + + public HpcAccountSelectionAction(final List hpcAccounts, final List checkedHpcAccountList, + final JTabbedPane tabbedPane) { + this.hpcAccounts = hpcAccounts; + this.checkedHpcAccountList = checkedHpcAccountList; + this.tabbedPane = tabbedPane; + } + + @Override + public void actionPerformed(ActionEvent e) { + final JCheckBox checkBox = (JCheckBox) e.getSource(); + for (HpcAccount hpcAccount : hpcAccounts) { + if (checkBox.getText().equals(hpcAccount.getConnectionName())) { + if (checkBox.isSelected() && !checkedHpcAccountList.contains(hpcAccount)) { + checkedHpcAccountList.add(hpcAccount); + } else if (!checkBox.isSelected() && checkedHpcAccountList.contains(hpcAccount)) { + checkedHpcAccountList.remove(hpcAccount); + } + } } - } + int index = tabbedPane.getSelectedIndex(); + tabbedPane.setSelectedIndex(-1); + tabbedPane.setSelectedIndex(index); } - int index = tabbedPane.getSelectedIndex(); - tabbedPane.setSelectedIndex(-1); - tabbedPane.setSelectedIndex(index); - } } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/action/HpcAccountSettingAction.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/action/HpcAccountSettingAction.java index 59b0ab6444..3f67429268 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/action/HpcAccountSettingAction.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/action/HpcAccountSettingAction.java @@ -35,139 +35,130 @@ */ public class HpcAccountSettingAction extends AbstractAction { - private static final long serialVersionUID = -4084211497363128243L; - - public HpcAccountSettingAction() { - super("HPC Account"); - } - - @Override - public void actionPerformed(ActionEvent e) { - TetradDesktop desktop = (TetradDesktop) DesktopController.getInstance(); - HpcAccountManager manager = desktop.getHpcAccountManager(); - - JComponent comp = buildHpcAccountSettingComponent(manager); - JOptionPane.showMessageDialog(JOptionUtils.centeringComp(), comp, - "High-Performance Computing Account Setting", - JOptionPane.PLAIN_MESSAGE); - } - - private static JComponent buildHpcAccountSettingComponent( - final HpcAccountManager manager) { - // Get ComputingAccount from DB - final DefaultListModel listModel = new DefaultListModel(); - - for (HpcAccount account : manager.getHpcAccounts()) { - listModel.addElement(account); + private static final long serialVersionUID = -4084211497363128243L; + + public HpcAccountSettingAction() { + super("HPC Account"); } - - // JSplitPane - final JSplitPane splitPane = new JSplitPane(JSplitPane.HORIZONTAL_SPLIT); - - // Left pane -> JList (parent pane) - JPanel leftPanel = new JPanel(new BorderLayout()); - - // Right pane -> ComputingAccountEditor - final JPanel accountDetailPanel = new JPanel(new BorderLayout()); - accountDetailPanel.add(new HpcAccountEditor(splitPane, listModel, manager, - new HpcAccount()), BorderLayout.CENTER); - - splitPane.setLeftComponent(leftPanel); - splitPane.setRightComponent(accountDetailPanel); - - // Center Panel - final JList accountList = new JList<>(listModel); - accountList.setSelectionMode(ListSelectionModel.SINGLE_SELECTION); - accountList.setLayoutOrientation(JList.VERTICAL); - accountList.setSelectedIndex(-1); - accountList.addListSelectionListener(new ListSelectionListener() { - - @Override - public void valueChanged(ListSelectionEvent e) { - if (e.getValueIsAdjusting()) - return; - int selectedIndex = ((JList) e.getSource()) - .getSelectedIndex(); - // Show or remove the detail - accountDetailPanel.removeAll(); - if (selectedIndex > -1) { - HpcAccount computingAccount = listModel - .get(selectedIndex); - System.out.println(computingAccount); - accountDetailPanel.add(new HpcAccountEditor(splitPane, listModel, manager, - computingAccount), BorderLayout.CENTER); - } - accountDetailPanel.updateUI(); - } - }); - - // Left Panel - JPanel buttonPanel = new JPanel(new BorderLayout()); - JButton addButton = new JButton("Add"); - addButton.setSize(new Dimension(14, 8)); - addButton.addActionListener(new ActionListener() { - - @Override - public void actionPerformed(ActionEvent e) { - // Show the empty ComputingAccountEditor - accountDetailPanel.removeAll(); - accountDetailPanel.add(new HpcAccountEditor(splitPane, listModel, manager, - new HpcAccount()), BorderLayout.CENTER); - accountDetailPanel.updateUI(); - } - }); - buttonPanel.add(addButton, BorderLayout.WEST); - - JButton removeButton = new JButton("Remove"); - removeButton.setSize(new Dimension(14, 8)); - removeButton.addActionListener(new ActionListener() { - - @Override - public void actionPerformed(ActionEvent e) { - if (accountList.isSelectionEmpty()) - return; - int selectedIndex = accountList.getSelectedIndex(); - if (selectedIndex > -1) { - HpcAccount computingAccount = listModel - .get(selectedIndex); - // Pop up the confirm dialog - int option = JOptionPane.showConfirmDialog( - accountDetailPanel, "Are you sure that you want to delete " - + computingAccount + " ?", - "HPC Account Setting", JOptionPane.YES_NO_OPTION, - JOptionPane.QUESTION_MESSAGE); - - // If yes, remove it from DB and listModel - if (option == JOptionPane.YES_OPTION) { - manager.removeAccount(computingAccount); - listModel.remove(selectedIndex); - } + @Override + public void actionPerformed(ActionEvent e) { + TetradDesktop desktop = (TetradDesktop) DesktopController.getInstance(); + HpcAccountManager manager = desktop.getHpcAccountManager(); + + JComponent comp = buildHpcAccountSettingComponent(manager); + JOptionPane.showMessageDialog(JOptionUtils.centeringComp(), comp, "High-Performance Computing Account Setting", + JOptionPane.PLAIN_MESSAGE); + } + + private static JComponent buildHpcAccountSettingComponent(final HpcAccountManager manager) { + // Get ComputingAccount from DB + final DefaultListModel listModel = new DefaultListModel(); + + for (HpcAccount account : manager.getHpcAccounts()) { + listModel.addElement(account); } - } - }); - buttonPanel.add(removeButton, BorderLayout.EAST); - leftPanel.add(buttonPanel, BorderLayout.NORTH); - - JScrollPane accountListScroller = new JScrollPane(accountList); - leftPanel.add(accountListScroller, BorderLayout.CENTER); - - int minWidth = 300; - int minHeight = 200; - int screenWidth = Toolkit.getDefaultToolkit().getScreenSize().width; - int screenHeight = Toolkit.getDefaultToolkit().getScreenSize().height; - int frameWidth = screenWidth / 2; - int frameHeight = screenHeight / 2; - frameWidth = minWidth > frameWidth ? minWidth : frameWidth; - frameHeight = minHeight > frameHeight ? minHeight : frameHeight; - - splitPane.setDividerLocation(frameWidth / 4); - accountListScroller.setPreferredSize(new Dimension(frameWidth / 4, - frameHeight)); - accountDetailPanel.setPreferredSize(new Dimension(frameWidth * 3 / 4, - frameHeight)); - - return splitPane; - } + + // JSplitPane + final JSplitPane splitPane = new JSplitPane(JSplitPane.HORIZONTAL_SPLIT); + + // Left pane -> JList (parent pane) + JPanel leftPanel = new JPanel(new BorderLayout()); + + // Right pane -> ComputingAccountEditor + final JPanel accountDetailPanel = new JPanel(new BorderLayout()); + accountDetailPanel.add(new HpcAccountEditor(splitPane, listModel, manager, new HpcAccount()), + BorderLayout.CENTER); + + splitPane.setLeftComponent(leftPanel); + splitPane.setRightComponent(accountDetailPanel); + + // Center Panel + final JList accountList = new JList<>(listModel); + accountList.setSelectionMode(ListSelectionModel.SINGLE_SELECTION); + accountList.setLayoutOrientation(JList.VERTICAL); + accountList.setSelectedIndex(-1); + accountList.addListSelectionListener(new ListSelectionListener() { + + @Override + public void valueChanged(ListSelectionEvent e) { + if (e.getValueIsAdjusting()) + return; + int selectedIndex = ((JList) e.getSource()).getSelectedIndex(); + // Show or remove the detail + accountDetailPanel.removeAll(); + if (selectedIndex > -1) { + HpcAccount computingAccount = listModel.get(selectedIndex); + System.out.println(computingAccount); + accountDetailPanel.add(new HpcAccountEditor(splitPane, listModel, manager, computingAccount), + BorderLayout.CENTER); + } + accountDetailPanel.updateUI(); + } + }); + + // Left Panel + JPanel buttonPanel = new JPanel(new BorderLayout()); + JButton addButton = new JButton("Add"); + addButton.setSize(new Dimension(14, 8)); + addButton.addActionListener(new ActionListener() { + + @Override + public void actionPerformed(ActionEvent e) { + // Show the empty ComputingAccountEditor + accountDetailPanel.removeAll(); + accountDetailPanel.add(new HpcAccountEditor(splitPane, listModel, manager, new HpcAccount()), + BorderLayout.CENTER); + accountDetailPanel.updateUI(); + } + }); + buttonPanel.add(addButton, BorderLayout.WEST); + + JButton removeButton = new JButton("Remove"); + removeButton.setSize(new Dimension(14, 8)); + removeButton.addActionListener(new ActionListener() { + + @Override + public void actionPerformed(ActionEvent e) { + if (accountList.isSelectionEmpty()) + return; + int selectedIndex = accountList.getSelectedIndex(); + if (selectedIndex > -1) { + HpcAccount computingAccount = listModel.get(selectedIndex); + // Pop up the confirm dialog + int option = JOptionPane.showConfirmDialog(accountDetailPanel, + "Are you sure that you want to delete " + computingAccount + " ?", "HPC Account Setting", + JOptionPane.YES_NO_OPTION, JOptionPane.QUESTION_MESSAGE); + + // If yes, remove it from DB and listModel + if (option == JOptionPane.YES_OPTION) { + manager.removeAccount(computingAccount); + listModel.remove(selectedIndex); + } + + } + } + }); + buttonPanel.add(removeButton, BorderLayout.EAST); + leftPanel.add(buttonPanel, BorderLayout.NORTH); + + JScrollPane accountListScroller = new JScrollPane(accountList); + leftPanel.add(accountListScroller, BorderLayout.CENTER); + + int minWidth = 300; + int minHeight = 200; + int screenWidth = Toolkit.getDefaultToolkit().getScreenSize().width; + int screenHeight = Toolkit.getDefaultToolkit().getScreenSize().height; + int frameWidth = screenWidth / 2; + int frameHeight = screenHeight / 2; + frameWidth = minWidth > frameWidth ? minWidth : frameWidth; + frameHeight = minHeight > frameHeight ? minHeight : frameHeight; + + splitPane.setDividerLocation(frameWidth / 4); + accountListScroller.setPreferredSize(new Dimension(frameWidth / 4, frameHeight)); + accountDetailPanel.setPreferredSize(new Dimension(frameWidth * 3 / 4, frameHeight)); + + return splitPane; + } } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/action/HpcJobActivityAction.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/action/HpcJobActivityAction.java index 2975d77a23..3cdbe0edbb 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/action/HpcJobActivityAction.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/action/HpcJobActivityAction.java @@ -1,5 +1,6 @@ package edu.cmu.tetradapp.app.hpc.action; +import java.awt.Frame; import java.awt.HeadlessException; import java.awt.event.ActionEvent; @@ -19,25 +20,25 @@ */ public class HpcJobActivityAction extends AbstractAction { - private static final long serialVersionUID = -8500391011385619809L; + private static final long serialVersionUID = -8500391011385619809L; - private static final String TITLE = "High-Performance Computing Job Activity"; + private static final String TITLE = "High-Performance Computing Job Activity"; - public HpcJobActivityAction(String actionTitle) { - super(actionTitle); - } + public HpcJobActivityAction(String actionTitle) { + super(actionTitle); + } - @Override - public void actionPerformed(ActionEvent e) { - try { - JComponent comp = new HpcJobActivityEditor(); - JOptionPane.showMessageDialog(JOptionUtils.centeringComp(), comp, - TITLE, JOptionPane.PLAIN_MESSAGE); - } catch (HeadlessException e1) { - //e1.printStackTrace(); - } catch (Exception e1) { - //e1.printStackTrace(); + @Override + public void actionPerformed(ActionEvent e) { + try { + Frame ancestor = (Frame) JOptionUtils.centeringComp().getTopLevelAncestor(); + JComponent comp = new HpcJobActivityEditor(); + JOptionPane.showMessageDialog(ancestor, comp, TITLE, JOptionPane.PLAIN_MESSAGE); + } catch (HeadlessException e1) { + // e1.printStackTrace(); + } catch (Exception e1) { + // e1.printStackTrace(); + } } - } } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/action/KillHpcJobAction.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/action/KillHpcJobAction.java index 19b600a56a..a8ed351854 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/action/KillHpcJobAction.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/action/KillHpcJobAction.java @@ -25,101 +25,93 @@ */ public class KillHpcJobAction extends AbstractAction { - private static final long serialVersionUID = 8275717978736439467L; - - private final Component parentComp; - - public KillHpcJobAction(Component parentComp) { - this.parentComp = parentComp; - } - - @Override - public void actionPerformed(ActionEvent e) { - - JTable table = (JTable) e.getSource(); - int modelRow = Integer.valueOf(e.getActionCommand()); - DefaultTableModel activeJobTableModel = (DefaultTableModel) table - .getModel(); - - long jobId = Long.valueOf( - activeJobTableModel.getValueAt(modelRow, HpcJobActivityEditor.ID_COLUMN).toString()) - .longValue(); - - int answer = JOptionPane.showConfirmDialog(parentComp, - "Would you like to cancel this HPC job id: " + jobId + "?", - "Cancel HPC job", JOptionPane.YES_NO_OPTION); - - if (answer == JOptionPane.NO_OPTION) - return; - - TetradDesktop desktop = (TetradDesktop) DesktopController.getInstance(); - final HpcJobManager hpcJobManager = desktop.getHpcJobManager(); - - HpcJobInfo hpcJobInfo = hpcJobManager.findHpcJobInfoById(Long.valueOf( - activeJobTableModel.getValueAt(modelRow, HpcJobActivityEditor.ID_COLUMN).toString()) - .longValue()); - - if (hpcJobInfo != null) { - try { - if (hpcJobInfo.getPid() != null) { - // Update table - activeJobTableModel.setValueAt("Kill Request", modelRow, 1); - table.updateUI(); - - hpcJobInfo = hpcJobManager.requestHpcJobKilled(hpcJobInfo); - // Update hpcJobInfo instance - hpcJobManager.updateHpcJobInfo(hpcJobInfo); - - // Update hpcJobLog instance - HpcJobLog hpcJobLog = hpcJobManager - .getHpcJobLog(hpcJobInfo); - - if (hpcJobLog != null) { - hpcJobLog.setLastUpdatedTime(new Date(System - .currentTimeMillis())); - hpcJobManager.updateHpcJobLog(hpcJobLog); - - // Update hpcJobLogDetail instance - String log = "Requested job id " + hpcJobLog.getId() - + " killed"; - - hpcJobManager.logHpcJobLogDetail(hpcJobLog, 2, log); - } - } else { - // Update table - activeJobTableModel.removeRow(modelRow); - table.updateUI(); - - hpcJobManager.removePendingHpcJob(hpcJobInfo); - - hpcJobInfo.setStatus(4); // Killed - - // Update hpcJobInfo instance - hpcJobManager.updateHpcJobInfo(hpcJobInfo); - - // Update hpcJobLog instance - HpcJobLog hpcJobLog = hpcJobManager - .getHpcJobLog(hpcJobInfo); - if (hpcJobLog != null) { - hpcJobLog.setCanceledTime(new Date(System - .currentTimeMillis())); - hpcJobLog.setLastUpdatedTime(new Date(System - .currentTimeMillis())); - hpcJobManager.updateHpcJobLog(hpcJobLog); - - // Update hpcJobLogDetail instance - String log = "Killed job id " + hpcJobLog.getId(); - - hpcJobManager.logHpcJobLogDetail(hpcJobLog, 4, log); - } - } + private static final long serialVersionUID = 8275717978736439467L; - } catch (Exception e1) { - // TODO Auto-generated catch block - e1.printStackTrace(); - } + private final Component parentComp; + public KillHpcJobAction(Component parentComp) { + this.parentComp = parentComp; } - } + @Override + public void actionPerformed(ActionEvent e) { + + JTable table = (JTable) e.getSource(); + int modelRow = Integer.valueOf(e.getActionCommand()); + DefaultTableModel activeJobTableModel = (DefaultTableModel) table.getModel(); + + long jobId = Long.valueOf(activeJobTableModel.getValueAt(modelRow, HpcJobActivityEditor.ID_COLUMN).toString()) + .longValue(); + + int answer = JOptionPane.showConfirmDialog(parentComp, + "Would you like to cancel this HPC job id: " + jobId + "?", "Cancel HPC job", + JOptionPane.YES_NO_OPTION); + + if (answer == JOptionPane.NO_OPTION) + return; + + TetradDesktop desktop = (TetradDesktop) DesktopController.getInstance(); + final HpcJobManager hpcJobManager = desktop.getHpcJobManager(); + + HpcJobInfo hpcJobInfo = hpcJobManager.findHpcJobInfoById( + Long.valueOf(activeJobTableModel.getValueAt(modelRow, HpcJobActivityEditor.ID_COLUMN).toString()) + .longValue()); + + if (hpcJobInfo != null) { + try { + if (hpcJobInfo.getPid() != null) { + // Update table + activeJobTableModel.setValueAt("Kill Request", modelRow, 1); + table.updateUI(); + + hpcJobInfo = hpcJobManager.requestHpcJobKilled(hpcJobInfo); + // Update hpcJobInfo instance + hpcJobManager.updateHpcJobInfo(hpcJobInfo); + + // Update hpcJobLog instance + HpcJobLog hpcJobLog = hpcJobManager.getHpcJobLog(hpcJobInfo); + + if (hpcJobLog != null) { + hpcJobLog.setLastUpdatedTime(new Date(System.currentTimeMillis())); + hpcJobManager.updateHpcJobLog(hpcJobLog); + + // Update hpcJobLogDetail instance + String log = "Requested job id " + hpcJobLog.getId() + " killed"; + + hpcJobManager.logHpcJobLogDetail(hpcJobLog, 2, log); + } + } else { + // Update table + activeJobTableModel.removeRow(modelRow); + table.updateUI(); + + hpcJobManager.removePendingHpcJob(hpcJobInfo); + + hpcJobInfo.setStatus(4); // Killed + + // Update hpcJobInfo instance + hpcJobManager.updateHpcJobInfo(hpcJobInfo); + + // Update hpcJobLog instance + HpcJobLog hpcJobLog = hpcJobManager.getHpcJobLog(hpcJobInfo); + if (hpcJobLog != null) { + hpcJobLog.setCanceledTime(new Date(System.currentTimeMillis())); + hpcJobLog.setLastUpdatedTime(new Date(System.currentTimeMillis())); + hpcJobManager.updateHpcJobLog(hpcJobLog); + + // Update hpcJobLogDetail instance + String log = "Killed job id " + hpcJobLog.getId(); + + hpcJobManager.logHpcJobLogDetail(hpcJobLog, 4, log); + } + } + + } catch (Exception e1) { + // TODO Auto-generated catch block + e1.printStackTrace(); + } + + } + + } } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/action/LoadHpcGraphJsonAction.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/action/LoadHpcGraphJsonAction.java index 54025e3d31..8017986506 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/action/LoadHpcGraphJsonAction.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/action/LoadHpcGraphJsonAction.java @@ -29,6 +29,7 @@ import edu.cmu.tetrad.util.JOptionUtils; import edu.cmu.tetrad.util.JsonUtils; import edu.cmu.tetradapp.app.TetradDesktop; +import edu.cmu.tetradapp.app.hpc.editor.LoadHpcGraphJsonTableModel; import edu.cmu.tetradapp.app.hpc.manager.HpcAccountManager; import edu.cmu.tetradapp.app.hpc.manager.HpcAccountService; import edu.cmu.tetradapp.app.hpc.manager.HpcJobManager; @@ -49,231 +50,204 @@ */ public class LoadHpcGraphJsonAction extends AbstractAction { - private static final long serialVersionUID = 3640705055173728331L; + private static final long serialVersionUID = 3640705055173728331L; - /** - * The component whose image is to be saved. - */ - private GraphEditable graphEditable; + /** + * The component whose image is to be saved. + */ + private GraphEditable graphEditable; - private String jsonFileName = null; + private String jsonFileName = null; - private HpcAccount hpcAccount = null; + private HpcAccount hpcAccount = null; - public LoadHpcGraphJsonAction(GraphEditable graphEditable, String title) { - super(title); + public LoadHpcGraphJsonAction(GraphEditable graphEditable, String title) { + super(title); - if (graphEditable == null) { - throw new NullPointerException("Component must not be null."); - } + if (graphEditable == null) { + throw new NullPointerException("Component must not be null."); + } - this.graphEditable = graphEditable; - } - - @Override - public void actionPerformed(ActionEvent e) { - TetradDesktop desktop = (TetradDesktop) DesktopController.getInstance(); - final HpcAccountManager hpcAccountManager = desktop.getHpcAccountManager(); - final HpcJobManager hpcJobManager = desktop.getHpcJobManager(); - - JComponent comp = buildHpcJsonChooserComponent(desktop); - int option = JOptionPane.showConfirmDialog( - JOptionUtils.centeringComp(), comp, - "High-Performance Computing Account Json Results Chooser", - JOptionPane.OK_CANCEL_OPTION, JOptionPane.PLAIN_MESSAGE); - - if (option == JOptionPane.OK_OPTION && jsonFileName != null - && hpcAccount != null) { - - try { - HpcAccountService hpcAccountService = hpcJobManager - .getHpcAccountService(hpcAccount); - - ResultService resultService = hpcAccountService.getResultService(); - - String json = resultService.downloadAlgorithmResultFile( - jsonFileName, HpcAccountUtils.getJsonWebToken(hpcAccountManager, hpcAccount)); - - Graph graph = JsonUtils.parseJSONObjectToTetradGraph(json); - GraphUtils.circleLayout(graph, 300, 300, 150); - graphEditable.setGraph(graph); - graphEditable.setName(jsonFileName); - } catch (Exception e1) { - // TODO Auto-generated catch block - e1.printStackTrace(); - } - } else { - System.out.println("Option: OK " - + (option == JOptionPane.OK_OPTION)); - System.out - .println("Option: jsonFileName " + (jsonFileName != null)); - System.out.println("Option: computingAccount " - + (hpcAccount != null)); + this.graphEditable = graphEditable; } - } - private JComponent buildHpcJsonChooserComponent(final TetradDesktop desktop) { - final HpcAccountManager hpcAccountManager = desktop.getHpcAccountManager(); - final HpcJobManager hpcJobManager = desktop.getHpcJobManager(); - // Get ComputingAccount from DB - final DefaultListModel listModel = new DefaultListModel(); + @Override + public void actionPerformed(ActionEvent e) { + TetradDesktop desktop = (TetradDesktop) DesktopController.getInstance(); + final HpcAccountManager hpcAccountManager = desktop.getHpcAccountManager(); + final HpcJobManager hpcJobManager = desktop.getHpcJobManager(); + + JComponent comp = buildHpcJsonChooserComponent(desktop); + int option = JOptionPane.showConfirmDialog(JOptionUtils.centeringComp(), comp, + "High-Performance Computing Account Json Results Chooser", JOptionPane.OK_CANCEL_OPTION, + JOptionPane.PLAIN_MESSAGE); + + if (option == JOptionPane.OK_OPTION && jsonFileName != null && hpcAccount != null) { + + try { + HpcAccountService hpcAccountService = hpcJobManager.getHpcAccountService(hpcAccount); + + ResultService resultService = hpcAccountService.getResultService(); + + String json = resultService.downloadAlgorithmResultFile(jsonFileName, + HpcAccountUtils.getJsonWebToken(hpcAccountManager, hpcAccount)); - for (HpcAccount account : hpcAccountManager.getHpcAccounts()) { - listModel.addElement(account); + Graph graph = JsonUtils.parseJSONObjectToTetradGraph(json); + GraphUtils.circleLayout(graph, 300, 300, 150); + graphEditable.setGraph(graph); + graphEditable.setName(jsonFileName); + } catch (Exception e1) { + // TODO Auto-generated catch block + e1.printStackTrace(); + } + } else { + System.out.println("Option: OK " + (option == JOptionPane.OK_OPTION)); + System.out.println("Option: jsonFileName " + (jsonFileName != null)); + System.out.println("Option: computingAccount " + (hpcAccount != null)); + } } - // JSplitPane - final JSplitPane splitPane = new JSplitPane(JSplitPane.HORIZONTAL_SPLIT); - - // Left pane -> JList (parent pane) - JPanel leftPanel = new JPanel(new BorderLayout()); - - // Right pane -> ComputingAccountResultList - final JPanel jsonResultListPanel = new JPanel(new BorderLayout()); - - int minWidth = 800; - int minHeight = 600; - int screenWidth = Toolkit.getDefaultToolkit().getScreenSize().width; - int screenHeight = Toolkit.getDefaultToolkit().getScreenSize().height; - int frameWidth = screenWidth * 3 / 4; - int frameHeight = screenHeight * 3 / 4; - final int paneWidth = minWidth > frameWidth ? minWidth : frameWidth; - final int paneHeight = minHeight > frameHeight ? minHeight - : frameHeight; - - // JTable - final Vector columnNames = new Vector<>(); - columnNames.addElement("Name"); - columnNames.addElement("Created"); - columnNames.addElement("Last Modified"); - columnNames.addElement("Size"); - - Vector> rowData = new Vector<>(); - - final DefaultTableModel tableModel = new DefaultTableModel(rowData, - columnNames); - final JTable jsonResultTable = new JTable(tableModel); - jsonResultTable.setSelectionMode(ListSelectionModel.SINGLE_SELECTION); - - // Resize table's column width - jsonResultTable.getColumnModel().getColumn(0) - .setPreferredWidth(paneWidth * 2 / 5); - jsonResultTable.getColumnModel().getColumn(1) - .setPreferredWidth(paneWidth * 2 / 15); - jsonResultTable.getColumnModel().getColumn(2) - .setPreferredWidth(paneWidth * 2 / 15); - jsonResultTable.getColumnModel().getColumn(3) - .setPreferredWidth(paneWidth * 2 / 15); - - ListSelectionModel selectionModel = jsonResultTable.getSelectionModel(); - selectionModel.addListSelectionListener(new ListSelectionListener() { - - @Override - public void valueChanged(ListSelectionEvent e) { - int row = jsonResultTable.getSelectedRow(); - if (row >= 0) { - DefaultTableModel model = (DefaultTableModel) jsonResultTable - .getModel(); - jsonFileName = (String) model.getValueAt(row, 0); + private JComponent buildHpcJsonChooserComponent(final TetradDesktop desktop) { + final HpcAccountManager hpcAccountManager = desktop.getHpcAccountManager(); + final HpcJobManager hpcJobManager = desktop.getHpcJobManager(); + // Get ComputingAccount from DB + final DefaultListModel listModel = new DefaultListModel(); + + for (HpcAccount account : hpcAccountManager.getHpcAccounts()) { + listModel.addElement(account); } - } - }); - - final JScrollPane scrollTablePane = new JScrollPane(jsonResultTable); - - jsonResultListPanel.add(scrollTablePane, BorderLayout.CENTER); - - splitPane.setLeftComponent(leftPanel); - splitPane.setRightComponent(jsonResultListPanel); - - // Center Panel - final JList accountList = new JList<>(listModel); - accountList.setSelectionMode(ListSelectionModel.SINGLE_SELECTION); - accountList.setLayoutOrientation(JList.VERTICAL); - accountList.setSelectedIndex(-1); - accountList.addListSelectionListener(new ListSelectionListener() { - - @Override - public void valueChanged(ListSelectionEvent e) { - if (e.getValueIsAdjusting()) - return; - int selectedIndex = ((JList) e.getSource()) - .getSelectedIndex(); - // Show or remove the json list - if (selectedIndex > -1) { - jsonFileName = null; - hpcAccount = listModel.get(selectedIndex); - - TableColumnModel columnModel = jsonResultTable - .getColumnModel(); - List columnWidthList = new ArrayList<>(); - for (int i = 0; i < columnModel.getColumnCount(); i++) { - int width = columnModel.getColumn(i).getPreferredWidth(); - columnWidthList.add(width); - } - - jsonResultTable.clearSelection(); - - try { - HpcAccountService hpcAccountService = hpcJobManager - .getHpcAccountService(hpcAccount); - - ResultService resultService = hpcAccountService.getResultService(); - - Set results = resultService - .listAlgorithmResultFiles(HpcAccountUtils.getJsonWebToken(hpcAccountManager, hpcAccount)); - - Vector> jsonFiles = new Vector<>(); - - for (ResultFile resultFile : results) { - if (resultFile.getName().endsWith(".json")) { - Vector rowData = new Vector<>(); - rowData.addElement(resultFile.getName()); - rowData.addElement(FilePrint - .fileTimestamp(resultFile - .getCreationTime().getTime())); - rowData.addElement(FilePrint - .fileTimestamp(resultFile - .getLastModifiedTime() - .getTime())); - rowData.addElement(FilePrint.humanReadableSize( - resultFile.getFileSize(), false)); - - jsonFiles.add(rowData); - } + + // JSplitPane + final JSplitPane splitPane = new JSplitPane(JSplitPane.HORIZONTAL_SPLIT); + + // Left pane -> JList (parent pane) + JPanel leftPanel = new JPanel(new BorderLayout()); + + // Right pane -> ComputingAccountResultList + final JPanel jsonResultListPanel = new JPanel(new BorderLayout()); + + int minWidth = 800; + int minHeight = 600; + int screenWidth = Toolkit.getDefaultToolkit().getScreenSize().width; + int screenHeight = Toolkit.getDefaultToolkit().getScreenSize().height; + int frameWidth = screenWidth * 3 / 4; + int frameHeight = screenHeight * 3 / 4; + final int paneWidth = minWidth > frameWidth ? minWidth : frameWidth; + final int paneHeight = minHeight > frameHeight ? minHeight : frameHeight; + + // JTable + final Vector columnNames = new Vector<>(); + columnNames.addElement("Name"); + columnNames.addElement("Created"); + columnNames.addElement("Last Modified"); + columnNames.addElement("Size"); + + Vector> rowData = new Vector<>(); + + final DefaultTableModel tableModel = new LoadHpcGraphJsonTableModel(rowData, columnNames); + final JTable jsonResultTable = new JTable(tableModel); + jsonResultTable.setSelectionMode(ListSelectionModel.SINGLE_SELECTION); + + // Resize table's column width + jsonResultTable.getColumnModel().getColumn(0).setPreferredWidth(paneWidth * 2 / 5); + jsonResultTable.getColumnModel().getColumn(1).setPreferredWidth(paneWidth * 2 / 15); + jsonResultTable.getColumnModel().getColumn(2).setPreferredWidth(paneWidth * 2 / 15); + jsonResultTable.getColumnModel().getColumn(3).setPreferredWidth(paneWidth * 2 / 15); + + ListSelectionModel selectionModel = jsonResultTable.getSelectionModel(); + selectionModel.addListSelectionListener(new ListSelectionListener() { + + @Override + public void valueChanged(ListSelectionEvent e) { + int row = jsonResultTable.getSelectedRow(); + if (row >= 0) { + DefaultTableModel model = (DefaultTableModel) jsonResultTable.getModel(); + jsonFileName = (String) model.getValueAt(row, 0); + } } + }); - tableModel.setDataVector(jsonFiles, columnNames); + final JScrollPane scrollTablePane = new JScrollPane(jsonResultTable); - } catch (Exception e1) { - // TODO Auto-generated catch block - e1.printStackTrace(); - } + jsonResultListPanel.add(scrollTablePane, BorderLayout.CENTER); - // Resize table's column width - for (int i = 0; i < columnModel.getColumnCount(); i++) { - jsonResultTable - .getColumnModel() - .getColumn(i) - .setPreferredWidth( - columnWidthList.get(i).intValue()); - } + splitPane.setLeftComponent(leftPanel); + splitPane.setRightComponent(jsonResultListPanel); - } - } - }); + // Center Panel + final JList accountList = new JList<>(listModel); + accountList.setSelectionMode(ListSelectionModel.SINGLE_SELECTION); + accountList.setLayoutOrientation(JList.VERTICAL); + accountList.setSelectedIndex(-1); + accountList.addListSelectionListener(new ListSelectionListener() { + + @Override + public void valueChanged(ListSelectionEvent e) { + if (e.getValueIsAdjusting()) + return; + int selectedIndex = ((JList) e.getSource()).getSelectedIndex(); + // Show or remove the json list + if (selectedIndex > -1) { + jsonFileName = null; + hpcAccount = listModel.get(selectedIndex); + + TableColumnModel columnModel = jsonResultTable.getColumnModel(); + List columnWidthList = new ArrayList<>(); + for (int i = 0; i < columnModel.getColumnCount(); i++) { + int width = columnModel.getColumn(i).getPreferredWidth(); + columnWidthList.add(width); + } - // Left Panel - JScrollPane accountListScroller = new JScrollPane(accountList); - leftPanel.add(accountListScroller, BorderLayout.CENTER); + jsonResultTable.clearSelection(); - splitPane.setDividerLocation(paneWidth / 5); - accountListScroller.setPreferredSize(new Dimension(paneWidth / 5, - paneHeight)); - jsonResultListPanel.setPreferredSize(new Dimension(paneWidth * 4 / 5, - paneHeight)); + try { + HpcAccountService hpcAccountService = hpcJobManager.getHpcAccountService(hpcAccount); - return splitPane; - } + ResultService resultService = hpcAccountService.getResultService(); + + Set results = resultService.listAlgorithmResultFiles( + HpcAccountUtils.getJsonWebToken(hpcAccountManager, hpcAccount)); + + Vector> jsonFiles = new Vector<>(); + + for (ResultFile resultFile : results) { + if (resultFile.getName().endsWith(".json")) { + Vector rowData = new Vector<>(); + rowData.addElement(resultFile.getName()); + rowData.addElement(FilePrint.fileTimestamp(resultFile.getCreationTime().getTime())); + rowData.addElement(FilePrint.fileTimestamp(resultFile.getLastModifiedTime().getTime())); + rowData.addElement(FilePrint.humanReadableSize(resultFile.getFileSize(), false)); + + jsonFiles.add(rowData); + } + } + + tableModel.setDataVector(jsonFiles, columnNames); + + } catch (Exception e1) { + // TODO Auto-generated catch block + e1.printStackTrace(); + } + + // Resize table's column width + for (int i = 0; i < columnModel.getColumnCount(); i++) { + jsonResultTable.getColumnModel().getColumn(i) + .setPreferredWidth(columnWidthList.get(i).intValue()); + } + + } + } + }); + + // Left Panel + JScrollPane accountListScroller = new JScrollPane(accountList); + leftPanel.add(accountListScroller, BorderLayout.CENTER); + + splitPane.setDividerLocation(paneWidth / 5); + accountListScroller.setPreferredSize(new Dimension(paneWidth / 5, paneHeight)); + jsonResultListPanel.setPreferredSize(new Dimension(paneWidth * 4 / 5, paneHeight)); + + return splitPane; + } } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/editor/HpcAccountEditor.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/editor/HpcAccountEditor.java index 7ea02958b4..a6e0d12db3 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/editor/HpcAccountEditor.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/editor/HpcAccountEditor.java @@ -34,280 +34,273 @@ */ public class HpcAccountEditor extends JPanel { - private static final long serialVersionUID = -3028667139958773226L; - - private final JComponent parentComponent; - - private final DefaultListModel listModel; - - private final HpcAccountManager hpcAccountManager; - - private final HpcAccount hpcAccount; - - public HpcAccountEditor( - final JComponent parentComponent, - final DefaultListModel listModel, - final HpcAccountManager hpcAccountManager, - final HpcAccount hpcAccount) { - this.parentComponent = parentComponent; - this.listModel = listModel; - this.hpcAccountManager = hpcAccountManager; - this.hpcAccount = hpcAccount; - initiateUI(); - } - - private void initiateUI() { - setLayout(new BorderLayout()); - - // Header - JPanel headerPanel = new JPanel(new BorderLayout()); - JLabel headerLabel = new JLabel("High-Performance Computing Account"); - headerLabel.setFont(new Font(headerLabel.getFont().getName(), - Font.BOLD, 20)); - headerPanel.add(headerLabel, BorderLayout.CENTER); - JPanel spacePanel = new JPanel(); - spacePanel.setSize(300, 100); - headerPanel.add(spacePanel, BorderLayout.SOUTH); - add(headerPanel, BorderLayout.NORTH); - - // Content - Box contentBox = Box.createVerticalBox(); - - // Connection Name - Box connBox = Box.createHorizontalBox(); - JLabel connLabel = new JLabel("Connection", JLabel.TRAILING); - connLabel.setPreferredSize(new Dimension(100, 5)); - connBox.add(connLabel); - - final JTextField connField = new JTextField(20); - connField.setText(hpcAccount.getConnectionName()); - connField.addKeyListener(new KeyListener() { - - @Override - public void keyTyped(KeyEvent e) { - } - - @Override - public void keyReleased(KeyEvent e) { - hpcAccount.setConnectionName(connField.getText()); - } - - @Override - public void keyPressed(KeyEvent e) { - } - }); - connLabel.setLabelFor(connField); - connBox.add(connField); - - contentBox.add(connBox); - - // Username - Box userBox = Box.createHorizontalBox(); - JLabel userLabel = new JLabel("Username", JLabel.TRAILING); - userLabel.setPreferredSize(new Dimension(100, 5)); - userBox.add(userLabel); - - final JTextField userField = new JTextField(20); - userField.setText(hpcAccount.getUsername()); - userField.addKeyListener(new KeyListener() { - - @Override - public void keyTyped(KeyEvent e) { - } - - @Override - public void keyReleased(KeyEvent e) { - hpcAccount.setUsername(userField.getText()); - } - - @Override - public void keyPressed(KeyEvent e) { - } - }); - userLabel.setLabelFor(userField); - userBox.add(userField); - - contentBox.add(userBox); - - // Password - Box passBox = Box.createHorizontalBox(); - JLabel passLabel = new JLabel("Password", JLabel.TRAILING); - passLabel.setPreferredSize(new Dimension(100, 5)); - passBox.add(passLabel); - - final JPasswordField passField = new JPasswordField(20); - passField.setText(hpcAccount.getPassword()); - passField.addKeyListener(new KeyListener() { - - @Override - public void keyTyped(KeyEvent e) { - } - - @Override - public void keyReleased(KeyEvent e) { - hpcAccount.setPassword(new String(passField.getPassword())); - } - - @Override - public void keyPressed(KeyEvent e) { - } - }); - passLabel.setLabelFor(passField); - passBox.add(passField); - - contentBox.add(passBox); - - // Scheme - JPanel schemePanel = new JPanel(new BorderLayout()); - JLabel schemeLabel = new JLabel("Scheme", JLabel.TRAILING); - schemeLabel.setPreferredSize(new Dimension(100, 5)); - schemePanel.add(schemeLabel, BorderLayout.WEST); - - final JRadioButton httpRadioButton = new JRadioButton("http"); - final JRadioButton httpsRadioButton = new JRadioButton("https"); - if (hpcAccount.getScheme().equalsIgnoreCase("https")) { - httpsRadioButton.setSelected(true); - } else { - httpRadioButton.setSelected(true); + private static final long serialVersionUID = -3028667139958773226L; + + private final JComponent parentComponent; + + private final DefaultListModel listModel; + + private final HpcAccountManager hpcAccountManager; + + private final HpcAccount hpcAccount; + + public HpcAccountEditor(final JComponent parentComponent, final DefaultListModel listModel, + final HpcAccountManager hpcAccountManager, final HpcAccount hpcAccount) { + this.parentComponent = parentComponent; + this.listModel = listModel; + this.hpcAccountManager = hpcAccountManager; + this.hpcAccount = hpcAccount; + initiateUI(); } - ButtonGroup schemeGroup = new ButtonGroup(); - schemeGroup.add(httpRadioButton); - schemeGroup.add(httpsRadioButton); - Box schemeRadioBox = Box.createHorizontalBox(); - schemeRadioBox.add(httpRadioButton); - schemeRadioBox.add(httpsRadioButton); - schemeLabel.setLabelFor(schemeRadioBox); - ActionListener schemeActionListener = new ActionListener() { - - @Override - public void actionPerformed(ActionEvent e) { - if (httpRadioButton.isSelected()) { - hpcAccount.setScheme("http"); + + private void initiateUI() { + setLayout(new BorderLayout()); + + // Header + JPanel headerPanel = new JPanel(new BorderLayout()); + JLabel headerLabel = new JLabel("High-Performance Computing Account"); + headerLabel.setFont(new Font(headerLabel.getFont().getName(), Font.BOLD, 20)); + headerPanel.add(headerLabel, BorderLayout.CENTER); + JPanel spacePanel = new JPanel(); + spacePanel.setSize(300, 100); + headerPanel.add(spacePanel, BorderLayout.SOUTH); + add(headerPanel, BorderLayout.NORTH); + + // Content + Box contentBox = Box.createVerticalBox(); + + // Connection Name + Box connBox = Box.createHorizontalBox(); + JLabel connLabel = new JLabel("Connection", JLabel.TRAILING); + connLabel.setPreferredSize(new Dimension(100, 5)); + connBox.add(connLabel); + + final JTextField connField = new JTextField(20); + connField.setText(hpcAccount.getConnectionName()); + connField.addKeyListener(new KeyListener() { + + @Override + public void keyTyped(KeyEvent e) { + } + + @Override + public void keyReleased(KeyEvent e) { + hpcAccount.setConnectionName(connField.getText()); + } + + @Override + public void keyPressed(KeyEvent e) { + } + }); + connLabel.setLabelFor(connField); + connBox.add(connField); + + contentBox.add(connBox); + + // Username + Box userBox = Box.createHorizontalBox(); + JLabel userLabel = new JLabel("Username", JLabel.TRAILING); + userLabel.setPreferredSize(new Dimension(100, 5)); + userBox.add(userLabel); + + final JTextField userField = new JTextField(20); + userField.setText(hpcAccount.getUsername()); + userField.addKeyListener(new KeyListener() { + + @Override + public void keyTyped(KeyEvent e) { + } + + @Override + public void keyReleased(KeyEvent e) { + hpcAccount.setUsername(userField.getText()); + } + + @Override + public void keyPressed(KeyEvent e) { + } + }); + userLabel.setLabelFor(userField); + userBox.add(userField); + + contentBox.add(userBox); + + // Password + Box passBox = Box.createHorizontalBox(); + JLabel passLabel = new JLabel("Password", JLabel.TRAILING); + passLabel.setPreferredSize(new Dimension(100, 5)); + passBox.add(passLabel); + + final JPasswordField passField = new JPasswordField(20); + passField.setText(hpcAccount.getPassword()); + passField.addKeyListener(new KeyListener() { + + @Override + public void keyTyped(KeyEvent e) { + } + + @Override + public void keyReleased(KeyEvent e) { + hpcAccount.setPassword(new String(passField.getPassword())); + } + + @Override + public void keyPressed(KeyEvent e) { + } + }); + passLabel.setLabelFor(passField); + passBox.add(passField); + + contentBox.add(passBox); + + // Scheme + JPanel schemePanel = new JPanel(new BorderLayout()); + JLabel schemeLabel = new JLabel("Scheme", JLabel.TRAILING); + schemeLabel.setPreferredSize(new Dimension(100, 5)); + schemePanel.add(schemeLabel, BorderLayout.WEST); + + final JRadioButton httpRadioButton = new JRadioButton("http"); + final JRadioButton httpsRadioButton = new JRadioButton("https"); + if (hpcAccount.getScheme().equalsIgnoreCase("https")) { + httpsRadioButton.setSelected(true); } else { - hpcAccount.setScheme("https"); - } - } - }; - httpRadioButton.addActionListener(schemeActionListener); - httpsRadioButton.addActionListener(schemeActionListener); - schemePanel.add(schemeRadioBox, BorderLayout.CENTER); - - contentBox.add(schemePanel); - - // Host - Box hostBox = Box.createHorizontalBox(); - JLabel hostLabel = new JLabel("Host Name", JLabel.TRAILING); - hostLabel.setPreferredSize(new Dimension(100, 5)); - hostBox.add(hostLabel); - - final JTextField hostField = new JTextField(20); - hostField.setText(hpcAccount.getHostname()); - hostField.addKeyListener(new KeyListener() { - - @Override - public void keyTyped(KeyEvent e) { - } - - @Override - public void keyReleased(KeyEvent e) { - hpcAccount.setHostname(hostField.getText()); - } - - @Override - public void keyPressed(KeyEvent e) { - } - }); - hostLabel.setLabelFor(hostField); - hostBox.add(hostField); - - contentBox.add(hostBox); - - // Port number - Box portBox = Box.createHorizontalBox(); - JLabel portLabel = new JLabel("Port Number", JLabel.TRAILING); - portLabel.setPreferredSize(new Dimension(100, 5)); - portBox.add(portLabel); - - final JTextField portField = new JTextField(20); - portField.setText(String.valueOf(hpcAccount.getPort())); - portField.addKeyListener(new KeyListener() { - - @Override - public void keyTyped(KeyEvent e) { - } - - @Override - public void keyReleased(KeyEvent e) { - try { - int port = Integer.parseInt(portField.getText()); - hpcAccount.setPort(port); - } catch (NumberFormatException e1) { - // TODO Auto-generated catch block - if (portField.getText().trim().length() > 0) { - JOptionPane.showMessageDialog(portField, - "Port number is decimal number only!"); - } + httpRadioButton.setSelected(true); } - } - - @Override - public void keyPressed(KeyEvent e) { - } - }); - portLabel.setLabelFor(portField); - portBox.add(portField); - - contentBox.add(portBox); - - JPanel contentPanel = new JPanel(new BorderLayout()); - contentPanel.add(contentBox, BorderLayout.NORTH); - add(contentPanel, BorderLayout.CENTER); - - // Footer -> Test and Save buttons - JPanel footerPanel = new JPanel(new BorderLayout()); - final JButton testConnButton = new JButton("Test Connection"); - testConnButton.addActionListener(new ActionListener() { - - @Override - public void actionPerformed(ActionEvent e) { - JButton button = (JButton) e.getSource(); - button.setText("Testing..."); - parentComponent.updateUI(); - button.setEnabled(false); - boolean success = HpcAccountUtils - .testConnection(hpcAccountManager, hpcAccount); - // Pop-up the test result - JOptionPane.showMessageDialog(null, "" - + hpcAccount + " Connection " - + (success ? "Successful" : "Failed"), - "HPC Account Setting", JOptionPane.INFORMATION_MESSAGE); - button.setEnabled(true); - button.setText("Test Connection"); - hpcAccount.setLastLoginDate(new Date()); - hpcAccountManager.saveAccount(hpcAccount); - } - }); - JButton saveButton = new JButton("Save"); - saveButton.addActionListener(new ActionListener() { - - @Override - public void actionPerformed(ActionEvent e) { - JButton button = (JButton) e.getSource(); - button.setText("Saving..."); - parentComponent.updateUI(); - button.setEnabled(false); - hpcAccountManager.saveAccount(hpcAccount); - if (listModel.indexOf(hpcAccount) == -1) { - listModel.addElement(hpcAccount); - } - button.setEnabled(true); - button.setText("Save"); - parentComponent.updateUI(); - } - }); - footerPanel.add(testConnButton, BorderLayout.WEST); - footerPanel.add(saveButton, BorderLayout.EAST); - add(footerPanel, BorderLayout.SOUTH); - } + ButtonGroup schemeGroup = new ButtonGroup(); + schemeGroup.add(httpRadioButton); + schemeGroup.add(httpsRadioButton); + Box schemeRadioBox = Box.createHorizontalBox(); + schemeRadioBox.add(httpRadioButton); + schemeRadioBox.add(httpsRadioButton); + schemeLabel.setLabelFor(schemeRadioBox); + ActionListener schemeActionListener = new ActionListener() { + + @Override + public void actionPerformed(ActionEvent e) { + if (httpRadioButton.isSelected()) { + hpcAccount.setScheme("http"); + } else { + hpcAccount.setScheme("https"); + } + } + }; + httpRadioButton.addActionListener(schemeActionListener); + httpsRadioButton.addActionListener(schemeActionListener); + schemePanel.add(schemeRadioBox, BorderLayout.CENTER); + + contentBox.add(schemePanel); + + // Host + Box hostBox = Box.createHorizontalBox(); + JLabel hostLabel = new JLabel("Host Name", JLabel.TRAILING); + hostLabel.setPreferredSize(new Dimension(100, 5)); + hostBox.add(hostLabel); + + final JTextField hostField = new JTextField(20); + hostField.setText(hpcAccount.getHostname()); + hostField.addKeyListener(new KeyListener() { + + @Override + public void keyTyped(KeyEvent e) { + } + + @Override + public void keyReleased(KeyEvent e) { + hpcAccount.setHostname(hostField.getText()); + } + + @Override + public void keyPressed(KeyEvent e) { + } + }); + hostLabel.setLabelFor(hostField); + hostBox.add(hostField); + + contentBox.add(hostBox); + + // Port number + Box portBox = Box.createHorizontalBox(); + JLabel portLabel = new JLabel("Port Number", JLabel.TRAILING); + portLabel.setPreferredSize(new Dimension(100, 5)); + portBox.add(portLabel); + + final JTextField portField = new JTextField(20); + portField.setText(String.valueOf(hpcAccount.getPort())); + portField.addKeyListener(new KeyListener() { + + @Override + public void keyTyped(KeyEvent e) { + } + + @Override + public void keyReleased(KeyEvent e) { + try { + int port = Integer.parseInt(portField.getText()); + hpcAccount.setPort(port); + } catch (NumberFormatException e1) { + // TODO Auto-generated catch block + if (portField.getText().trim().length() > 0) { + JOptionPane.showMessageDialog(portField, "Port number is decimal number only!"); + } + } + } + + @Override + public void keyPressed(KeyEvent e) { + } + }); + portLabel.setLabelFor(portField); + portBox.add(portField); + + contentBox.add(portBox); + + JPanel contentPanel = new JPanel(new BorderLayout()); + contentPanel.add(contentBox, BorderLayout.NORTH); + add(contentPanel, BorderLayout.CENTER); + + // Footer -> Test and Save buttons + JPanel footerPanel = new JPanel(new BorderLayout()); + final JButton testConnButton = new JButton("Test Connection"); + testConnButton.addActionListener(new ActionListener() { + + @Override + public void actionPerformed(ActionEvent e) { + JButton button = (JButton) e.getSource(); + button.setText("Testing..."); + parentComponent.updateUI(); + button.setEnabled(false); + boolean success = HpcAccountUtils.testConnection(hpcAccountManager, hpcAccount); + // Pop-up the test result + JOptionPane.showMessageDialog(null, + "" + hpcAccount + " Connection " + (success ? "Successful" : "Failed"), "HPC Account Setting", + JOptionPane.INFORMATION_MESSAGE); + button.setEnabled(true); + button.setText("Test Connection"); + hpcAccount.setLastLoginDate(new Date()); + hpcAccountManager.saveAccount(hpcAccount); + } + }); + JButton saveButton = new JButton("Save"); + saveButton.addActionListener(new ActionListener() { + + @Override + public void actionPerformed(ActionEvent e) { + JButton button = (JButton) e.getSource(); + button.setText("Saving..."); + parentComponent.updateUI(); + button.setEnabled(false); + hpcAccountManager.saveAccount(hpcAccount); + if (listModel.indexOf(hpcAccount) == -1) { + listModel.addElement(hpcAccount); + } + button.setEnabled(true); + button.setText("Save"); + parentComponent.updateUI(); + } + }); + footerPanel.add(testConnButton, BorderLayout.WEST); + footerPanel.add(saveButton, BorderLayout.EAST); + add(footerPanel, BorderLayout.SOUTH); + } } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/editor/HpcJobActivityEditor.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/editor/HpcJobActivityEditor.java index 26d5c2de31..00fdc1a6ec 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/editor/HpcJobActivityEditor.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/editor/HpcJobActivityEditor.java @@ -53,671 +53,621 @@ */ public class HpcJobActivityEditor extends JPanel implements FinalizingEditor { - private static final long serialVersionUID = -6178713484456753741L; - - private final List checkedHpcAccountList; - - private final Set pendingDisplayHpcJobInfoSet; - - private final Set submittedDisplayHpcJobInfoSet; - - private final Timer pendingTimer; - - private final Timer submittedTimer; - - private int PENDING_TIME_INTERVAL = 100; - - private int SUBMITTED_TIME_INTERVAL = 1000; - - private JTable jobsTable; - - private JTabbedPane tabbedPane; - - private PendingHpcJobUpdaterTask pendingJobUpdater; - - private SubmittedHpcJobUpdaterTask submittedJobUpdater; - - public final static int ID_COLUMN = 0; - public final static int STATUS_COLUMN = 1; - public final static int DATA_UPLOAD_COLUMN = 5; - public final static int KNOWLEDGE_UPLOAD_COLUMN = 6; - public final static int ACTIVE_SUBMITTED_COLUMN = 7; - public final static int ACTIVE_HPC_JOB_ID_COLUMN = 8; - public final static int ACTIVE_LAST_UPDATED_COLUMN = 9; - public final static int KILL_BUTTON_COLUMN = 10; - - public final static int DELETE_BUTTON_COLUMN = 11; - - public HpcJobActivityEditor() throws Exception{ - checkedHpcAccountList = new ArrayList<>(); - pendingDisplayHpcJobInfoSet = new HashSet<>(); - submittedDisplayHpcJobInfoSet = new HashSet<>(); - this.pendingTimer = new Timer(); - this.submittedTimer = new Timer(); - TetradDesktop desktop = (TetradDesktop) DesktopController.getInstance(); - buildHpcJobActivityComponent(desktop); - } - - private void buildHpcJobActivityComponent(final TetradDesktop desktop) - throws Exception { - setLayout(new BorderLayout()); - - final JPanel controllerPane = new JPanel(new BorderLayout()); - add(controllerPane, BorderLayout.NORTH); - Dimension preferredSize = new Dimension(100, 100); - controllerPane.setPreferredSize(preferredSize); - buildController(controllerPane, desktop); - - final JPanel contentPanel = new JPanel(new BorderLayout()); - add(contentPanel, BorderLayout.CENTER); - buildActivityContent(contentPanel, desktop); - - int minWidth = 800; - int minHeight = 600; - int screenWidth = Toolkit.getDefaultToolkit().getScreenSize().width; - int screenHeight = Toolkit.getDefaultToolkit().getScreenSize().height; - int frameWidth = screenWidth * 5 / 6; - int frameHeight = screenHeight * 3 / 4; - final int paneWidth = minWidth > frameWidth ? minWidth : frameWidth; - final int paneHeight = minHeight > frameHeight ? minHeight - : frameHeight; - - setPreferredSize(new Dimension(paneWidth, paneHeight)); - } - - private class HpcAccountSelectionAction implements ActionListener { - - private final List hpcAccounts; - - public HpcAccountSelectionAction(final List hpcAccounts) { - this.hpcAccounts = hpcAccounts; + private static final long serialVersionUID = -6178713484456753741L; + + private final List checkedHpcAccountList; + + private final Set pendingDisplayHpcJobInfoSet; + + private final Set submittedDisplayHpcJobInfoSet; + + private final Timer pendingTimer; + + private final Timer submittedTimer; + + private int PENDING_TIME_INTERVAL = 100; + + private int SUBMITTED_TIME_INTERVAL = 1000; + + private JTable jobsTable; + + private JTabbedPane tabbedPane; + + private PendingHpcJobUpdaterTask pendingJobUpdater; + + private SubmittedHpcJobUpdaterTask submittedJobUpdater; + + public final static int ID_COLUMN = 0; + public final static int STATUS_COLUMN = 1; + public final static int DATA_UPLOAD_COLUMN = 5; + public final static int KNOWLEDGE_UPLOAD_COLUMN = 6; + public final static int ACTIVE_SUBMITTED_COLUMN = 7; + public final static int ACTIVE_HPC_JOB_ID_COLUMN = 8; + public final static int ACTIVE_LAST_UPDATED_COLUMN = 9; + public final static int KILL_BUTTON_COLUMN = 10; + + public final static int DELETE_BUTTON_COLUMN = 11; + + public HpcJobActivityEditor() throws Exception { + checkedHpcAccountList = new ArrayList<>(); + pendingDisplayHpcJobInfoSet = new HashSet<>(); + submittedDisplayHpcJobInfoSet = new HashSet<>(); + this.pendingTimer = new Timer(); + this.submittedTimer = new Timer(); + TetradDesktop desktop = (TetradDesktop) DesktopController.getInstance(); + buildHpcJobActivityComponent(desktop); } - @Override - public void actionPerformed(ActionEvent e) { - final JCheckBox checkBox = (JCheckBox) e.getSource(); - for (HpcAccount hpcAccount : hpcAccounts) { - if (checkBox.getText().equals(hpcAccount.getConnectionName())) { - if (checkBox.isSelected() - && !checkedHpcAccountList.contains(hpcAccount)) { - checkedHpcAccountList.add(hpcAccount); - } else if (!checkBox.isSelected() - && checkedHpcAccountList.contains(hpcAccount)) { - checkedHpcAccountList.remove(hpcAccount); - } + private void buildHpcJobActivityComponent(final TetradDesktop desktop) throws Exception { + setLayout(new BorderLayout()); + + final JPanel controllerPane = new JPanel(new BorderLayout()); + add(controllerPane, BorderLayout.NORTH); + Dimension preferredSize = new Dimension(100, 100); + controllerPane.setPreferredSize(preferredSize); + buildController(controllerPane, desktop); + + final JPanel contentPanel = new JPanel(new BorderLayout()); + add(contentPanel, BorderLayout.CENTER); + buildActivityContent(contentPanel, desktop); + + int minWidth = 800; + int minHeight = 600; + int screenWidth = Toolkit.getDefaultToolkit().getScreenSize().width; + int screenHeight = Toolkit.getDefaultToolkit().getScreenSize().height; + int frameWidth = screenWidth * 5 / 6; + int frameHeight = screenHeight * 3 / 4; + final int paneWidth = minWidth > frameWidth ? minWidth : frameWidth; + final int paneHeight = minHeight > frameHeight ? minHeight : frameHeight; + + setPreferredSize(new Dimension(paneWidth, paneHeight)); + } + + private class HpcAccountSelectionAction implements ActionListener { + + private final List hpcAccounts; + + public HpcAccountSelectionAction(final List hpcAccounts) { + this.hpcAccounts = hpcAccounts; + } + + @Override + public void actionPerformed(ActionEvent e) { + final JCheckBox checkBox = (JCheckBox) e.getSource(); + for (HpcAccount hpcAccount : hpcAccounts) { + if (checkBox.getText().equals(hpcAccount.getConnectionName())) { + if (checkBox.isSelected() && !checkedHpcAccountList.contains(hpcAccount)) { + checkedHpcAccountList.add(hpcAccount); + } else if (!checkBox.isSelected() && checkedHpcAccountList.contains(hpcAccount)) { + checkedHpcAccountList.remove(hpcAccount); + } + } + } + int index = tabbedPane.getSelectedIndex(); + tabbedPane.setSelectedIndex(-1); + tabbedPane.setSelectedIndex(index); } - } - int index = tabbedPane.getSelectedIndex(); - tabbedPane.setSelectedIndex(-1); - tabbedPane.setSelectedIndex(index); + } - } + private void buildController(final JPanel controllerPane, final TetradDesktop desktop) { + // Content + Box contentBox = Box.createVerticalBox(); + + JPanel hpcPanel = new JPanel(new BorderLayout()); - private void buildController(final JPanel controllerPane, - final TetradDesktop desktop) { - // Content - Box contentBox = Box.createVerticalBox(); + JLabel hpcAccountLabel = new JLabel("HPC Account: ", JLabel.TRAILING); + hpcAccountLabel.setPreferredSize(new Dimension(100, 5)); + hpcPanel.add(hpcAccountLabel, BorderLayout.WEST); - JPanel hpcPanel = new JPanel(new BorderLayout()); + Box hpcAccountCheckBox = Box.createHorizontalBox(); + final HpcAccountManager hpcAccountManager = desktop.getHpcAccountManager(); + List hpcAccounts = hpcAccountManager.getHpcAccounts(); - JLabel hpcAccountLabel = new JLabel("HPC Account: ", JLabel.TRAILING); - hpcAccountLabel.setPreferredSize(new Dimension(100, 5)); - hpcPanel.add(hpcAccountLabel, BorderLayout.WEST); + HpcAccountSelectionAction hpcAccountSelectionAction = new HpcAccountSelectionAction(hpcAccounts); - Box hpcAccountCheckBox = Box.createHorizontalBox(); - final HpcAccountManager hpcAccountManager = desktop - .getHpcAccountManager(); - List hpcAccounts = hpcAccountManager.getHpcAccounts(); + for (HpcAccount hpcAccount : hpcAccounts) { + checkedHpcAccountList.add(hpcAccount); + final JCheckBox hpcCheckBox = new JCheckBox(hpcAccount.getConnectionName(), true); + hpcCheckBox.addActionListener(hpcAccountSelectionAction); + hpcAccountCheckBox.add(hpcCheckBox); - HpcAccountSelectionAction hpcAccountSelectionAction = new HpcAccountSelectionAction( - hpcAccounts); + } + hpcPanel.add(hpcAccountCheckBox, BorderLayout.CENTER); - for (HpcAccount hpcAccount : hpcAccounts) { - checkedHpcAccountList.add(hpcAccount); - final JCheckBox hpcCheckBox = new JCheckBox( - hpcAccount.getConnectionName(), true); - hpcCheckBox.addActionListener(hpcAccountSelectionAction); - hpcAccountCheckBox.add(hpcCheckBox); + contentBox.add(hpcPanel); + controllerPane.add(contentBox, BorderLayout.CENTER); } - hpcPanel.add(hpcAccountCheckBox, BorderLayout.CENTER); - contentBox.add(hpcPanel); + private void buildActivityContent(final JPanel activityPanel, final TetradDesktop desktop) throws Exception { - controllerPane.add(contentBox, BorderLayout.CENTER); - } + jobsTable = new JTable(); + jobsTable.setSelectionMode(ListSelectionModel.SINGLE_SELECTION); - private void buildActivityContent(final JPanel activityPanel, - final TetradDesktop desktop) throws Exception { + final JScrollPane scrollTablePane = new JScrollPane(jobsTable); - jobsTable = new JTable(); - jobsTable.setSelectionMode(ListSelectionModel.SINGLE_SELECTION); + tabbedPane = new JTabbedPane(); - final JScrollPane scrollTablePane = new JScrollPane(jobsTable); + JPanel activeJobsPanel = new JPanel(new BorderLayout()); + activeJobsPanel.add(scrollTablePane, BorderLayout.CENTER); + tabbedPane.add("Active Jobs", activeJobsPanel); - tabbedPane = new JTabbedPane(); + final KillHpcJobAction killJobAction = new KillHpcJobAction(this); - JPanel activeJobsPanel = new JPanel(new BorderLayout()); - activeJobsPanel.add(scrollTablePane, BorderLayout.CENTER); - tabbedPane.add("Active Jobs", activeJobsPanel); + JPanel finishedJobsPanel = new JPanel(new BorderLayout()); - final KillHpcJobAction killJobAction = new KillHpcJobAction( - this); + tabbedPane.add("Finished Jobs", finishedJobsPanel); - JPanel finishedJobsPanel = new JPanel(new BorderLayout()); + final DeleteHpcJobInfoAction deleteJobAction = new DeleteHpcJobInfoAction(this); - tabbedPane.add("Finished Jobs", finishedJobsPanel); + ChangeListener changeListener = new ChangeListener() { - final DeleteHpcJobInfoAction deleteJobAction = new DeleteHpcJobInfoAction( - this); + @Override + public void stateChanged(ChangeEvent e) { + JTabbedPane sourceTabbedPane = (JTabbedPane) e.getSource(); + int index = sourceTabbedPane.getSelectedIndex(); + if (index == 0) { + finishedJobsPanel.remove(scrollTablePane); + activeJobsPanel.add(scrollTablePane, BorderLayout.CENTER); + try { + final Vector activeColumnNames = genActiveJobColumnNames(); + final Vector> activeRowData = getActiveRowData(desktop, checkedHpcAccountList); - ChangeListener changeListener = new ChangeListener() { + final DefaultTableModel activeJobTableModel = new HpcJobInfoTableModel(activeRowData, + activeColumnNames, KILL_BUTTON_COLUMN); - @Override - public void stateChanged(ChangeEvent e) { - JTabbedPane sourceTabbedPane = (JTabbedPane) e.getSource(); - int index = sourceTabbedPane.getSelectedIndex(); - if (index == 0) { - finishedJobsPanel.remove(scrollTablePane); - activeJobsPanel.add(scrollTablePane, BorderLayout.CENTER); - try { - final Vector activeColumnNames = genActiveJobColumnNames(); - final Vector> activeRowData = getActiveRowData( - desktop, checkedHpcAccountList); + jobsTable.setModel(activeJobTableModel); - final DefaultTableModel activeJobTableModel = new HpcJobInfoTableModel( - activeRowData, activeColumnNames, - KILL_BUTTON_COLUMN); + if (activeRowData.size() > 0) { + new ButtonColumn(jobsTable, killJobAction, KILL_BUTTON_COLUMN); + } - jobsTable.setModel(activeJobTableModel); + adjustActiveJobsWidthColumns(jobsTable); + jobsTable.updateUI(); + } catch (Exception e1) { + // TODO Auto-generated catch block + e1.printStackTrace(); + } + } else if (index == 1) { + activeJobsPanel.remove(scrollTablePane); + finishedJobsPanel.add(scrollTablePane, BorderLayout.CENTER); + try { + final Vector finishedColumnNames = genFinishedJobColumnNames(); + final Vector> finishedRowData = getFinishedRowData(desktop, + checkedHpcAccountList); - if (activeRowData.size() > 0) { - new ButtonColumn(jobsTable, killJobAction, - KILL_BUTTON_COLUMN); - } + final DefaultTableModel finishedJobTableModel = new HpcJobInfoTableModel(finishedRowData, + finishedColumnNames, DELETE_BUTTON_COLUMN); - adjustActiveJobsWidthColumns(jobsTable); - jobsTable.updateUI(); - } catch (Exception e1) { - // TODO Auto-generated catch block - e1.printStackTrace(); - } - } else if (index == 1) { - activeJobsPanel.remove(scrollTablePane); - finishedJobsPanel.add(scrollTablePane, BorderLayout.CENTER); - try { - final Vector finishedColumnNames = genFinishedJobColumnNames(); - final Vector> finishedRowData = getFinishedRowData( - desktop, checkedHpcAccountList); - - final DefaultTableModel finishedJobTableModel = new HpcJobInfoTableModel( - finishedRowData, finishedColumnNames, - DELETE_BUTTON_COLUMN); - - jobsTable.setModel(finishedJobTableModel); - - if (finishedRowData.size() > 0) { - new ButtonColumn(jobsTable, deleteJobAction, - DELETE_BUTTON_COLUMN); + jobsTable.setModel(finishedJobTableModel); + + if (finishedRowData.size() > 0) { + new ButtonColumn(jobsTable, deleteJobAction, DELETE_BUTTON_COLUMN); + } + adjustFinishedJobsWidthColumns(jobsTable); + jobsTable.updateUI(); + } catch (Exception e1) { + // TODO Auto-generated catch block + e1.printStackTrace(); + } + + } } - adjustFinishedJobsWidthColumns(jobsTable); - jobsTable.updateUI(); - } catch (Exception e1) { - // TODO Auto-generated catch block - e1.printStackTrace(); - } - } - } - - }; - tabbedPane.addChangeListener(changeListener); - - activityPanel.add(tabbedPane, BorderLayout.CENTER); - - - final HpcJobManager hpcJobManager = desktop.getHpcJobManager(); - - // Start active job updater - pendingJobUpdater = new PendingHpcJobUpdaterTask( - hpcJobManager, this); - - submittedJobUpdater = new SubmittedHpcJobUpdaterTask( - hpcJobManager, this); - - tabbedPane.setSelectedIndex(-1); - tabbedPane.setSelectedIndex(0); - - startUpdaters(); - } - - private void startUpdaters(){ - pendingTimer.schedule(pendingJobUpdater, 0, PENDING_TIME_INTERVAL); - submittedTimer.schedule(submittedJobUpdater, 0, SUBMITTED_TIME_INTERVAL); - } - - private void stopUpdaters(){ - pendingTimer.cancel(); - submittedTimer.cancel(); - } - - private void adjustActiveJobsWidthColumns(final JTable jobsTable) { - jobsTable.getColumnModel().getColumn(0).setPreferredWidth(20); - jobsTable.getColumnModel().getColumn(1).setPreferredWidth(30); - jobsTable.getColumnModel().getColumn(3).setPreferredWidth(20); - jobsTable.getColumnModel().getColumn(4).setPreferredWidth(40); - jobsTable.getColumnModel().getColumn(8).setPreferredWidth(35); - } - - private void adjustFinishedJobsWidthColumns(final JTable jobsTable) { - jobsTable.getColumnModel().getColumn(0).setPreferredWidth(20); - jobsTable.getColumnModel().getColumn(1).setPreferredWidth(30); - jobsTable.getColumnModel().getColumn(3).setPreferredWidth(20); - jobsTable.getColumnModel().getColumn(4).setPreferredWidth(40); - jobsTable.getColumnModel().getColumn(6).setPreferredWidth(35); - } - - private Vector genActiveJobColumnNames() { - final Vector columnNames = new Vector<>(); - - columnNames.addElement("Job ID"); - columnNames.addElement("Status"); - columnNames.addElement("Added"); - columnNames.addElement("HPC"); - columnNames.addElement("Algorithm"); - columnNames.addElement("Data Upload"); - columnNames.addElement("Knowledge Upload"); - columnNames.addElement("Submitted"); - columnNames.addElement("HPC Job ID"); - columnNames.addElement("lastUpdated"); - columnNames.addElement(""); - - return columnNames; - } - - private Vector genFinishedJobColumnNames() { - final Vector columnNames = new Vector<>(); - - columnNames.addElement("Job ID"); - columnNames.addElement("Status"); - columnNames.addElement("Added"); - columnNames.addElement("HPC"); - columnNames.addElement("Algorithm"); - columnNames.addElement("Submitted"); - columnNames.addElement("HPC Job ID"); - columnNames.addElement("Result Name"); - columnNames.addElement("Finished"); - columnNames.addElement("Canceled"); - columnNames.addElement("lastUpdated"); - columnNames.addElement(""); - - return columnNames; - } - - private Vector> getActiveRowData( - final TetradDesktop desktop, - final List exclusiveHpcAccounts) throws Exception { - final Vector> activeRowData = new Vector<>(); - final HpcJobManager hpcJobManager = desktop.getHpcJobManager(); - Map activeHpcJobInfoMap = null; - - // Pending - Map> pendingHpcJobInfoMap = hpcJobManager - .getPendingHpcJobInfoMap(); - - pendingDisplayHpcJobInfoSet.clear(); - - for (HpcAccount hpcAccount : pendingHpcJobInfoMap.keySet()) { - - if (exclusiveHpcAccounts != null - && !exclusiveHpcAccounts.contains(hpcAccount)) { - continue; - } - - Set pendingHpcJobSet = pendingHpcJobInfoMap - .get(hpcAccount); - for (HpcJobInfo hpcJobInfo : pendingHpcJobSet) { - // For monitoring purpose - pendingDisplayHpcJobInfoSet.add(hpcJobInfo); - - if (activeHpcJobInfoMap == null) { - activeHpcJobInfoMap = new HashMap<>(); - } - activeHpcJobInfoMap.put(hpcJobInfo.getId(), hpcJobInfo); - } + }; + tabbedPane.addChangeListener(changeListener); + + activityPanel.add(tabbedPane, BorderLayout.CENTER); + + final HpcJobManager hpcJobManager = desktop.getHpcJobManager(); + + // Start active job updater + pendingJobUpdater = new PendingHpcJobUpdaterTask(hpcJobManager, this); + + submittedJobUpdater = new SubmittedHpcJobUpdaterTask(hpcJobManager, this); + + tabbedPane.setSelectedIndex(-1); + tabbedPane.setSelectedIndex(0); + + startUpdaters(); } - // Submitted - Map> submittedHpcJobInfoMap = hpcJobManager - .getSubmittedHpcJobInfoMap(); - - submittedDisplayHpcJobInfoSet.clear(); - - for (HpcAccount hpcAccount : submittedHpcJobInfoMap.keySet()) { - - if (exclusiveHpcAccounts != null - && !exclusiveHpcAccounts.contains(hpcAccount)) { - continue; - } - - Set submittedHpcJobSet = submittedHpcJobInfoMap - .get(hpcAccount); - for (HpcJobInfo hpcJobInfo : submittedHpcJobSet) { - // For monitoring purpose - submittedDisplayHpcJobInfoSet.add(hpcJobInfo); - - if (activeHpcJobInfoMap == null) { - activeHpcJobInfoMap = new HashMap<>(); - } - activeHpcJobInfoMap.put(hpcJobInfo.getId(), hpcJobInfo); - } + private void startUpdaters() { + pendingTimer.schedule(pendingJobUpdater, 0, PENDING_TIME_INTERVAL); + submittedTimer.schedule(submittedJobUpdater, 0, SUBMITTED_TIME_INTERVAL); } - if (activeHpcJobInfoMap != null) { + private void stopUpdaters() { + pendingTimer.cancel(); + submittedTimer.cancel(); + } - List activeJobIds = new ArrayList<>( - activeHpcJobInfoMap.keySet()); + private void adjustActiveJobsWidthColumns(final JTable jobsTable) { + jobsTable.getColumnModel().getColumn(0).setPreferredWidth(20); + jobsTable.getColumnModel().getColumn(1).setPreferredWidth(30); + jobsTable.getColumnModel().getColumn(3).setPreferredWidth(20); + jobsTable.getColumnModel().getColumn(4).setPreferredWidth(40); + jobsTable.getColumnModel().getColumn(8).setPreferredWidth(35); + } - Collections.sort(activeJobIds); - Collections.reverse(activeJobIds); + private void adjustFinishedJobsWidthColumns(final JTable jobsTable) { + jobsTable.getColumnModel().getColumn(0).setPreferredWidth(20); + jobsTable.getColumnModel().getColumn(1).setPreferredWidth(30); + jobsTable.getColumnModel().getColumn(3).setPreferredWidth(20); + jobsTable.getColumnModel().getColumn(4).setPreferredWidth(40); + jobsTable.getColumnModel().getColumn(6).setPreferredWidth(35); + } - for (Long jobId : activeJobIds) { + private Vector genActiveJobColumnNames() { + final Vector columnNames = new Vector<>(); + + columnNames.addElement("Job ID"); + columnNames.addElement("Status"); + columnNames.addElement("Added"); + columnNames.addElement("HPC"); + columnNames.addElement("Algorithm"); + columnNames.addElement("Data Upload"); + columnNames.addElement("Knowledge Upload"); + columnNames.addElement("Submitted"); + columnNames.addElement("HPC Job ID"); + columnNames.addElement("lastUpdated"); + columnNames.addElement(""); + + return columnNames; + } - final HpcJobInfo hpcJobInfo = activeHpcJobInfoMap.get(jobId); + private Vector genFinishedJobColumnNames() { + final Vector columnNames = new Vector<>(); + + columnNames.addElement("Job ID"); + columnNames.addElement("Status"); + columnNames.addElement("Added"); + columnNames.addElement("HPC"); + columnNames.addElement("Algorithm"); + columnNames.addElement("Submitted"); + columnNames.addElement("HPC Job ID"); + columnNames.addElement("Result Name"); + columnNames.addElement("Finished"); + columnNames.addElement("Canceled"); + columnNames.addElement("lastUpdated"); + columnNames.addElement(""); + + return columnNames; + } - Vector rowData = new Vector<>(); + private Vector> getActiveRowData(final TetradDesktop desktop, + final List exclusiveHpcAccounts) throws Exception { + final Vector> activeRowData = new Vector<>(); + final HpcJobManager hpcJobManager = desktop.getHpcJobManager(); + Map activeHpcJobInfoMap = null; - HpcJobLog hpcJobLog = hpcJobManager.getHpcJobLog(hpcJobInfo); + // Pending + Map> pendingHpcJobInfoMap = hpcJobManager.getPendingHpcJobInfoMap(); - // Local job id - rowData.add(hpcJobInfo.getId().toString()); + pendingDisplayHpcJobInfoSet.clear(); - int status = hpcJobInfo.getStatus(); + for (HpcAccount hpcAccount : pendingHpcJobInfoMap.keySet()) { - switch (status) { - case -1: - rowData.add("Pending"); - break; - case 0: - rowData.add("Submitted"); - break; - case 1: - rowData.add("Running"); - break; - case 2: - rowData.add("Kill Request"); - break; + if (exclusiveHpcAccounts != null && !exclusiveHpcAccounts.contains(hpcAccount)) { + continue; + } + + Set pendingHpcJobSet = pendingHpcJobInfoMap.get(hpcAccount); + for (HpcJobInfo hpcJobInfo : pendingHpcJobSet) { + // For monitoring purpose + pendingDisplayHpcJobInfoSet.add(hpcJobInfo); + + if (activeHpcJobInfoMap == null) { + activeHpcJobInfoMap = new HashMap<>(); + } + activeHpcJobInfoMap.put(hpcJobInfo.getId(), hpcJobInfo); + } } - // Locally added time - rowData.add(FilePrint.fileTimestamp(hpcJobLog.getAddedTime() - .getTime())); + // Submitted + Map> submittedHpcJobInfoMap = hpcJobManager.getSubmittedHpcJobInfoMap(); - // HPC node name - HpcAccount hpcAccount = hpcJobInfo.getHpcAccount(); - rowData.add(hpcAccount.getConnectionName()); + submittedDisplayHpcJobInfoSet.clear(); - // Algorithm - rowData.add(hpcJobInfo.getAlgorithmName()); + for (HpcAccount hpcAccount : submittedHpcJobInfoMap.keySet()) { - // Dataset uploading progress - AlgorithmParamRequest algorParamReq = hpcJobInfo - .getAlgorithmParamRequest(); - String datasetPath = algorParamReq.getDatasetPath(); + if (exclusiveHpcAccounts != null && !exclusiveHpcAccounts.contains(hpcAccount)) { + continue; + } - int progress = hpcJobManager.getUploadFileProgress(datasetPath); - if (progress > -1 && progress < 100) { - rowData.add("" + progress + "%"); - } else { - rowData.add("Done"); - } + Set submittedHpcJobSet = submittedHpcJobInfoMap.get(hpcAccount); + for (HpcJobInfo hpcJobInfo : submittedHpcJobSet) { + // For monitoring purpose + submittedDisplayHpcJobInfoSet.add(hpcJobInfo); - // Prior Knowledge uploading progress - String priorKnowledgePath = algorParamReq - .getPriorKnowledgePath(); - if (priorKnowledgePath != null) { - progress = hpcJobManager - .getUploadFileProgress(priorKnowledgePath); - if (progress > -1 && progress < 100) { - rowData.add("" + progress + "%"); - } else { - rowData.add("Done"); - } - } else { - rowData.add("Skipped"); + if (activeHpcJobInfoMap == null) { + activeHpcJobInfoMap = new HashMap<>(); + } + activeHpcJobInfoMap.put(hpcJobInfo.getId(), hpcJobInfo); + } } - if (status > -1) { - // Submitted time - rowData.add(FilePrint.fileTimestamp(hpcJobInfo - .getSubmittedTime().getTime())); + if (activeHpcJobInfoMap != null) { - // HPC job id - rowData.add(hpcJobInfo.getPid() != null?"" + hpcJobInfo.getPid():""); + List activeJobIds = new ArrayList<>(activeHpcJobInfoMap.keySet()); - } else { - rowData.add(""); - rowData.add(""); - } + Collections.sort(activeJobIds); + Collections.reverse(activeJobIds); - // Last update time - rowData.add(FilePrint.fileTimestamp(hpcJobLog - .getLastUpdatedTime().getTime())); + for (Long jobId : activeJobIds) { - // Cancel job - rowData.add("Cancel"); + final HpcJobInfo hpcJobInfo = activeHpcJobInfoMap.get(jobId); - activeRowData.add(rowData); - } - } + Vector rowData = new Vector<>(); - return activeRowData; - } - - private Vector> getFinishedRowData( - final TetradDesktop desktop, - final List exclusiveHpcAccounts) throws Exception { - final Vector> finishedRowData = new Vector<>(); - HpcJobManager hpcJobManager = desktop.getHpcJobManager(); - Map finishedHpcJobIdMap = null; - - // Finished jobs - Map> finishedHpcJobInfoMap = hpcJobManager - .getFinishedHpcJobInfoMap(); - for (HpcAccount hpcAccount : finishedHpcJobInfoMap.keySet()) { - - if (exclusiveHpcAccounts != null - && !exclusiveHpcAccounts.contains(hpcAccount)) { - continue; - } - - Set finishedHpcJobSet = finishedHpcJobInfoMap - .get(hpcAccount); - for (HpcJobInfo hpcJobInfo : finishedHpcJobSet) { - if (finishedHpcJobIdMap == null) { - finishedHpcJobIdMap = new HashMap<>(); - } - finishedHpcJobIdMap.put(hpcJobInfo.getId(), hpcJobInfo); - } - } + HpcJobLog hpcJobLog = hpcJobManager.getHpcJobLog(hpcJobInfo); - if (finishedHpcJobIdMap != null) { + // Local job id + rowData.add(hpcJobInfo.getId().toString()); - List finishedJobIds = new ArrayList<>( - finishedHpcJobIdMap.keySet()); + int status = hpcJobInfo.getStatus(); - Collections.sort(finishedJobIds); - Collections.reverse(finishedJobIds); + switch (status) { + case -1: + rowData.add("Pending"); + break; + case 0: + rowData.add("Submitted"); + break; + case 1: + rowData.add("Running"); + break; + case 2: + rowData.add("Kill Request"); + break; + } - for (Long jobId : finishedJobIds) { - final HpcJobInfo hpcJobInfo = finishedHpcJobIdMap.get(jobId); + // Locally added time + rowData.add(FilePrint.fileTimestamp(hpcJobLog.getAddedTime().getTime())); - Vector rowData = new Vector<>(); + // HPC node name + HpcAccount hpcAccount = hpcJobInfo.getHpcAccount(); + rowData.add(hpcAccount.getConnectionName()); - HpcJobLog hpcJobLog = hpcJobManager.getHpcJobLog(hpcJobInfo); + // Algorithm + rowData.add(hpcJobInfo.getAlgorithmName()); - // Local job id - rowData.add(hpcJobInfo.getId().toString()); + // Dataset uploading progress + AlgorithmParamRequest algorParamReq = hpcJobInfo.getAlgorithmParamRequest(); + String datasetPath = algorParamReq.getDatasetPath(); - int status = hpcJobInfo.getStatus(); + int progress = hpcJobManager.getUploadFileProgress(datasetPath); + if (progress > -1 && progress < 100) { + rowData.add("" + progress + "%"); + } else { + rowData.add("Done"); + } - switch (status) { - case 3: - rowData.add("Finished"); - break; - case 4: - rowData.add("Canceled"); - break; - case 5: - rowData.add("Finished"); - break; - case 6: - rowData.add("Error"); - break; - } + // Prior Knowledge uploading progress + String priorKnowledgePath = algorParamReq.getPriorKnowledgePath(); + if (priorKnowledgePath != null) { + progress = hpcJobManager.getUploadFileProgress(priorKnowledgePath); + if (progress > -1 && progress < 100) { + rowData.add("" + progress + "%"); + } else { + rowData.add("Done"); + } + } else { + rowData.add("Skipped"); + } - // Locally added time - rowData.add(FilePrint.fileTimestamp(hpcJobLog.getAddedTime() - .getTime())); - - // HPC node name - HpcAccount hpcAccount = hpcJobInfo.getHpcAccount(); - rowData.add(hpcAccount.getConnectionName()); - - // Algorithm - rowData.add(hpcJobInfo.getAlgorithmName()); - - // Submitted time - rowData.add(hpcJobInfo.getSubmittedTime() != null ? FilePrint - .fileTimestamp(hpcJobInfo.getSubmittedTime().getTime()) - : ""); - - // HPC job id - rowData.add("" + hpcJobInfo.getPid()); - - // Result Name - switch (status) { - case 3: - rowData.add(hpcJobInfo.getResultFileName()); - break; - case 4: - rowData.add(""); - break; - case 5: - rowData.add(hpcJobInfo.getResultFileName()); - break; - case 6: - rowData.add(hpcJobInfo.getErrorResultFileName()); - break; - } + if (status > -1) { + // Submitted time + rowData.add(FilePrint.fileTimestamp(hpcJobInfo.getSubmittedTime().getTime())); + + // HPC job id + rowData.add(hpcJobInfo.getPid() != null ? "" + hpcJobInfo.getPid() : ""); + + } else { + rowData.add(""); + rowData.add(""); + } + + // Last update time + rowData.add(FilePrint.fileTimestamp(hpcJobLog.getLastUpdatedTime().getTime())); + + // Cancel job + rowData.add("Cancel"); - // Finished time - if (status != 4) { - rowData.add(FilePrint.fileTimestamp(hpcJobLog - .getEndedTime().getTime())); - } else { - rowData.add(""); + activeRowData.add(rowData); + } } - // Canceled time - if (status == 4) { - rowData.add(hpcJobLog.getCanceledTime() != null ? FilePrint - .fileTimestamp(hpcJobLog.getCanceledTime() - .getTime()) : ""); - } else { - rowData.add(""); + return activeRowData; + } + + private Vector> getFinishedRowData(final TetradDesktop desktop, + final List exclusiveHpcAccounts) throws Exception { + final Vector> finishedRowData = new Vector<>(); + HpcJobManager hpcJobManager = desktop.getHpcJobManager(); + Map finishedHpcJobIdMap = null; + + // Finished jobs + Map> finishedHpcJobInfoMap = hpcJobManager.getFinishedHpcJobInfoMap(); + for (HpcAccount hpcAccount : finishedHpcJobInfoMap.keySet()) { + + if (exclusiveHpcAccounts != null && !exclusiveHpcAccounts.contains(hpcAccount)) { + continue; + } + + Set finishedHpcJobSet = finishedHpcJobInfoMap.get(hpcAccount); + for (HpcJobInfo hpcJobInfo : finishedHpcJobSet) { + if (finishedHpcJobIdMap == null) { + finishedHpcJobIdMap = new HashMap<>(); + } + finishedHpcJobIdMap.put(hpcJobInfo.getId(), hpcJobInfo); + } } - // Last update time - rowData.add(FilePrint.fileTimestamp(hpcJobLog - .getLastUpdatedTime().getTime())); + if (finishedHpcJobIdMap != null) { + + List finishedJobIds = new ArrayList<>(finishedHpcJobIdMap.keySet()); + + Collections.sort(finishedJobIds); + Collections.reverse(finishedJobIds); + + for (Long jobId : finishedJobIds) { + final HpcJobInfo hpcJobInfo = finishedHpcJobIdMap.get(jobId); + + Vector rowData = new Vector<>(); + + HpcJobLog hpcJobLog = hpcJobManager.getHpcJobLog(hpcJobInfo); + + // Local job id + rowData.add(hpcJobInfo.getId().toString()); + + int status = hpcJobInfo.getStatus(); + + switch (status) { + case 3: + rowData.add("Finished"); + break; + case 4: + rowData.add("Canceled"); + break; + case 5: + rowData.add("Finished"); + break; + case 6: + rowData.add("Error"); + break; + } + + // Locally added time + rowData.add(FilePrint.fileTimestamp(hpcJobLog.getAddedTime().getTime())); + + // HPC node name + HpcAccount hpcAccount = hpcJobInfo.getHpcAccount(); + rowData.add(hpcAccount.getConnectionName()); + + // Algorithm + rowData.add(hpcJobInfo.getAlgorithmName()); + + // Submitted time + rowData.add(hpcJobInfo.getSubmittedTime() != null + ? FilePrint.fileTimestamp(hpcJobInfo.getSubmittedTime().getTime()) : ""); + + // HPC job id + rowData.add("" + hpcJobInfo.getPid()); + + // Result Name + switch (status) { + case 3: + rowData.add(hpcJobInfo.getResultFileName()); + break; + case 4: + rowData.add(""); + break; + case 5: + rowData.add(hpcJobInfo.getResultFileName()); + break; + case 6: + rowData.add(hpcJobInfo.getErrorResultFileName()); + break; + } + + // Finished time + if (status != 4) { + rowData.add(FilePrint.fileTimestamp(hpcJobLog.getEndedTime().getTime())); + } else { + rowData.add(""); + } + + // Canceled time + if (status == 4) { + rowData.add(hpcJobLog.getCanceledTime() != null + ? FilePrint.fileTimestamp(hpcJobLog.getCanceledTime().getTime()) : ""); + } else { + rowData.add(""); + } + + // Last update time + rowData.add(FilePrint.fileTimestamp(hpcJobLog.getLastUpdatedTime().getTime())); + + // Delete job from db + rowData.add("Delete"); + + finishedRowData.add(rowData); + } - // Delete job from db - rowData.add("Delete"); + } - finishedRowData.add(rowData); - } + return finishedRowData; + } + public synchronized Set getPendingDisplayHpcJobInfoSet() { + + return pendingDisplayHpcJobInfoSet; } - return finishedRowData; - } - - public synchronized Set getPendingDisplayHpcJobInfoSet() { - - return pendingDisplayHpcJobInfoSet; - } - - public synchronized void removePendingDisplayHpcJobInfo( - final Set removingJobSet) { - for (final HpcJobInfo hpcJobInfo : removingJobSet) { - for (HpcJobInfo pendingJob: pendingDisplayHpcJobInfoSet) { - if (hpcJobInfo.getId() == pendingJob.getId()) { - pendingDisplayHpcJobInfoSet.remove(pendingJob); - continue; + public synchronized void removePendingDisplayHpcJobInfo(final Set removingJobSet) { + for (final HpcJobInfo hpcJobInfo : removingJobSet) { + for (HpcJobInfo pendingJob : pendingDisplayHpcJobInfoSet) { + if (hpcJobInfo.getId() == pendingJob.getId()) { + pendingDisplayHpcJobInfoSet.remove(pendingJob); + continue; + } + } } - } } - } - public Set getSubmittedDisplayHpcJobInfoSet() { - return submittedDisplayHpcJobInfoSet; - } + public Set getSubmittedDisplayHpcJobInfoSet() { + return submittedDisplayHpcJobInfoSet; + } - public synchronized void addSubmittedDisplayHpcJobInfo( - final Set submittedJobSet) { - for(HpcJobInfo job : submittedJobSet){ - System.out.println("addSubmittedDisplayHpcJobInfo: job: " + job.getId()); + public synchronized void addSubmittedDisplayHpcJobInfo(final Set submittedJobSet) { + for (HpcJobInfo job : submittedJobSet) { + System.out.println("addSubmittedDisplayHpcJobInfo: job: " + job.getId()); + } + submittedDisplayHpcJobInfoSet.addAll(submittedJobSet); } - submittedDisplayHpcJobInfoSet.addAll(submittedJobSet); - } - - public synchronized void removeSubmittedDisplayHpcJobInfo( - final Set removingJobSet) { - for (final HpcJobInfo hpcJobInfo : removingJobSet) { - for (Iterator it = submittedDisplayHpcJobInfoSet - .iterator(); it.hasNext();) { - final HpcJobInfo submittedJob = it.next(); - if (hpcJobInfo.getId() == submittedJob.getId()) { - submittedDisplayHpcJobInfoSet.remove(hpcJobInfo); - continue; + + public synchronized void removeSubmittedDisplayHpcJobInfo(final Set removingJobSet) { + for (final HpcJobInfo hpcJobInfo : removingJobSet) { + for (Iterator it = submittedDisplayHpcJobInfoSet.iterator(); it.hasNext();) { + final HpcJobInfo submittedJob = it.next(); + if (hpcJobInfo.getId() == submittedJob.getId()) { + submittedDisplayHpcJobInfoSet.remove(hpcJobInfo); + continue; + } + } } - } + } - } + public synchronized void removeSubmittedDisplayJobFromActiveTableModel(final Set finishedJobSet) { + DefaultTableModel model = (DefaultTableModel) jobsTable.getModel(); + Map rowMap = new HashMap<>(); + for (int row = 0; row < model.getRowCount(); row++) { + rowMap.put(Long.valueOf(model.getValueAt(row, ID_COLUMN).toString()), row); + } + + for (final HpcJobInfo hpcJobInfo : finishedJobSet) { + if (rowMap.containsKey(hpcJobInfo.getId())) { + model.removeRow(rowMap.get(hpcJobInfo.getId())); + } + } + + } - public synchronized void removeSubmittedDisplayJobFromActiveTableModel( - final Set finishedJobSet) { - DefaultTableModel model = (DefaultTableModel) jobsTable.getModel(); - Map rowMap = new HashMap<>(); - for (int row = 0; row < model.getRowCount(); row++) { - rowMap.put( - Long.valueOf(model.getValueAt(row, - ID_COLUMN).toString()), row); + public TableModel getJobsTableModel() { + return jobsTable.getModel(); } - for (final HpcJobInfo hpcJobInfo : finishedJobSet) { - if (rowMap.containsKey(hpcJobInfo.getId())) { - model.removeRow(rowMap.get(hpcJobInfo.getId())); - } + public int selectedTabbedPaneIndex() { + return tabbedPane.getSelectedIndex(); } - } - - public TableModel getJobsTableModel() { - return jobsTable.getModel(); - } - - public int selectedTabbedPaneIndex(){ - return tabbedPane.getSelectedIndex(); - } - - @Override - public boolean finalizeEditor() { - stopUpdaters(); - return true; - } + @Override + public boolean finalizeEditor() { + stopUpdaters(); + return true; + } } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/editor/HpcJobInfoTableModel.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/editor/HpcJobInfoTableModel.java index 07c5a71692..3ba8116ab2 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/editor/HpcJobInfoTableModel.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/editor/HpcJobInfoTableModel.java @@ -17,16 +17,16 @@ public class HpcJobInfoTableModel extends DefaultTableModel { private final int buttonColumn; - public HpcJobInfoTableModel(final Vector> activeRowData, - final Vector activeColumnNames, final int buttonColumn) { - super(activeRowData, activeColumnNames); - this.buttonColumn = buttonColumn; + public HpcJobInfoTableModel(final Vector> activeRowData, final Vector activeColumnNames, + final int buttonColumn) { + super(activeRowData, activeColumnNames); + this.buttonColumn = buttonColumn; } public boolean isCellEditable(int row, int column) { - if (column == buttonColumn) - return true; - return false; + if (column == buttonColumn) + return true; + return false; } } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/editor/LoadHpcGraphJsonTableModel.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/editor/LoadHpcGraphJsonTableModel.java new file mode 100644 index 0000000000..aa9a553661 --- /dev/null +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/editor/LoadHpcGraphJsonTableModel.java @@ -0,0 +1,44 @@ +/* + * Copyright (C) 2015 University of Pittsburgh. + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, + * MA 02110-1301 USA + */ +package edu.cmu.tetradapp.app.hpc.editor; + +import java.util.Vector; + +import javax.swing.table.DefaultTableModel; + +/** + * + * Feb 14, 2017 7:22:42 PM + * + * @author Chirayu (Kong) Wongchokprasitti + * + */ +public class LoadHpcGraphJsonTableModel extends DefaultTableModel { + + private static final long serialVersionUID = 2896909588298923241L; + + public LoadHpcGraphJsonTableModel(final Vector> rowData, final Vector columnNames) { + super(rowData, columnNames); + } + + public boolean isCellEditable(int row, int column) { + return false; + } + +} diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/manager/HpcAccountManager.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/manager/HpcAccountManager.java index 8dccceaafe..2464731812 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/manager/HpcAccountManager.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/manager/HpcAccountManager.java @@ -14,30 +14,30 @@ */ public class HpcAccountManager { - private final HpcAccountService hpcAccountService; - - private final JsonWebTokenManager jsonWebTokenManager; - - public HpcAccountManager(final org.hibernate.Session session) { - this.hpcAccountService = new HpcAccountService(session); - this.jsonWebTokenManager = new JsonWebTokenManager(); - } - - public List getHpcAccounts() { - List hpcAccounts = hpcAccountService.get(); - return hpcAccounts; - } - - public void saveAccount(final HpcAccount hpcAccount) { - hpcAccountService.update(hpcAccount); - } - - public void removeAccount(final HpcAccount hpcAccount) { - hpcAccountService.remove(hpcAccount); - } - - public JsonWebTokenManager getJsonWebTokenManager() { - return jsonWebTokenManager; - } + private final HpcAccountService hpcAccountService; + + private final JsonWebTokenManager jsonWebTokenManager; + + public HpcAccountManager(final org.hibernate.Session session) { + this.hpcAccountService = new HpcAccountService(session); + this.jsonWebTokenManager = new JsonWebTokenManager(); + } + + public List getHpcAccounts() { + List hpcAccounts = hpcAccountService.get(); + return hpcAccounts; + } + + public void saveAccount(final HpcAccount hpcAccount) { + hpcAccountService.update(hpcAccount); + } + + public void removeAccount(final HpcAccount hpcAccount) { + hpcAccountService.remove(hpcAccount); + } + + public JsonWebTokenManager getJsonWebTokenManager() { + return jsonWebTokenManager; + } } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/manager/HpcAccountService.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/manager/HpcAccountService.java index a4ef9e3b25..556e3fe368 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/manager/HpcAccountService.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/manager/HpcAccountService.java @@ -16,54 +16,48 @@ */ public class HpcAccountService { - private final RemoteDataFileService remoteDataService; + private final RemoteDataFileService remoteDataService; - private final DataUploadService dataUploadService; + private final DataUploadService dataUploadService; - private final JobQueueService jobQueueService; + private final JobQueueService jobQueueService; - private final ResultService resultService; - - public HpcAccountService(final HpcAccount hpcAccount, - final int simultaneousUpload) throws Exception { + private final ResultService resultService; - final String username = hpcAccount.getUsername(); - final String password = hpcAccount.getPassword(); - final String scheme = hpcAccount.getScheme(); - final String hostname = hpcAccount.getHostname(); - final int port = hpcAccount.getPort(); + public HpcAccountService(final HpcAccount hpcAccount, final int simultaneousUpload) throws Exception { - RestHttpsClient restHttpsClient = new RestHttpsClient(username, password, scheme, - hostname, port); + final String username = hpcAccount.getUsername(); + final String password = hpcAccount.getPassword(); + final String scheme = hpcAccount.getScheme(); + final String hostname = hpcAccount.getHostname(); + final int port = hpcAccount.getPort(); - this.remoteDataService = new RemoteDataFileService(restHttpsClient, - scheme, hostname, port); + RestHttpsClient restHttpsClient = new RestHttpsClient(username, password, scheme, hostname, port); - this.dataUploadService = new DataUploadService(restHttpsClient, - simultaneousUpload, scheme, hostname, port); + this.remoteDataService = new RemoteDataFileService(restHttpsClient, scheme, hostname, port); - this.jobQueueService = new JobQueueService(restHttpsClient, scheme, - hostname, port); + this.dataUploadService = new DataUploadService(restHttpsClient, simultaneousUpload, scheme, hostname, port); - this.resultService = new ResultService(restHttpsClient, - scheme, hostname, port); - - } + this.jobQueueService = new JobQueueService(restHttpsClient, scheme, hostname, port); - public RemoteDataFileService getRemoteDataService() { - return remoteDataService; - } + this.resultService = new ResultService(restHttpsClient, scheme, hostname, port); - public DataUploadService getDataUploadService() { - return dataUploadService; - } + } - public JobQueueService getJobQueueService() { - return jobQueueService; - } + public RemoteDataFileService getRemoteDataService() { + return remoteDataService; + } - public ResultService getResultService() { - return resultService; - } + public DataUploadService getDataUploadService() { + return dataUploadService; + } + + public JobQueueService getJobQueueService() { + return jobQueueService; + } + + public ResultService getResultService() { + return resultService; + } } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/manager/HpcJobManager.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/manager/HpcJobManager.java index 8ebdc5d468..d1625471ab 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/manager/HpcJobManager.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/manager/HpcJobManager.java @@ -41,442 +41,399 @@ */ public class HpcJobManager { - private final HpcJobLogService hpcJobLogService; + private final HpcJobLogService hpcJobLogService; - private final HpcJobLogDetailService hpcJobLogDetailService; + private final HpcJobLogDetailService hpcJobLogDetailService; - private final HpcJobInfoService hpcJobInfoService; + private final HpcJobInfoService hpcJobInfoService; - private final int simultaneousUpload; + private final int simultaneousUpload; - private final ExecutorService executorService; + private final ExecutorService executorService; - private final Timer timer; + private final Timer timer; - private int TIME_INTERVAL = 10000; + private int TIME_INTERVAL = 10000; - private final Map uploadFileProgressMap; + private final Map uploadFileProgressMap; - private final Map hpcGraphResultMap; + private final Map hpcGraphResultMap; - private final Map> pendingHpcJobInfoMap; + private final Map> pendingHpcJobInfoMap; - private final Map> submittedHpcJobInfoMap; + private final Map> submittedHpcJobInfoMap; - private final Map hpcAccountServiceMap; + private final Map hpcAccountServiceMap; - public HpcJobManager(final org.hibernate.Session session, - final int simultaneousUpload) { - this.hpcJobLogService = new HpcJobLogService(session); - this.hpcJobLogDetailService = new HpcJobLogDetailService(session); - this.hpcJobInfoService = new HpcJobInfoService(session); - this.simultaneousUpload = simultaneousUpload; + public HpcJobManager(final org.hibernate.Session session, final int simultaneousUpload) { + this.hpcJobLogService = new HpcJobLogService(session); + this.hpcJobLogDetailService = new HpcJobLogDetailService(session); + this.hpcJobInfoService = new HpcJobInfoService(session); + this.simultaneousUpload = simultaneousUpload; - executorService = Executors.newFixedThreadPool(simultaneousUpload); + executorService = Executors.newFixedThreadPool(simultaneousUpload); - uploadFileProgressMap = new HashMap<>(); - pendingHpcJobInfoMap = new HashMap<>(); - submittedHpcJobInfoMap = new HashMap<>(); - hpcGraphResultMap = new HashMap<>(); - hpcAccountServiceMap = new HashMap<>(); + uploadFileProgressMap = new HashMap<>(); + pendingHpcJobInfoMap = new HashMap<>(); + submittedHpcJobInfoMap = new HashMap<>(); + hpcGraphResultMap = new HashMap<>(); + hpcAccountServiceMap = new HashMap<>(); - resumePreProcessJobs(); - resumeSubmittedHpcJobInfos(); + resumePreProcessJobs(); + resumeSubmittedHpcJobInfos(); - this.timer = new Timer(); + this.timer = new Timer(); - startHpcJobScheduler(); - } + startHpcJobScheduler(); + } + + public Map> getPendingHpcJobInfoMap() { + return pendingHpcJobInfoMap; + } + + private synchronized void resumePreProcessJobs() { + // Lookup on DB for HpcJobInfo with status -1 (Pending) + + List pendingHpcJobInfo = hpcJobInfoService.findByStatus(-1); + if (pendingHpcJobInfo != null) { + for (HpcJobInfo hpcJobInfo : pendingHpcJobInfo) { + System.out.println("resumePreProcessJobs: " + hpcJobInfo.getAlgorithmName() + " : " + + hpcJobInfo.getHpcAccount().getConnectionName() + " : " + + hpcJobInfo.getAlgorithmParamRequest().getDatasetPath()); + + final HpcAccount hpcAccount = hpcJobInfo.getHpcAccount(); + + Set hpcJobInfos = pendingHpcJobInfoMap.get(hpcAccount); + if (hpcJobInfos == null) { + hpcJobInfos = new LinkedHashSet<>(); + } + hpcJobInfos.add(hpcJobInfo); + pendingHpcJobInfoMap.put(hpcAccount, hpcJobInfos); + + HpcJobPreProcessTask preProcessTask = new HpcJobPreProcessTask(hpcJobInfo); + executorService.submit(preProcessTask); + } + } else { + System.out.println("resumePreProcessJobs: no pending jobs to be resumed"); + } + } + + public void startHpcJobScheduler() { + System.out.println("startHpcJobScheduler"); + HpcJobsScheduledTask hpcScheduledTask = new HpcJobsScheduledTask(); + timer.schedule(hpcScheduledTask, 1000, TIME_INTERVAL); + } + + public synchronized void submitNewHpcJobToQueue(final HpcJobInfo hpcJobInfo, + final GeneralAlgorithmEditor generalAlgorithmEditor) { + + hpcJobInfoService.add(hpcJobInfo); + System.out.println("hpcJobInfo: id: " + hpcJobInfo.getId()); + + HpcJobLog hpcJobLog = new HpcJobLog(); + hpcJobLog.setAddedTime(new Date(System.currentTimeMillis())); + hpcJobLog.setHpcJobInfo(hpcJobInfo); + hpcJobLogService.update(hpcJobLog); + System.out.println("HpcJobLog: id: " + hpcJobLog.getId()); - public Map> getPendingHpcJobInfoMap() { - return pendingHpcJobInfoMap; - } + HpcJobLogDetail hpcJobLogDetail = new HpcJobLogDetail(); + hpcJobLogDetail.setAddedTime(new Date()); + hpcJobLogDetail.setHpcJobLog(hpcJobLog); + hpcJobLogDetail.setJobState(-1);// Pending + hpcJobLogDetail.setProgress("Pending"); + hpcJobLogDetailService.add(hpcJobLogDetail); + System.out.println("HpcJobLogDetail: id: " + hpcJobLogDetail.getId()); - private synchronized void resumePreProcessJobs() { - // Lookup on DB for HpcJobInfo with status -1 (Pending) + hpcGraphResultMap.put(hpcJobInfo, generalAlgorithmEditor); - List pendingHpcJobInfo = hpcJobInfoService.findByStatus(-1); - if (pendingHpcJobInfo != null) { - for (HpcJobInfo hpcJobInfo : pendingHpcJobInfo) { - System.out.println("resumePreProcessJobs: " - + hpcJobInfo.getAlgorithmName() - + " : " - + hpcJobInfo.getHpcAccount().getConnectionName() - + " : " - + hpcJobInfo.getAlgorithmParamRequest() - .getDatasetPath()); + // Put a new pre-process task into hpc job queue + HpcJobPreProcessTask preProcessTask = new HpcJobPreProcessTask(hpcJobInfo); + // Added a job to the pending list final HpcAccount hpcAccount = hpcJobInfo.getHpcAccount(); - Set hpcJobInfos = pendingHpcJobInfoMap - .get(hpcAccount); + Set hpcJobInfos = pendingHpcJobInfoMap.get(hpcAccount); if (hpcJobInfos == null) { - hpcJobInfos = new LinkedHashSet<>(); + hpcJobInfos = new LinkedHashSet<>(); } hpcJobInfos.add(hpcJobInfo); pendingHpcJobInfoMap.put(hpcAccount, hpcJobInfos); - HpcJobPreProcessTask preProcessTask = new HpcJobPreProcessTask( - hpcJobInfo); - executorService.submit(preProcessTask); - } - } else { - System.out - .println("resumePreProcessJobs: no pending jobs to be resumed"); + executorService.execute(preProcessTask); + } + + public void stopHpcJobScheduler() { + timer.cancel(); + } + + public void restartHpcJobScheduler() { + stopHpcJobScheduler(); + startHpcJobScheduler(); } - } - - public void startHpcJobScheduler() { - System.out.println("startHpcJobScheduler"); - HpcJobsScheduledTask hpcScheduledTask = new HpcJobsScheduledTask(); - timer.schedule(hpcScheduledTask, 1000, TIME_INTERVAL); - } - - public synchronized void submitNewHpcJobToQueue( - final HpcJobInfo hpcJobInfo, - final GeneralAlgorithmEditor generalAlgorithmEditor) { - - hpcJobInfoService.add(hpcJobInfo); - System.out.println("hpcJobInfo: id: " + hpcJobInfo.getId()); - - HpcJobLog hpcJobLog = new HpcJobLog(); - hpcJobLog.setAddedTime(new Date(System.currentTimeMillis())); - hpcJobLog.setHpcJobInfo(hpcJobInfo); - hpcJobLogService.update(hpcJobLog); - System.out.println("HpcJobLog: id: " + hpcJobLog.getId()); - - HpcJobLogDetail hpcJobLogDetail = new HpcJobLogDetail(); - hpcJobLogDetail.setAddedTime(new Date()); - hpcJobLogDetail.setHpcJobLog(hpcJobLog); - hpcJobLogDetail.setJobState(-1);// Pending - hpcJobLogDetail.setProgress("Pending"); - hpcJobLogDetailService.add(hpcJobLogDetail); - System.out.println("HpcJobLogDetail: id: " + hpcJobLogDetail.getId()); - - hpcGraphResultMap.put(hpcJobInfo, generalAlgorithmEditor); - - // Put a new pre-process task into hpc job queue - HpcJobPreProcessTask preProcessTask = new HpcJobPreProcessTask( - hpcJobInfo); - - // Added a job to the pending list - final HpcAccount hpcAccount = hpcJobInfo.getHpcAccount(); - - Set hpcJobInfos = pendingHpcJobInfoMap.get(hpcAccount); - if (hpcJobInfos == null) { - hpcJobInfos = new LinkedHashSet<>(); + + public HpcJobLog getHpcJobLog(final HpcJobInfo hpcJobInfo) { + return hpcJobLogService.findByHpcJobInfo(hpcJobInfo); } - hpcJobInfos.add(hpcJobInfo); - pendingHpcJobInfoMap.put(hpcAccount, hpcJobInfos); - - executorService.execute(preProcessTask); - } - - public void stopHpcJobScheduler() { - timer.cancel(); - } - - public void restartHpcJobScheduler() { - stopHpcJobScheduler(); - startHpcJobScheduler(); - } - - public HpcJobLog getHpcJobLog(final HpcJobInfo hpcJobInfo) { - return hpcJobLogService.findByHpcJobInfo(hpcJobInfo); - } - - public void appendHpcJobLogDetail(final HpcJobLogDetail hpcJobLogDetail) { - hpcJobLogDetailService.add(hpcJobLogDetail); - } - - public HpcJobInfo findHpcJobInfoById(final long id) { - return hpcJobInfoService.findById(id); - } - - public void updateHpcJobInfo(final HpcJobInfo hpcJobInfo) { - hpcJobInfoService.update(hpcJobInfo); - updateSubmittedHpcJobInfo(hpcJobInfo); - } - - public void updateHpcJobLog(final HpcJobLog hpcJobLog) { - hpcJobLogService.update(hpcJobLog); - } - - public void logHpcJobLogDetail(final HpcJobLog hpcJobLog, int jobStatus, - String jobProgress) { - Date now = new Date(System.currentTimeMillis()); - hpcJobLog.setLastUpdatedTime(now); - if (jobStatus == 3) {// Finished - hpcJobLog.setEndedTime(now); + + public void appendHpcJobLogDetail(final HpcJobLogDetail hpcJobLogDetail) { + hpcJobLogDetailService.add(hpcJobLogDetail); } - if (jobStatus == 4) {// Killed - hpcJobLog.setCanceledTime(now); + + public HpcJobInfo findHpcJobInfoById(final long id) { + return hpcJobInfoService.findById(id); + } + + public void updateHpcJobInfo(final HpcJobInfo hpcJobInfo) { + hpcJobInfoService.update(hpcJobInfo); + updateSubmittedHpcJobInfo(hpcJobInfo); } - updateHpcJobLog(hpcJobLog); - - HpcJobLogDetail hpcJobLogDetail = new HpcJobLogDetail(); - hpcJobLogDetail.setAddedTime(new Date(System.currentTimeMillis())); - hpcJobLogDetail.setHpcJobLog(hpcJobLog); - hpcJobLogDetail.setJobState(jobStatus); - hpcJobLogDetail.setProgress(jobProgress); - appendHpcJobLogDetail(hpcJobLogDetail); - } - - public synchronized void updateUploadFileProgress(final String datasetPath, - int percentageProgress) { - uploadFileProgressMap.put(datasetPath, percentageProgress); - } - - public int getUploadFileProgress(final String dataPath) { - int progress = -1; - if (uploadFileProgressMap.containsKey(dataPath)) { - progress = uploadFileProgressMap.get(dataPath).intValue(); + + public void updateHpcJobLog(final HpcJobLog hpcJobLog) { + hpcJobLogService.update(hpcJobLog); } - return progress; - } - - public void resumeSubmittedHpcJobInfos() { - // Lookup on DB for HpcJobInfo with status 0 (Submitted); 1 (Running); 2 - // (Kill Request) - for (int status = 0; status <= 2; status++) { - //System.out.println("resumeSubmittedHpcJobInfos: " - // + "looping status: " + status); - List submittedHpcJobInfo = hpcJobInfoService - .findByStatus(status); - if (submittedHpcJobInfo != null) { - for (HpcJobInfo hpcJobInfo : submittedHpcJobInfo) { - addNewSubmittedHpcJob(hpcJobInfo); + + public void logHpcJobLogDetail(final HpcJobLog hpcJobLog, int jobStatus, String jobProgress) { + Date now = new Date(System.currentTimeMillis()); + hpcJobLog.setLastUpdatedTime(now); + if (jobStatus == 3) {// Finished + hpcJobLog.setEndedTime(now); + } + if (jobStatus == 4) {// Killed + hpcJobLog.setCanceledTime(now); } - } + updateHpcJobLog(hpcJobLog); + + HpcJobLogDetail hpcJobLogDetail = new HpcJobLogDetail(); + hpcJobLogDetail.setAddedTime(new Date(System.currentTimeMillis())); + hpcJobLogDetail.setHpcJobLog(hpcJobLog); + hpcJobLogDetail.setJobState(jobStatus); + hpcJobLogDetail.setProgress(jobProgress); + appendHpcJobLogDetail(hpcJobLogDetail); } - } - - public GeneralAlgorithmEditor getGeneralAlgorithmEditor( - final HpcJobInfo hpcJobInfo) { - return hpcGraphResultMap.get(hpcJobInfo); - } - - public synchronized void addNewSubmittedHpcJob(final HpcJobInfo hpcJobInfo) { - HpcAccount hpcAccount = hpcJobInfo.getHpcAccount(); - System.out.println("addNewSubmittedHpcJob: connection: " - + hpcAccount.getConnectionName()); - System.out.println("addNewSubmittedHpcJob: algorithm: " - + hpcJobInfo.getAlgorithmName()); - System.out.println("addNewSubmittedHpcJob: status: " - + hpcJobInfo.getStatus()); - System.out.println("addNewSubmittedHpcJob: " + "pid: " - + hpcJobInfo.getPid()); - - Set hpcJobInfos = submittedHpcJobInfoMap.get(hpcAccount); - if (hpcJobInfos == null) { - hpcJobInfos = new LinkedHashSet<>(); + + public synchronized void updateUploadFileProgress(final String datasetPath, int percentageProgress) { + uploadFileProgressMap.put(datasetPath, percentageProgress); } - hpcJobInfos.add(hpcJobInfo); - submittedHpcJobInfoMap.put(hpcAccount, hpcJobInfos); - - removePendingHpcJob(hpcJobInfo); - } - - public synchronized void removeFinishedHpcJob(final HpcJobInfo hpcJobInfo) { - HpcAccount hpcAccount = hpcJobInfo.getHpcAccount(); - System.out.println("removedFinishedHpcJob: connection: " - + hpcAccount.getConnectionName()); - System.out.println("removedFinishedHpcJob: algorithm: " - + hpcJobInfo.getAlgorithmName()); - System.out.println("removedFinishedHpcJob: status: " - + hpcJobInfo.getStatus()); - System.out - .println("removedFinishedHpcJob: pid: " + hpcJobInfo.getPid()); - Set hpcJobInfos = submittedHpcJobInfoMap.get(hpcAccount); - if (hpcJobInfos != null) { - - //System.out.println("removeFinishedHpcJob: hpcJobInfos not null"); - - for (HpcJobInfo jobInfo : hpcJobInfos) { - if (jobInfo.getId() == hpcJobInfo.getId()) { - - //System.out.println("removeFinishedHpcJob: Found hpcJobInfo in the submittedHpcJobInfoMap & removed it!"); - - hpcJobInfos.remove(jobInfo); + + public int getUploadFileProgress(final String dataPath) { + int progress = -1; + if (uploadFileProgressMap.containsKey(dataPath)) { + progress = uploadFileProgressMap.get(dataPath).intValue(); } - } + return progress; + } + + public void resumeSubmittedHpcJobInfos() { + // Lookup on DB for HpcJobInfo with status 0 (Submitted); 1 (Running); 2 + // (Kill Request) + for (int status = 0; status <= 2; status++) { + // System.out.println("resumeSubmittedHpcJobInfos: " + // + "looping status: " + status); + List submittedHpcJobInfo = hpcJobInfoService.findByStatus(status); + if (submittedHpcJobInfo != null) { + for (HpcJobInfo hpcJobInfo : submittedHpcJobInfo) { + addNewSubmittedHpcJob(hpcJobInfo); + } + } + } + } + + public GeneralAlgorithmEditor getGeneralAlgorithmEditor(final HpcJobInfo hpcJobInfo) { + return hpcGraphResultMap.get(hpcJobInfo); + } - if (hpcJobInfos.isEmpty()) { - submittedHpcJobInfoMap.remove(hpcAccount); - } else { + public synchronized void addNewSubmittedHpcJob(final HpcJobInfo hpcJobInfo) { + HpcAccount hpcAccount = hpcJobInfo.getHpcAccount(); + System.out.println("addNewSubmittedHpcJob: connection: " + hpcAccount.getConnectionName()); + System.out.println("addNewSubmittedHpcJob: algorithm: " + hpcJobInfo.getAlgorithmName()); + System.out.println("addNewSubmittedHpcJob: status: " + hpcJobInfo.getStatus()); + System.out.println("addNewSubmittedHpcJob: " + "pid: " + hpcJobInfo.getPid()); + + Set hpcJobInfos = submittedHpcJobInfoMap.get(hpcAccount); + if (hpcJobInfos == null) { + hpcJobInfos = new LinkedHashSet<>(); + } + hpcJobInfos.add(hpcJobInfo); submittedHpcJobInfoMap.put(hpcAccount, hpcJobInfos); - } + + removePendingHpcJob(hpcJobInfo); } - hpcGraphResultMap.remove(hpcJobInfo); - } - - public synchronized void removePendingHpcJob(final HpcJobInfo hpcJobInfo) { - HpcAccount hpcAccount = hpcJobInfo.getHpcAccount(); - System.out.println("removedPendingHpcJob: connection: " - + hpcAccount.getConnectionName()); - System.out.println("removedPendingHpcJob: algorithm: " - + hpcJobInfo.getAlgorithmName()); - System.out.println("removedPendingHpcJob: status: " - + hpcJobInfo.getStatus()); - System.out.println("removedPendingHpcJob: pid: " + hpcJobInfo.getPid()); - - Set hpcJobInfos = pendingHpcJobInfoMap.get(hpcAccount); - if (hpcJobInfos != null) { - - //System.out.println("removedPendingHpcJob: hpcJobInfos not null"); - - for (HpcJobInfo jobInfo : hpcJobInfos) { - if (jobInfo.getId() == hpcJobInfo.getId()) { - - //System.out.println("removedPendingHpcJob: Found hpcJobInfo in the pendingHpcJobInfoMap & removed it!"); - - hpcJobInfos.remove(jobInfo); + + public synchronized void removeFinishedHpcJob(final HpcJobInfo hpcJobInfo) { + HpcAccount hpcAccount = hpcJobInfo.getHpcAccount(); + System.out.println("removedFinishedHpcJob: connection: " + hpcAccount.getConnectionName()); + System.out.println("removedFinishedHpcJob: algorithm: " + hpcJobInfo.getAlgorithmName()); + System.out.println("removedFinishedHpcJob: status: " + hpcJobInfo.getStatus()); + System.out.println("removedFinishedHpcJob: pid: " + hpcJobInfo.getPid()); + Set hpcJobInfos = submittedHpcJobInfoMap.get(hpcAccount); + if (hpcJobInfos != null) { + + // System.out.println("removeFinishedHpcJob: hpcJobInfos not null"); + + for (HpcJobInfo jobInfo : hpcJobInfos) { + if (jobInfo.getId() == hpcJobInfo.getId()) { + + // System.out.println("removeFinishedHpcJob: Found + // hpcJobInfo in the submittedHpcJobInfoMap & removed it!"); + + hpcJobInfos.remove(jobInfo); + } + } + + if (hpcJobInfos.isEmpty()) { + submittedHpcJobInfoMap.remove(hpcAccount); + } else { + submittedHpcJobInfoMap.put(hpcAccount, hpcJobInfos); + } } - } + hpcGraphResultMap.remove(hpcJobInfo); + } - if (hpcJobInfos.isEmpty()) { - pendingHpcJobInfoMap.remove(hpcAccount); - } else { - pendingHpcJobInfoMap.put(hpcAccount, hpcJobInfos); - } + public synchronized void removePendingHpcJob(final HpcJobInfo hpcJobInfo) { + HpcAccount hpcAccount = hpcJobInfo.getHpcAccount(); + System.out.println("removedPendingHpcJob: connection: " + hpcAccount.getConnectionName()); + System.out.println("removedPendingHpcJob: algorithm: " + hpcJobInfo.getAlgorithmName()); + System.out.println("removedPendingHpcJob: status: " + hpcJobInfo.getStatus()); + System.out.println("removedPendingHpcJob: pid: " + hpcJobInfo.getPid()); + + Set hpcJobInfos = pendingHpcJobInfoMap.get(hpcAccount); + if (hpcJobInfos != null) { + + // System.out.println("removedPendingHpcJob: hpcJobInfos not null"); + + for (HpcJobInfo jobInfo : hpcJobInfos) { + if (jobInfo.getId() == hpcJobInfo.getId()) { + + // System.out.println("removedPendingHpcJob: Found + // hpcJobInfo in the pendingHpcJobInfoMap & removed it!"); + + hpcJobInfos.remove(jobInfo); + } + } + + if (hpcJobInfos.isEmpty()) { + pendingHpcJobInfoMap.remove(hpcAccount); + } else { + pendingHpcJobInfoMap.put(hpcAccount, hpcJobInfos); + } + } } - } - - public Map> getSubmittedHpcJobInfoMap() { - return submittedHpcJobInfoMap; - } - - public synchronized void updateSubmittedHpcJobInfo( - final HpcJobInfo hpcJobInfo) { - final HpcAccount hpcAccount = hpcJobInfo.getHpcAccount(); - Set hpcJobInfos = submittedHpcJobInfoMap.get(hpcAccount); - if (hpcJobInfos != null) { - - //System.out.println("updateSubmittedHpcJobInfo: hpcJobInfos not null"); - - for (HpcJobInfo jobInfo : hpcJobInfos) { - if (jobInfo.getId() == hpcJobInfo.getId()) { - - //System.out.println("updateSubmittedHpcJobInfo: Found hpcJobInfo in the submittedHpcJobInfoMap & removed it!"); - - hpcJobInfos.remove(jobInfo); + + public Map> getSubmittedHpcJobInfoMap() { + return submittedHpcJobInfoMap; + } + + public synchronized void updateSubmittedHpcJobInfo(final HpcJobInfo hpcJobInfo) { + final HpcAccount hpcAccount = hpcJobInfo.getHpcAccount(); + Set hpcJobInfos = submittedHpcJobInfoMap.get(hpcAccount); + if (hpcJobInfos != null) { + + // System.out.println("updateSubmittedHpcJobInfo: hpcJobInfos not + // null"); + + for (HpcJobInfo jobInfo : hpcJobInfos) { + if (jobInfo.getId() == hpcJobInfo.getId()) { + + // System.out.println("updateSubmittedHpcJobInfo: Found + // hpcJobInfo in the submittedHpcJobInfoMap & removed it!"); + + hpcJobInfos.remove(jobInfo); + } + } + + hpcJobInfos.add(hpcJobInfo); + submittedHpcJobInfoMap.put(hpcAccount, hpcJobInfos); } - } - - hpcJobInfos.add(hpcJobInfo); - submittedHpcJobInfoMap.put(hpcAccount, hpcJobInfos); } - } - - public synchronized HpcAccountService getHpcAccountService( - final HpcAccount hpcAccount) throws Exception { - HpcAccountService hpcAccountService = hpcAccountServiceMap - .get(hpcAccount); - if (hpcAccountService == null) { - hpcAccountService = new HpcAccountService(hpcAccount, - simultaneousUpload); - hpcAccountServiceMap.put(hpcAccount, hpcAccountService); + + public synchronized HpcAccountService getHpcAccountService(final HpcAccount hpcAccount) throws Exception { + HpcAccountService hpcAccountService = hpcAccountServiceMap.get(hpcAccount); + if (hpcAccountService == null) { + hpcAccountService = new HpcAccountService(hpcAccount, simultaneousUpload); + hpcAccountServiceMap.put(hpcAccount, hpcAccountService); + } + return hpcAccountService; } - return hpcAccountService; - } - - public synchronized void removeHpcAccountService(final HpcAccount hpcAccount) { - hpcAccountServiceMap.remove(hpcAccount); - } - - public synchronized Map> getFinishedHpcJobInfoMap() { - final Map> finishedHpcJobInfoMap = new HashMap<>(); - // Lookup on DB for HpcJobInfo with status 3 (Finished); 4 (Killed); - // 5 (Result Downloaded); 6 (Error Result Downloaded); - for (int status = 3; status <= 6; status++) { - //System.out.println("getFinishedHpcJobInfoMap: " - // + "looping status: " + status); - List finishedHpcJobInfo = hpcJobInfoService - .findByStatus(status); - if (finishedHpcJobInfo != null) { - for (HpcJobInfo hpcJobInfo : finishedHpcJobInfo) { - final HpcAccount hpcAccount = hpcJobInfo.getHpcAccount(); - Set hpcJobInfos = finishedHpcJobInfoMap - .get(hpcAccount); - if (hpcJobInfos == null) { - hpcJobInfos = new LinkedHashSet<>(); - } - hpcJobInfos.add(hpcJobInfo); - finishedHpcJobInfoMap.put(hpcAccount, hpcJobInfos); + + public synchronized void removeHpcAccountService(final HpcAccount hpcAccount) { + hpcAccountServiceMap.remove(hpcAccount); + } + + public synchronized Map> getFinishedHpcJobInfoMap() { + final Map> finishedHpcJobInfoMap = new HashMap<>(); + // Lookup on DB for HpcJobInfo with status 3 (Finished); 4 (Killed); + // 5 (Result Downloaded); 6 (Error Result Downloaded); + for (int status = 3; status <= 6; status++) { + // System.out.println("getFinishedHpcJobInfoMap: " + // + "looping status: " + status); + List finishedHpcJobInfo = hpcJobInfoService.findByStatus(status); + if (finishedHpcJobInfo != null) { + for (HpcJobInfo hpcJobInfo : finishedHpcJobInfo) { + final HpcAccount hpcAccount = hpcJobInfo.getHpcAccount(); + Set hpcJobInfos = finishedHpcJobInfoMap.get(hpcAccount); + if (hpcJobInfos == null) { + hpcJobInfos = new LinkedHashSet<>(); + } + hpcJobInfos.add(hpcJobInfo); + finishedHpcJobInfoMap.put(hpcAccount, hpcJobInfos); + } + } + } + return finishedHpcJobInfoMap; + } + + public HpcJobInfo requestHpcJobKilled(final HpcJobInfo hpcJobInfo) throws Exception { + final HpcAccount hpcAccount = hpcJobInfo.getHpcAccount(); + + HpcAccountService hpcAccountService = getHpcAccountService(hpcAccount); + + JobQueueService jobQueueService = hpcAccountService.getJobQueueService(); + TetradDesktop desktop = (TetradDesktop) DesktopController.getInstance(); + final HpcAccountManager hpcAccountManager = desktop.getHpcAccountManager(); + JsonWebTokenManager jsonWebTokenManager = hpcAccountManager.getJsonWebTokenManager(); + jobQueueService.requestJobKilled(hpcJobInfo.getPid(), jsonWebTokenManager.getJsonWebToken(hpcAccount)); + JobInfo jobInfo = jobQueueService.getJobStatus(hpcJobInfo.getPid(), + jsonWebTokenManager.getJsonWebToken(hpcAccount)); + + if (jobInfo != null) { + hpcJobInfo.setStatus(jobInfo.getStatus()); + return hpcJobInfo; } - } + + return null; + + } + + public List getRemoteActiveJobs(final HpcAccountManager hpcAccountManager, final HpcAccount hpcAccount) + throws ClientProtocolException, URISyntaxException, IOException, Exception { + HpcAccountService hpcAccountService = getHpcAccountService(hpcAccount); + JobQueueService jobQueueService = hpcAccountService.getJobQueueService(); + return jobQueueService.getActiveJobs(HpcAccountUtils.getJsonWebToken(hpcAccountManager, hpcAccount)); } - return finishedHpcJobInfoMap; - } - - public HpcJobInfo requestHpcJobKilled(final HpcJobInfo hpcJobInfo) - throws Exception { - final HpcAccount hpcAccount = hpcJobInfo.getHpcAccount(); - - HpcAccountService hpcAccountService = getHpcAccountService(hpcAccount); - - JobQueueService jobQueueService = hpcAccountService - .getJobQueueService(); - TetradDesktop desktop = (TetradDesktop) DesktopController.getInstance(); - final HpcAccountManager hpcAccountManager = desktop - .getHpcAccountManager(); - JsonWebTokenManager jsonWebTokenManager = hpcAccountManager - .getJsonWebTokenManager(); - jobQueueService.requestJobKilled(hpcJobInfo.getPid(), - jsonWebTokenManager.getJsonWebToken(hpcAccount)); - JobInfo jobInfo = jobQueueService.getJobStatus(hpcJobInfo.getPid(), - jsonWebTokenManager.getJsonWebToken(hpcAccount)); - - if (jobInfo != null) { - hpcJobInfo.setStatus(jobInfo.getStatus()); - return hpcJobInfo; + + public Set listRemoteAlgorithmResultFiles(final HpcAccountManager hpcAccountManager, + final HpcAccount hpcAccount) throws ClientProtocolException, URISyntaxException, IOException, Exception { + HpcAccountService hpcAccountService = getHpcAccountService(hpcAccount); + ResultService resultService = hpcAccountService.getResultService(); + return resultService.listAlgorithmResultFiles(HpcAccountUtils.getJsonWebToken(hpcAccountManager, hpcAccount)); } - return null; - - } - - public List getRemoteActiveJobs( - final HpcAccountManager hpcAccountManager, - final HpcAccount hpcAccount) throws ClientProtocolException, - URISyntaxException, IOException, Exception { - HpcAccountService hpcAccountService = getHpcAccountService(hpcAccount); - JobQueueService jobQueueService = hpcAccountService - .getJobQueueService(); - return jobQueueService.getActiveJobs(HpcAccountUtils.getJsonWebToken( - hpcAccountManager, hpcAccount)); - } - - public Set listRemoteAlgorithmResultFiles( - final HpcAccountManager hpcAccountManager, - final HpcAccount hpcAccount) throws ClientProtocolException, - URISyntaxException, IOException, Exception { - HpcAccountService hpcAccountService = getHpcAccountService(hpcAccount); - ResultService resultService = hpcAccountService.getResultService(); - return resultService.listAlgorithmResultFiles(HpcAccountUtils - .getJsonWebToken(hpcAccountManager, hpcAccount)); - } - - public String downloadAlgorithmResultFile( - final HpcAccountManager hpcAccountManager, - final HpcAccount hpcAccount, final String errorResultFileName) - throws ClientProtocolException, URISyntaxException, IOException, - Exception { - HpcAccountService hpcAccountService = getHpcAccountService(hpcAccount); - ResultService resultService = hpcAccountService.getResultService(); - return resultService.downloadAlgorithmResultFile(errorResultFileName, - HpcAccountUtils.getJsonWebToken(hpcAccountManager, hpcAccount)); - } - - public synchronized void removeHpcJobInfoTransaction( - final HpcJobInfo hpcJobInfo) { - HpcJobLog hpcJobLog = hpcJobLogService.findByHpcJobInfo(hpcJobInfo); - List logDetailList = hpcJobLogDetailService - .findByHpcJobLog(hpcJobLog); - for (HpcJobLogDetail logDetail : logDetailList) { - hpcJobLogDetailService.remove(logDetail); + public String downloadAlgorithmResultFile(final HpcAccountManager hpcAccountManager, final HpcAccount hpcAccount, + final String errorResultFileName) + throws ClientProtocolException, URISyntaxException, IOException, Exception { + HpcAccountService hpcAccountService = getHpcAccountService(hpcAccount); + ResultService resultService = hpcAccountService.getResultService(); + return resultService.downloadAlgorithmResultFile(errorResultFileName, + HpcAccountUtils.getJsonWebToken(hpcAccountManager, hpcAccount)); + } + + public synchronized void removeHpcJobInfoTransaction(final HpcJobInfo hpcJobInfo) { + HpcJobLog hpcJobLog = hpcJobLogService.findByHpcJobInfo(hpcJobInfo); + List logDetailList = hpcJobLogDetailService.findByHpcJobLog(hpcJobLog); + for (HpcJobLogDetail logDetail : logDetailList) { + hpcJobLogDetailService.remove(logDetail); + } + hpcJobLogService.remove(hpcJobLog); } - hpcJobLogService.remove(hpcJobLog); - } } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/manager/JsonWebTokenManager.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/manager/JsonWebTokenManager.java index a92acd0e78..db3ea1c36f 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/manager/JsonWebTokenManager.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/manager/JsonWebTokenManager.java @@ -18,53 +18,48 @@ */ public class JsonWebTokenManager { - private final Map jsonWebTokenMap; + private final Map jsonWebTokenMap; - private final Map jsonWebTokenRequestTimeMap; + private final Map jsonWebTokenRequestTimeMap; - private final static long TOKEN_VALID_TIME = 60 * 60 * 1000;// 1-hour - // expired time - // in - // millisecond - private boolean locked = false; + private final static long TOKEN_VALID_TIME = 60 * 60 * 1000;// 1-hour + // expired time + // in + // millisecond + private boolean locked = false; - public JsonWebTokenManager() { - jsonWebTokenMap = new HashMap<>(); - jsonWebTokenRequestTimeMap = new HashMap<>(); - } - - public synchronized JsonWebToken getJsonWebToken(final HpcAccount hpcAccount) - throws Exception { - if (locked) { - Thread.sleep(100); - getJsonWebToken(hpcAccount); + public JsonWebTokenManager() { + jsonWebTokenMap = new HashMap<>(); + jsonWebTokenRequestTimeMap = new HashMap<>(); } - locked = true; - long now = System.currentTimeMillis(); - JsonWebToken jsonWebToken = jsonWebTokenMap.get(hpcAccount); - if (jsonWebToken == null - || (now - jsonWebTokenRequestTimeMap.get(hpcAccount).getTime()) > TOKEN_VALID_TIME) { - final String username = hpcAccount.getUsername(); - final String password = hpcAccount.getPassword(); - final String scheme = hpcAccount.getScheme(); - final String hostname = hpcAccount.getHostname(); - final int port = hpcAccount.getPort(); + public synchronized JsonWebToken getJsonWebToken(final HpcAccount hpcAccount) throws Exception { + if (locked) { + Thread.sleep(100); + getJsonWebToken(hpcAccount); + } + locked = true; + long now = System.currentTimeMillis(); + JsonWebToken jsonWebToken = jsonWebTokenMap.get(hpcAccount); + if (jsonWebToken == null || (now - jsonWebTokenRequestTimeMap.get(hpcAccount).getTime()) > TOKEN_VALID_TIME) { + + final String username = hpcAccount.getUsername(); + final String password = hpcAccount.getPassword(); + final String scheme = hpcAccount.getScheme(); + final String hostname = hpcAccount.getHostname(); + final int port = hpcAccount.getPort(); - RestHttpsClient restClient = new RestHttpsClient(username, - password, scheme, hostname, port); + RestHttpsClient restClient = new RestHttpsClient(username, password, scheme, hostname, port); - // Authentication - UserService userService = new UserService(restClient, scheme, - hostname, port); - // JWT token is valid for 1 hour - jsonWebToken = userService.requestJWT(); - jsonWebTokenMap.put(hpcAccount, jsonWebToken); - jsonWebTokenRequestTimeMap.put(hpcAccount, - new Date(System.currentTimeMillis())); + // Authentication + UserService userService = new UserService(restClient, scheme, hostname, port); + // JWT token is valid for 1 hour + jsonWebToken = userService.requestJWT(); + jsonWebTokenMap.put(hpcAccount, jsonWebToken); + jsonWebTokenRequestTimeMap.put(hpcAccount, new Date(System.currentTimeMillis())); + } + locked = false; + return jsonWebToken; } - locked = false; - return jsonWebToken; - } } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/task/HpcJobPreProcessTask.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/task/HpcJobPreProcessTask.java index b19746d605..eb0b0a414e 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/task/HpcJobPreProcessTask.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/task/HpcJobPreProcessTask.java @@ -33,253 +33,222 @@ */ public class HpcJobPreProcessTask implements Runnable { - private final HpcJobInfo hpcJobInfo; - - public HpcJobPreProcessTask(final HpcJobInfo hpcJobInfo) { - this.hpcJobInfo = hpcJobInfo; - } - - @Override - public void run() { - TetradDesktop desktop = (TetradDesktop) DesktopController.getInstance(); - while (desktop == null) { - try { - Thread.sleep(1000); - } catch (InterruptedException e) { - e.printStackTrace(); - } + private final HpcJobInfo hpcJobInfo; + + public HpcJobPreProcessTask(final HpcJobInfo hpcJobInfo) { + this.hpcJobInfo = hpcJobInfo; } - final HpcAccountManager hpcAccountManager = desktop - .getHpcAccountManager(); - final HpcJobManager hpcJobManager = desktop.getHpcJobManager(); - - HpcAccount hpcAccount = hpcJobInfo.getHpcAccount(); - - AlgorithmParamRequest algorParamReq = hpcJobInfo - .getAlgorithmParamRequest(); - String datasetPath = algorParamReq.getDatasetPath(); - String priorKnowledgePath = algorParamReq.getPriorKnowledgePath(); - - try { - HpcAccountService hpcAccountService = hpcJobManager - .getHpcAccountService(hpcAccount); - - HpcJobLog hpcJobLog = hpcJobManager.getHpcJobLog(hpcJobInfo); - - String log = "Initiated connection to " - + hpcAccount.getConnectionName(); - System.out.println(log); - hpcJobManager.logHpcJobLogDetail(hpcJobLog, -1, log); - - log = "datasetPath: " + datasetPath; - System.out.println(log); - Path file = Paths.get(datasetPath); - // Get file's MD5 hash and use it as its identifier - String md5 = algorParamReq.getDatasetMd5(); - - // Initiate data uploading progress - hpcJobManager.updateUploadFileProgress(datasetPath, 0); - - Path prior = null; - if (priorKnowledgePath != null) { - log = "priorKnowledgePath: " + priorKnowledgePath; - System.out.println(log); - prior = Paths.get(priorKnowledgePath); - - // Initiate prior knowledge uploading progress - hpcJobManager.updateUploadFileProgress(priorKnowledgePath, 0); - } - - // Check if this dataset already exists with this md5 hash - RemoteDataFileService remoteDataService = hpcAccountService - .getRemoteDataService(); - - DataFile dataFile = HpcAccountUtils.getRemoteDataFile( - hpcAccountManager, remoteDataService, hpcAccount, md5); - DataUploadService dataUploadService = hpcAccountService - .getDataUploadService(); - - // If not, upload the file - if (dataFile == null) { - log = "Started uploading " + file.getFileName().toString(); - System.out.println(log); - dataUploadService.startUpload(file, HpcAccountUtils - .getJsonWebToken(hpcAccountManager, hpcAccount)); - hpcJobManager.logHpcJobLogDetail(hpcJobLog, -1, log); - - int progress; - while ((progress = dataUploadService.getUploadJobStatus(file - .toAbsolutePath().toString())) < 100) { - // System.out.println("Uploading " - // + file.toAbsolutePath().toString() + " Progress: " - // + progress + "%"); - hpcJobManager.updateUploadFileProgress(datasetPath, - progress); - Thread.sleep(10); - } - hpcJobManager.updateUploadFileProgress(datasetPath, progress); - - log = "Finished uploading " + file.getFileName().toString(); - System.out.println(log); - hpcJobManager.logHpcJobLogDetail(hpcJobLog, -1, log); - - // Get remote datafile - dataFile = HpcAccountUtils.getRemoteDataFile(hpcAccountManager, - remoteDataService, hpcAccount, md5); - - HpcAccountUtils - .summarizeDataset(remoteDataService, algorParamReq, - dataFile.getId(), HpcAccountUtils - .getJsonWebToken(hpcAccountManager, - hpcAccount)); - log = "Summarized " + file.getFileName().toString(); - System.out.println(log); - hpcJobManager.logHpcJobLogDetail(hpcJobLog, -1, log); - } else { - log = "Skipped uploading " + file.getFileName().toString(); - System.out.println(log); - - hpcJobManager.updateUploadFileProgress(datasetPath, -1); - - hpcJobManager.logHpcJobLogDetail(hpcJobLog, -1, log); - - if (dataFile.getFileSummary().getVariableType() == null) { - HpcAccountUtils.summarizeDataset(remoteDataService, - algorParamReq, dataFile.getId(), HpcAccountUtils - .getJsonWebToken(hpcAccountManager, - hpcAccount)); - log = "Summarized " + file.getFileName().toString(); - System.out.println(log); - hpcJobManager.logHpcJobLogDetail(hpcJobLog, -1, - "Summarized " + file.getFileName().toString()); + @Override + public void run() { + TetradDesktop desktop = (TetradDesktop) DesktopController.getInstance(); + while (desktop == null) { + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + e.printStackTrace(); + } } + final HpcAccountManager hpcAccountManager = desktop.getHpcAccountManager(); + final HpcJobManager hpcJobManager = desktop.getHpcJobManager(); - } - - DataFile priorKnowledgeFile = null; - - // Prior Knowledge File - if (prior != null) { - // Get prior knowledge file Id - md5 = algorParamReq.getPriorKnowledgeMd5(); - - priorKnowledgeFile = HpcAccountUtils - .getRemotePriorKnowledgeFile(hpcAccountManager, - remoteDataService, hpcAccount, md5); - - if (priorKnowledgeFile == null) { - // Upload prior knowledge file - dataUploadService.startUpload(prior, HpcAccountUtils - .getJsonWebToken(hpcAccountManager, hpcAccount)); + HpcAccount hpcAccount = hpcJobInfo.getHpcAccount(); + + AlgorithmParamRequest algorParamReq = hpcJobInfo.getAlgorithmParamRequest(); + String datasetPath = algorParamReq.getDatasetPath(); + String priorKnowledgePath = algorParamReq.getPriorKnowledgePath(); + + try { + HpcAccountService hpcAccountService = hpcJobManager.getHpcAccountService(hpcAccount); + + HpcJobLog hpcJobLog = hpcJobManager.getHpcJobLog(hpcJobInfo); + + String log = "Initiated connection to " + hpcAccount.getConnectionName(); + System.out.println(log); + hpcJobManager.logHpcJobLogDetail(hpcJobLog, -1, log); + + log = "datasetPath: " + datasetPath; + System.out.println(log); + Path file = Paths.get(datasetPath); + // Get file's MD5 hash and use it as its identifier + String md5 = algorParamReq.getDatasetMd5(); + + // Initiate data uploading progress + hpcJobManager.updateUploadFileProgress(datasetPath, 0); - log = "Started uploading Prior Knowledge File"; - System.out.println(log); - hpcJobManager.logHpcJobLogDetail(hpcJobLog, -1, log); + Path prior = null; + if (priorKnowledgePath != null) { + log = "priorKnowledgePath: " + priorKnowledgePath; + System.out.println(log); + prior = Paths.get(priorKnowledgePath); - int progress; - while ((progress = dataUploadService - .getUploadJobStatus(prior.toAbsolutePath() - .toString())) < 100) { - hpcJobManager.updateUploadFileProgress( - priorKnowledgePath, progress); - Thread.sleep(10); - } - - hpcJobManager.updateUploadFileProgress(priorKnowledgePath, - progress); - - priorKnowledgeFile = HpcAccountUtils - .getRemotePriorKnowledgeFile(hpcAccountManager, - remoteDataService, hpcAccount, md5); - - log = "Finished uploading Prior Knowledge File"; - System.out.println(log); - hpcJobManager.logHpcJobLogDetail(hpcJobLog, -1, log); + // Initiate prior knowledge uploading progress + hpcJobManager.updateUploadFileProgress(priorKnowledgePath, 0); + } + // Check if this dataset already exists with this md5 hash + RemoteDataFileService remoteDataService = hpcAccountService.getRemoteDataService(); + + DataFile dataFile = HpcAccountUtils.getRemoteDataFile(hpcAccountManager, remoteDataService, hpcAccount, + md5); + DataUploadService dataUploadService = hpcAccountService.getDataUploadService(); + + // If not, upload the file + if (dataFile == null) { + log = "Started uploading " + file.getFileName().toString(); + System.out.println(log); + dataUploadService.startUpload(file, HpcAccountUtils.getJsonWebToken(hpcAccountManager, hpcAccount)); + hpcJobManager.logHpcJobLogDetail(hpcJobLog, -1, log); + + int progress; + while ((progress = dataUploadService.getUploadJobStatus(file.toAbsolutePath().toString())) < 100) { + // System.out.println("Uploading " + // + file.toAbsolutePath().toString() + " Progress: " + // + progress + "%"); + hpcJobManager.updateUploadFileProgress(datasetPath, progress); + Thread.sleep(10); + } + + hpcJobManager.updateUploadFileProgress(datasetPath, progress); + + log = "Finished uploading " + file.getFileName().toString(); + System.out.println(log); + hpcJobManager.logHpcJobLogDetail(hpcJobLog, -1, log); + + // Get remote datafile + dataFile = HpcAccountUtils.getRemoteDataFile(hpcAccountManager, remoteDataService, hpcAccount, md5); + + HpcAccountUtils.summarizeDataset(remoteDataService, algorParamReq, dataFile.getId(), + HpcAccountUtils.getJsonWebToken(hpcAccountManager, hpcAccount)); + log = "Summarized " + file.getFileName().toString(); + System.out.println(log); + hpcJobManager.logHpcJobLogDetail(hpcJobLog, -1, log); + } else { + log = "Skipped uploading " + file.getFileName().toString(); + System.out.println(log); + + hpcJobManager.updateUploadFileProgress(datasetPath, -1); + + hpcJobManager.logHpcJobLogDetail(hpcJobLog, -1, log); + + if (dataFile.getFileSummary().getVariableType() == null) { + HpcAccountUtils.summarizeDataset(remoteDataService, algorParamReq, dataFile.getId(), + HpcAccountUtils.getJsonWebToken(hpcAccountManager, hpcAccount)); + log = "Summarized " + file.getFileName().toString(); + System.out.println(log); + hpcJobManager.logHpcJobLogDetail(hpcJobLog, -1, "Summarized " + file.getFileName().toString()); + } + + } + + DataFile priorKnowledgeFile = null; + + // Prior Knowledge File + if (prior != null) { + // Get prior knowledge file Id + md5 = algorParamReq.getPriorKnowledgeMd5(); + + priorKnowledgeFile = HpcAccountUtils.getRemotePriorKnowledgeFile(hpcAccountManager, remoteDataService, + hpcAccount, md5); + + if (priorKnowledgeFile == null) { + // Upload prior knowledge file + dataUploadService.startUpload(prior, + HpcAccountUtils.getJsonWebToken(hpcAccountManager, hpcAccount)); + + log = "Started uploading Prior Knowledge File"; + System.out.println(log); + hpcJobManager.logHpcJobLogDetail(hpcJobLog, -1, log); + + int progress; + while ((progress = dataUploadService.getUploadJobStatus(prior.toAbsolutePath().toString())) < 100) { + hpcJobManager.updateUploadFileProgress(priorKnowledgePath, progress); + Thread.sleep(10); + } + + hpcJobManager.updateUploadFileProgress(priorKnowledgePath, progress); + + priorKnowledgeFile = HpcAccountUtils.getRemotePriorKnowledgeFile(hpcAccountManager, + remoteDataService, hpcAccount, md5); + + log = "Finished uploading Prior Knowledge File"; + System.out.println(log); + hpcJobManager.logHpcJobLogDetail(hpcJobLog, -1, log); + + } + + } + + // Algorithm Job Preparation + edu.pitt.dbmi.ccd.rest.client.dto.algo.AlgorithmParamRequest paramRequest = new edu.pitt.dbmi.ccd.rest.client.dto.algo.AlgorithmParamRequest(); + paramRequest.setDatasetFileId(dataFile.getId()); + + Map dataValidation = new HashMap<>(); + dataValidation.put("skipUniqueVarName", false); + System.out.println("dataValidation: skipUniqueVarName: false"); + if (algorParamReq.getVariableType().equalsIgnoreCase("discrete")) { + dataValidation.put("skipNonzeroVariance", false); + System.out.println("dataValidation: skipNonzeroVariance: false"); + } else { + dataValidation.put("skipCategoryLimit", false); + System.out.println("dataValidation: skipCategoryLimit: false"); + } + paramRequest.setDataValidation(dataValidation); + + Map algorithmParameters = new HashMap<>(); + for (AlgorithmParameter param : algorParamReq.getAlgorithmParameters()) { + algorithmParameters.put(param.getParameter(), param.getValue()); + System.out.println("AlgorithmParameter: " + param.getParameter() + " : " + param.getValue()); + } + + if (priorKnowledgeFile != null) { + algorithmParameters.put("priorKnowledgeFileId", priorKnowledgeFile.getId()); + System.out.println("priorKnowledgeFileId: " + priorKnowledgeFile.getId()); + } + paramRequest.setAlgorithmParameters(algorithmParameters); + + Map jvmOptions = new HashMap<>(); + for (JvmOption jvmOption : algorParamReq.getJvmOptions()) { + jvmOptions.put(jvmOption.getParameter(), jvmOption.getValue()); + System.out.println("JvmOption: " + jvmOption.getParameter() + " : " + jvmOption.getValue()); + } + paramRequest.setJvmOptions(jvmOptions); + + // Submit a job + String algorithmName = hpcJobInfo.getAlgorithmName(); + JobQueueService jobQueueService = hpcAccountService.getJobQueueService(); + JobInfo jobInfo = jobQueueService.addToRemoteQueue(algorithmName, paramRequest, + HpcAccountUtils.getJsonWebToken(hpcAccountManager, hpcAccount)); + + // Log the job submission + hpcJobInfo.setSubmittedTime(new Date(System.currentTimeMillis())); + hpcJobInfo.setStatus(0); // Submitted + hpcJobInfo.setPid(jobInfo.getId()); + hpcJobInfo.setResultFileName(jobInfo.getResultFileName()); + hpcJobInfo.setResultJsonFileName(jobInfo.getResultJsonFileName()); + hpcJobInfo.setErrorResultFileName(jobInfo.getErrorResultFileName()); + + hpcJobManager.updateHpcJobInfo(hpcJobInfo); + + log = "Submitted job to " + hpcAccount.getConnectionName(); + System.out.println(log); + hpcJobManager.logHpcJobLogDetail(hpcJobLog, 0, log); + + System.out.println( + "HpcJobPreProcessTask: HpcJobInfo: id : " + hpcJobInfo.getId() + " : pid : " + hpcJobInfo.getPid() + + " : " + hpcJobInfo.getAlgorithmName() + " : " + hpcJobInfo.getResultFileName()); + + hpcJobManager.addNewSubmittedHpcJob(hpcJobInfo); + + } catch (Exception e) { + // TODO Auto-generated catch block + e.printStackTrace(); } - } - - // Algorithm Job Preparation - edu.pitt.dbmi.ccd.rest.client.dto.algo.AlgorithmParamRequest paramRequest = new edu.pitt.dbmi.ccd.rest.client.dto.algo.AlgorithmParamRequest(); - paramRequest.setDatasetFileId(dataFile.getId()); - - Map dataValidation = new HashMap<>(); - dataValidation.put("skipUniqueVarName", false); - System.out.println("dataValidation: skipUniqueVarName: false"); - if (algorParamReq.getVariableType().equalsIgnoreCase("discrete")) { - dataValidation.put("skipNonzeroVariance", false); - System.out - .println("dataValidation: skipNonzeroVariance: false"); - } else { - dataValidation.put("skipCategoryLimit", false); - System.out.println("dataValidation: skipCategoryLimit: false"); - } - paramRequest.setDataValidation(dataValidation); - - Map algorithmParameters = new HashMap<>(); - for (AlgorithmParameter param : algorParamReq - .getAlgorithmParameters()) { - algorithmParameters.put(param.getParameter(), param.getValue()); - System.out.println("AlgorithmParameter: " - + param.getParameter() + " : " + param.getValue()); - } - - if (priorKnowledgeFile != null) { - algorithmParameters.put("priorKnowledgeFileId", - priorKnowledgeFile.getId()); - System.out.println("priorKnowledgeFileId: " - + priorKnowledgeFile.getId()); - } - paramRequest.setAlgorithmParameters(algorithmParameters); - - Map jvmOptions = new HashMap<>(); - for (JvmOption jvmOption : algorParamReq.getJvmOptions()) { - jvmOptions.put(jvmOption.getParameter(), jvmOption.getValue()); - System.out.println("JvmOption: " + jvmOption.getParameter() - + " : " + jvmOption.getValue()); - } - paramRequest.setJvmOptions(jvmOptions); - - // Submit a job - String algorithmName = hpcJobInfo.getAlgorithmName(); - JobQueueService jobQueueService = hpcAccountService - .getJobQueueService(); - JobInfo jobInfo = jobQueueService.addToRemoteQueue(algorithmName, - paramRequest, HpcAccountUtils.getJsonWebToken( - hpcAccountManager, hpcAccount)); - - // Log the job submission - hpcJobInfo.setSubmittedTime(new Date(System.currentTimeMillis())); - hpcJobInfo.setStatus(0); // Submitted - hpcJobInfo.setPid(jobInfo.getId()); - hpcJobInfo.setResultFileName(jobInfo.getResultFileName()); - hpcJobInfo.setResultJsonFileName(jobInfo.getResultJsonFileName()); - hpcJobInfo.setErrorResultFileName(jobInfo.getErrorResultFileName()); - - hpcJobManager.updateHpcJobInfo(hpcJobInfo); - - log = "Submitted job to " + hpcAccount.getConnectionName(); - System.out.println(log); - hpcJobManager.logHpcJobLogDetail(hpcJobLog, 0, log); - - System.out.println("HpcJobPreProcessTask: HpcJobInfo: id : " - + hpcJobInfo.getId() + " : pid : " + hpcJobInfo.getPid() - + " : " + hpcJobInfo.getAlgorithmName() + " : " - + hpcJobInfo.getResultFileName()); - - hpcJobManager.addNewSubmittedHpcJob(hpcJobInfo); - - } catch (Exception e) { - // TODO Auto-generated catch block - e.printStackTrace(); } - } - - public HpcJobInfo getHpcJobInfo() { - return hpcJobInfo; - } + public HpcJobInfo getHpcJobInfo() { + return hpcJobInfo; + } } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/task/HpcJobsScheduledTask.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/task/HpcJobsScheduledTask.java index 78a38f2b80..3440128912 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/task/HpcJobsScheduledTask.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/task/HpcJobsScheduledTask.java @@ -32,304 +32,251 @@ */ public class HpcJobsScheduledTask extends TimerTask { - public HpcJobsScheduledTask() { - } - - // Pooling job status from HPC nodes - @Override - public void run() { - TetradDesktop desktop = (TetradDesktop) DesktopController.getInstance(); - if (desktop == null) - return; - - final HpcAccountManager hpcAccountManager = desktop - .getHpcAccountManager(); - final HpcJobManager hpcJobManager = desktop.getHpcJobManager(); - - System.out.println("HpcJobsScheduledTask: " - + new Date(System.currentTimeMillis())); - - // Load active jobs: Status (0 = Submitted; 1 = Running; 2 = Kill - // Request) - Map> submittedHpcJobInfos = hpcJobManager - .getSubmittedHpcJobInfoMap(); - if (submittedHpcJobInfos.size() == 0) { - System.out.println("Submitted job pool is empty!"); - } else { - System.out.println("Submitted job pool has " - + submittedHpcJobInfos.keySet().size() + " hpcAccount" - + (submittedHpcJobInfos.keySet().size() > 1 ? "s" : "")); + public HpcJobsScheduledTask() { } - for (HpcAccount hpcAccount : submittedHpcJobInfos.keySet()) { + // Pooling job status from HPC nodes + @Override + public void run() { + TetradDesktop desktop = (TetradDesktop) DesktopController.getInstance(); + if (desktop == null) + return; - System.out.println("HpcJobsScheduledTask: " - + hpcAccount.getConnectionName()); + final HpcAccountManager hpcAccountManager = desktop.getHpcAccountManager(); + final HpcJobManager hpcJobManager = desktop.getHpcJobManager(); - Set hpcJobInfos = submittedHpcJobInfos.get(hpcAccount); - // Pid-HpcJobInfo map - Map hpcJobInfoMap = new HashMap<>(); - for (HpcJobInfo hpcJobInfo : hpcJobInfos) { - if (hpcJobInfo.getPid() != null) { - long pid = hpcJobInfo.getPid().longValue(); - hpcJobInfoMap.put(pid, hpcJobInfo); - - System.out.println("id: " + hpcJobInfo.getId() + " : " - + hpcJobInfo.getAlgorithmName() + ": pid: " + pid - + " : " + hpcJobInfo.getResultFileName()); + System.out.println("HpcJobsScheduledTask: " + new Date(System.currentTimeMillis())); + // Load active jobs: Status (0 = Submitted; 1 = Running; 2 = Kill + // Request) + Map> submittedHpcJobInfos = hpcJobManager.getSubmittedHpcJobInfoMap(); + if (submittedHpcJobInfos.size() == 0) { + System.out.println("Submitted job pool is empty!"); } else { + System.out.println("Submitted job pool has " + submittedHpcJobInfos.keySet().size() + " hpcAccount" + + (submittedHpcJobInfos.keySet().size() > 1 ? "s" : "")); + } - System.out.println("id: " + hpcJobInfo.getId() + " : " - + hpcJobInfo.getAlgorithmName() + ": no pid! : " - + hpcJobInfo.getResultFileName()); + for (HpcAccount hpcAccount : submittedHpcJobInfos.keySet()) { - hpcJobInfos.remove(hpcJobInfo); - } - } - - // Finished job map - HashMap finishedJobMap = new HashMap<>(); - for (HpcJobInfo job : hpcJobInfos) { - finishedJobMap.put(job.getPid(), job); - } - - try { - List jobInfos = hpcJobManager.getRemoteActiveJobs( - hpcAccountManager, hpcAccount); - - for (JobInfo jobInfo : jobInfos) { - System.out.println("Remote pid: " + jobInfo.getId() + " : " - + jobInfo.getAlgorithmName() + " : " - + jobInfo.getResultFileName()); - - long pid = jobInfo.getId(); - - if (finishedJobMap.containsKey(pid)) { - finishedJobMap.remove(pid); - } - - int remoteStatus = jobInfo.getStatus(); - String recentStatusText = (remoteStatus == 0 ? "Submitted" - : (remoteStatus == 1 ? "Running" : "Kill Request")); - HpcJobInfo hpcJobInfo = hpcJobInfoMap.get(pid);// Local job - // map - HpcJobLog hpcJobLog = hpcJobManager - .getHpcJobLog(hpcJobInfo); - if (hpcJobInfo != null) { - int status = hpcJobInfo.getStatus(); - if (status != remoteStatus) { - // Update status - hpcJobInfo.setStatus(remoteStatus); - - hpcJobManager.updateHpcJobInfo(hpcJobInfo); - hpcJobLog.setLastUpdatedTime(new Date(System - .currentTimeMillis())); - - String log = "Job status changed to " - + recentStatusText; - System.out.println(hpcJobInfo.getAlgorithmName() - + " : id : " + hpcJobInfo.getId() - + " : pid : " + pid); - System.out.println(log); - - hpcJobManager.logHpcJobLogDetail(hpcJobLog, - remoteStatus, log); + System.out.println("HpcJobsScheduledTask: " + hpcAccount.getConnectionName()); + + Set hpcJobInfos = submittedHpcJobInfos.get(hpcAccount); + // Pid-HpcJobInfo map + Map hpcJobInfoMap = new HashMap<>(); + for (HpcJobInfo hpcJobInfo : hpcJobInfos) { + if (hpcJobInfo.getPid() != null) { + long pid = hpcJobInfo.getPid().longValue(); + hpcJobInfoMap.put(pid, hpcJobInfo); + + System.out.println("id: " + hpcJobInfo.getId() + " : " + hpcJobInfo.getAlgorithmName() + ": pid: " + + pid + " : " + hpcJobInfo.getResultFileName()); + + } else { + + System.out.println("id: " + hpcJobInfo.getId() + " : " + hpcJobInfo.getAlgorithmName() + + ": no pid! : " + hpcJobInfo.getResultFileName()); + + hpcJobInfos.remove(hpcJobInfo); + } } - } - } - // Download finished jobs' results - if (finishedJobMap.size() > 0) { - Set resultFiles = hpcJobManager - .listRemoteAlgorithmResultFiles(hpcAccountManager, - hpcAccount); - - Set resultFileNames = new HashSet<>(); - for (ResultFile resultFile : resultFiles) { - resultFileNames.add(resultFile.getName()); - // System.out.println(hpcAccount.getConnectionName() - // + " Result : " + resultFile.getName()); - } - - for (HpcJobInfo hpcJobInfo : finishedJobMap.values()) {// Job - // is - // done - // or - // killed or - // time-out - HpcJobLog hpcJobLog = hpcJobManager - .getHpcJobLog(hpcJobInfo); - String recentStatusText = "Job finished"; - int recentStatus = 3; // Finished - if (hpcJobInfo.getStatus() == 2) { - recentStatusText = "Job killed"; - recentStatus = 4; // Killed + // Finished job map + HashMap finishedJobMap = new HashMap<>(); + for (HpcJobInfo job : hpcJobInfos) { + finishedJobMap.put(job.getPid(), job); } - hpcJobInfo.setStatus(recentStatus); - hpcJobManager.updateHpcJobInfo(hpcJobInfo); - // System.out.println("hpcJobInfo: id: " - // + hpcJobInfo.getId() + " : " - // + hpcJobInfo.getStatus()); + try { + List jobInfos = hpcJobManager.getRemoteActiveJobs(hpcAccountManager, hpcAccount); + + for (JobInfo jobInfo : jobInfos) { + System.out.println("Remote pid: " + jobInfo.getId() + " : " + jobInfo.getAlgorithmName() + " : " + + jobInfo.getResultFileName()); + + long pid = jobInfo.getId(); + + if (finishedJobMap.containsKey(pid)) { + finishedJobMap.remove(pid); + } + + int remoteStatus = jobInfo.getStatus(); + String recentStatusText = (remoteStatus == 0 ? "Submitted" + : (remoteStatus == 1 ? "Running" : "Kill Request")); + HpcJobInfo hpcJobInfo = hpcJobInfoMap.get(pid);// Local job + // map + HpcJobLog hpcJobLog = hpcJobManager.getHpcJobLog(hpcJobInfo); + if (hpcJobInfo != null) { + int status = hpcJobInfo.getStatus(); + if (status != remoteStatus) { + // Update status + hpcJobInfo.setStatus(remoteStatus); + + hpcJobManager.updateHpcJobInfo(hpcJobInfo); + hpcJobLog.setLastUpdatedTime(new Date(System.currentTimeMillis())); + + String log = "Job status changed to " + recentStatusText; + System.out.println(hpcJobInfo.getAlgorithmName() + " : id : " + hpcJobInfo.getId() + + " : pid : " + pid); + System.out.println(log); + + hpcJobManager.logHpcJobLogDetail(hpcJobLog, remoteStatus, log); + } + } + } - hpcJobManager.logHpcJobLogDetail(hpcJobLog, - recentStatus, recentStatusText); + // Download finished jobs' results + if (finishedJobMap.size() > 0) { + Set resultFiles = hpcJobManager.listRemoteAlgorithmResultFiles(hpcAccountManager, + hpcAccount); - System.out.println(hpcJobInfo.getAlgorithmName() - + " : id : " + hpcJobInfo.getId() + " : " - + recentStatusText); + Set resultFileNames = new HashSet<>(); + for (ResultFile resultFile : resultFiles) { + resultFileNames.add(resultFile.getName()); + // System.out.println(hpcAccount.getConnectionName() + // + " Result : " + resultFile.getName()); + } - GeneralAlgorithmEditor editor = hpcJobManager - .getGeneralAlgorithmEditor(hpcJobInfo); - if (editor != null) { - System.out - .println("GeneralAlgorithmEditor is not null"); + for (HpcJobInfo hpcJobInfo : finishedJobMap.values()) {// Job + // is + // done + // or + // killed or + // time-out + HpcJobLog hpcJobLog = hpcJobManager.getHpcJobLog(hpcJobInfo); + String recentStatusText = "Job finished"; + int recentStatus = 3; // Finished + if (hpcJobInfo.getStatus() == 2) { + recentStatusText = "Job killed"; + recentStatus = 4; // Killed + } + hpcJobInfo.setStatus(recentStatus); + hpcJobManager.updateHpcJobInfo(hpcJobInfo); - String resultJsonFileName = hpcJobInfo - .getResultJsonFileName(); - String errorResultFileName = hpcJobInfo - .getErrorResultFileName(); + // System.out.println("hpcJobInfo: id: " + // + hpcJobInfo.getId() + " : " + // + hpcJobInfo.getStatus()); - if (resultFileNames.contains(resultJsonFileName)) { - recentStatus = 5; // Result Downloaded + hpcJobManager.logHpcJobLogDetail(hpcJobLog, recentStatus, recentStatusText); - String json = downloadAlgorithmResultFile( - hpcAccountManager, hpcJobManager, - hpcAccount, resultJsonFileName, editor); + System.out.println(hpcJobInfo.getAlgorithmName() + " : id : " + hpcJobInfo.getId() + " : " + + recentStatusText); - if (!json.toLowerCase().contains("not found")) { - editor.setAlgorithmResult(json); - } + GeneralAlgorithmEditor editor = hpcJobManager.getGeneralAlgorithmEditor(hpcJobInfo); + if (editor != null) { + System.out.println("GeneralAlgorithmEditor is not null"); - String log = "Result downloaded"; - hpcJobManager.logHpcJobLogDetail(hpcJobLog, - recentStatus, log); + String resultJsonFileName = hpcJobInfo.getResultJsonFileName(); + String errorResultFileName = hpcJobInfo.getErrorResultFileName(); - System.out.println(hpcJobInfo - .getAlgorithmName() - + " : id : " - + hpcJobInfo.getId() + " : " + log); + if (resultFileNames.contains(resultJsonFileName)) { + recentStatus = 5; // Result Downloaded - } else if (resultFileNames - .contains(errorResultFileName)) { - recentStatus = 6; // Error Result Downloaded + String json = downloadAlgorithmResultFile(hpcAccountManager, hpcJobManager, hpcAccount, + resultJsonFileName, editor); - String error = downloadAlgorithmResultFile( - hpcAccountManager, hpcJobManager, - hpcAccount, errorResultFileName, editor); + if (!json.toLowerCase().contains("not found")) { + editor.setAlgorithmResult(json); + } - if (!error.toLowerCase().contains("not found")) { - editor.setAlgorithmErrorResult(error); - } + String log = "Result downloaded"; + hpcJobManager.logHpcJobLogDetail(hpcJobLog, recentStatus, log); - String log = "Error Result downloaded"; - hpcJobManager.logHpcJobLogDetail(hpcJobLog, - recentStatus, log); + System.out.println( + hpcJobInfo.getAlgorithmName() + " : id : " + hpcJobInfo.getId() + " : " + log); - System.out.println(hpcJobInfo - .getAlgorithmName() - + " : id : " - + hpcJobInfo.getId() + " : " + log); + } else if (resultFileNames.contains(errorResultFileName)) { + recentStatus = 6; // Error Result Downloaded - } else { + String error = downloadAlgorithmResultFile(hpcAccountManager, hpcJobManager, hpcAccount, + errorResultFileName, editor); - // Try again - Thread.sleep(5000); + if (!error.toLowerCase().contains("not found")) { + editor.setAlgorithmErrorResult(error); + } - String json = downloadAlgorithmResultFile( - hpcAccountManager, hpcJobManager, - hpcAccount, resultJsonFileName, editor); + String log = "Error Result downloaded"; + hpcJobManager.logHpcJobLogDetail(hpcJobLog, recentStatus, log); - if (!json.toLowerCase().contains("not found")) { - editor.setAlgorithmResult(json); + System.out.println( + hpcJobInfo.getAlgorithmName() + " : id : " + hpcJobInfo.getId() + " : " + log); - recentStatus = 5; // Result Downloaded + } else { - String log = "Result downloaded"; - hpcJobManager.logHpcJobLogDetail(hpcJobLog, - recentStatus, log); + // Try again + Thread.sleep(5000); - System.out.println(hpcJobInfo - .getAlgorithmName() - + " : id : " - + hpcJobInfo.getId() + " : " + log); - } else { - String error = downloadAlgorithmResultFile( - hpcAccountManager, hpcJobManager, - hpcAccount, errorResultFileName, - editor); - - if (!error.toLowerCase().contains( - "not found")) { - editor.setAlgorithmErrorResult(error); - - recentStatus = 6; // Error Result - // Downloaded - - String log = "Error Result downloaded"; - hpcJobManager.logHpcJobLogDetail( - hpcJobLog, recentStatus, log); - - System.out.println(hpcJobInfo - .getAlgorithmName() - + " : id : " - + hpcJobInfo.getId() - + " : " - + log); - } else { - recentStatus = 7; // Result Not Found - - String log = resultJsonFileName - + " not found"; - hpcJobManager.logHpcJobLogDetail( - hpcJobLog, recentStatus, log); - - System.out.println(hpcJobInfo - .getAlgorithmName() - + " : id : " - + hpcJobInfo.getId() - + " : " - + log); - } + String json = downloadAlgorithmResultFile(hpcAccountManager, hpcJobManager, hpcAccount, + resultJsonFileName, editor); - } + if (!json.toLowerCase().contains("not found")) { + editor.setAlgorithmResult(json); + + recentStatus = 5; // Result Downloaded + + String log = "Result downloaded"; + hpcJobManager.logHpcJobLogDetail(hpcJobLog, recentStatus, log); - } + System.out.println(hpcJobInfo.getAlgorithmName() + " : id : " + hpcJobInfo.getId() + + " : " + log); + } else { + String error = downloadAlgorithmResultFile(hpcAccountManager, hpcJobManager, + hpcAccount, errorResultFileName, editor); + if (!error.toLowerCase().contains("not found")) { + editor.setAlgorithmErrorResult(error); + + recentStatus = 6; // Error Result + // Downloaded + + String log = "Error Result downloaded"; + hpcJobManager.logHpcJobLogDetail(hpcJobLog, recentStatus, log); + + System.out.println(hpcJobInfo.getAlgorithmName() + " : id : " + + hpcJobInfo.getId() + " : " + log); + } else { + recentStatus = 7; // Result Not Found + + String log = resultJsonFileName + " not found"; + hpcJobManager.logHpcJobLogDetail(hpcJobLog, recentStatus, log); + + System.out.println(hpcJobInfo.getAlgorithmName() + " : id : " + + hpcJobInfo.getId() + " : " + log); + } + + } + + } + + } + hpcJobManager.removeFinishedHpcJob(hpcJobInfo); + } + } else { + System.out.println("No finished job yet."); + } + + } catch (Exception e) { + // TODO Auto-generated catch block + e.printStackTrace(); } - hpcJobManager.removeFinishedHpcJob(hpcJobInfo); - } - } else { - System.out.println("No finished job yet."); - } - } catch (Exception e) { - // TODO Auto-generated catch block - e.printStackTrace(); - } + } } - } - - private String downloadAlgorithmResultFile( - final HpcAccountManager hpcAccountManager, - final HpcJobManager hpcJobManager, final HpcAccount hpcAccount, - final String resultFileName, final GeneralAlgorithmEditor editor) - throws ClientProtocolException, URISyntaxException, IOException, - Exception { - int trial = 10; - String txt = hpcJobManager.downloadAlgorithmResultFile( - hpcAccountManager, hpcAccount, resultFileName); - while (trial != 0 && txt.toLowerCase().contains("not found")) { - Thread.sleep(5000); - txt = hpcJobManager.downloadAlgorithmResultFile(hpcAccountManager, - hpcAccount, resultFileName); - trial--; - } + private String downloadAlgorithmResultFile(final HpcAccountManager hpcAccountManager, + final HpcJobManager hpcJobManager, final HpcAccount hpcAccount, final String resultFileName, + final GeneralAlgorithmEditor editor) + throws ClientProtocolException, URISyntaxException, IOException, Exception { + int trial = 10; + String txt = hpcJobManager.downloadAlgorithmResultFile(hpcAccountManager, hpcAccount, resultFileName); + while (trial != 0 && txt.toLowerCase().contains("not found")) { + Thread.sleep(5000); + txt = hpcJobManager.downloadAlgorithmResultFile(hpcAccountManager, hpcAccount, resultFileName); + trial--; + } - return txt; - } + return txt; + } } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/task/PendingHpcJobUpdaterTask.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/task/PendingHpcJobUpdaterTask.java index 451d9707ef..ab821157d7 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/task/PendingHpcJobUpdaterTask.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/task/PendingHpcJobUpdaterTask.java @@ -25,216 +25,168 @@ */ public class PendingHpcJobUpdaterTask extends TimerTask { - private final HpcJobManager hpcJobManager; + private final HpcJobManager hpcJobManager; - private final HpcJobActivityEditor hpcJobActivityEditor; + private final HpcJobActivityEditor hpcJobActivityEditor; - public PendingHpcJobUpdaterTask(final HpcJobManager hpcJobManager, - final HpcJobActivityEditor hpcJobActivityEditor) { - this.hpcJobManager = hpcJobManager; - this.hpcJobActivityEditor = hpcJobActivityEditor; - } - - @Override - public void run() { - - if (hpcJobActivityEditor.selectedTabbedPaneIndex() != 0) - return; + public PendingHpcJobUpdaterTask(final HpcJobManager hpcJobManager, + final HpcJobActivityEditor hpcJobActivityEditor) { + this.hpcJobManager = hpcJobManager; + this.hpcJobActivityEditor = hpcJobActivityEditor; + } - final DefaultTableModel model = (DefaultTableModel) hpcJobActivityEditor - .getJobsTableModel(); + @Override + public void run() { - Set pendingDisplayHpcJobInfoSet = hpcJobActivityEditor - .getPendingDisplayHpcJobInfoSet(); + if (hpcJobActivityEditor.selectedTabbedPaneIndex() != 0) + return; - Set notPendingJobAnymoreSet = monitorDataUploadProgress( - pendingDisplayHpcJobInfoSet, model); + final DefaultTableModel model = (DefaultTableModel) hpcJobActivityEditor.getJobsTableModel(); - hpcJobActivityEditor - .addSubmittedDisplayHpcJobInfo(notPendingJobAnymoreSet); + Set pendingDisplayHpcJobInfoSet = hpcJobActivityEditor.getPendingDisplayHpcJobInfoSet(); - hpcJobActivityEditor - .removePendingDisplayHpcJobInfo(notPendingJobAnymoreSet); - } + Set notPendingJobAnymoreSet = monitorDataUploadProgress(pendingDisplayHpcJobInfoSet, model); - private synchronized Set monitorDataUploadProgress( - final Set pendingDisplayHpcJobInfoSet, - final DefaultTableModel model) { + hpcJobActivityEditor.addSubmittedDisplayHpcJobInfo(notPendingJobAnymoreSet); - Map rowMap = new HashMap<>(); - for (int row = 0; row < model.getRowCount(); row++) { - rowMap.put( - Long.valueOf(model.getValueAt(row, - HpcJobActivityEditor.ID_COLUMN).toString()), row); + hpcJobActivityEditor.removePendingDisplayHpcJobInfo(notPendingJobAnymoreSet); } - Set notPendingJobAnymoreSet = new HashSet<>(); + private synchronized Set monitorDataUploadProgress(final Set pendingDisplayHpcJobInfoSet, + final DefaultTableModel model) { - for (HpcJobInfo hpcJobInfo : pendingDisplayHpcJobInfoSet) { + Map rowMap = new HashMap<>(); + for (int row = 0; row < model.getRowCount(); row++) { + rowMap.put(Long.valueOf(model.getValueAt(row, HpcJobActivityEditor.ID_COLUMN).toString()), row); + } - int status = hpcJobInfo.getStatus(); + Set notPendingJobAnymoreSet = new HashSet<>(); - if (!rowMap.containsKey(hpcJobInfo.getId())) { - continue; - } + for (HpcJobInfo hpcJobInfo : pendingDisplayHpcJobInfoSet) { - int modelRow = rowMap.get(hpcJobInfo.getId()); + int status = hpcJobInfo.getStatus(); - // In case the job was accidentally added to the map OR the kill - // request was issued - if (status != -1) { - notPendingJobAnymoreSet.add(hpcJobInfo); - } else { + if (!rowMap.containsKey(hpcJobInfo.getId())) { + continue; + } - // Dataset uploading progress - AlgorithmParamRequest algorParamReq = hpcJobInfo - .getAlgorithmParamRequest(); - String datasetPath = algorParamReq.getDatasetPath(); + int modelRow = rowMap.get(hpcJobInfo.getId()); - int dataUploadProgress = hpcJobManager - .getUploadFileProgress(datasetPath); + // In case the job was accidentally added to the map OR the kill + // request was issued + if (status != -1) { + notPendingJobAnymoreSet.add(hpcJobInfo); + } else { - if (dataUploadProgress > -1 && dataUploadProgress < 100) { - model.setValueAt("" + dataUploadProgress + "%", modelRow, - HpcJobActivityEditor.DATA_UPLOAD_COLUMN); - } else if(dataUploadProgress == -1) { - model.setValueAt("Skipped", modelRow, - HpcJobActivityEditor.DATA_UPLOAD_COLUMN); - } else { - model.setValueAt("Done", modelRow, - HpcJobActivityEditor.DATA_UPLOAD_COLUMN); - } + // Dataset uploading progress + AlgorithmParamRequest algorParamReq = hpcJobInfo.getAlgorithmParamRequest(); + String datasetPath = algorParamReq.getDatasetPath(); - // Prior Knowledge uploading progress - String priorKnowledgePath = algorParamReq - .getPriorKnowledgePath(); - - int priorKnowledgeUploadProgress = -1; - if (priorKnowledgePath != null) { - - System.out.println("priorKnowledgePath: " - + priorKnowledgePath); - - priorKnowledgeUploadProgress = hpcJobManager - .getUploadFileProgress(priorKnowledgePath); - - if (priorKnowledgeUploadProgress > -1 - && priorKnowledgeUploadProgress < 100) { - model.setValueAt("" + priorKnowledgeUploadProgress - + "%", modelRow, - HpcJobActivityEditor.KNOWLEDGE_UPLOAD_COLUMN); - } else { - model.setValueAt("Done", modelRow, - HpcJobActivityEditor.KNOWLEDGE_UPLOAD_COLUMN); - } - } else { - model.setValueAt("Skipped", modelRow, - HpcJobActivityEditor.KNOWLEDGE_UPLOAD_COLUMN); - } + int dataUploadProgress = hpcJobManager.getUploadFileProgress(datasetPath); - if (dataUploadProgress == 100 - && (priorKnowledgeUploadProgress == -1 || priorKnowledgeUploadProgress == 100)) { + if (dataUploadProgress > -1 && dataUploadProgress < 100) { + model.setValueAt("" + dataUploadProgress + "%", modelRow, HpcJobActivityEditor.DATA_UPLOAD_COLUMN); + } else if (dataUploadProgress == -1) { + model.setValueAt("Skipped", modelRow, HpcJobActivityEditor.DATA_UPLOAD_COLUMN); + } else { + model.setValueAt("Done", modelRow, HpcJobActivityEditor.DATA_UPLOAD_COLUMN); + } - System.out.println("HpcJobInfo Id: " + hpcJobInfo.getId() - + " done with both uploading"); + // Prior Knowledge uploading progress + String priorKnowledgePath = algorParamReq.getPriorKnowledgePath(); - Map> pendingHpcJobInfoMap = hpcJobManager - .getPendingHpcJobInfoMap(); + int priorKnowledgeUploadProgress = -1; + if (priorKnowledgePath != null) { - Map> submittedHpcJobInfoMap = hpcJobManager - .getSubmittedHpcJobInfoMap(); + System.out.println("priorKnowledgePath: " + priorKnowledgePath); - if (pendingHpcJobInfoMap != null) { - Set pendingJobSet = pendingHpcJobInfoMap - .get(hpcJobInfo.getHpcAccount()); + priorKnowledgeUploadProgress = hpcJobManager.getUploadFileProgress(priorKnowledgePath); - // Is the job still stuck in the pre-processed schedule - // task? - long id = -1; - for (HpcJobInfo pendingJob : pendingJobSet) { - if (pendingJob.getId() == hpcJobInfo.getId()) { - id = pendingJob.getId(); - continue; - } - } + if (priorKnowledgeUploadProgress > -1 && priorKnowledgeUploadProgress < 100) { + model.setValueAt("" + priorKnowledgeUploadProgress + "%", modelRow, + HpcJobActivityEditor.KNOWLEDGE_UPLOAD_COLUMN); + } else { + model.setValueAt("Done", modelRow, HpcJobActivityEditor.KNOWLEDGE_UPLOAD_COLUMN); + } + } else { + model.setValueAt("Skipped", modelRow, HpcJobActivityEditor.KNOWLEDGE_UPLOAD_COLUMN); + } - // The job is not in the pre-processed schedule task - if (id == -1 && submittedHpcJobInfoMap != null) { - Set submittedJobSet = submittedHpcJobInfoMap - .get(hpcJobInfo.getHpcAccount()); - - // Is the job in the submitted schedule task? - for (HpcJobInfo submittedJob : submittedJobSet) { - if (submittedJob.getId() == hpcJobInfo.getId()) { - - // Status - switch (submittedJob.getStatus()) { - case -1: - model.setValueAt( - "Pending", - modelRow, - HpcJobActivityEditor.STATUS_COLUMN); - break; - case 0: - model.setValueAt( - "Submitted", - modelRow, - HpcJobActivityEditor.STATUS_COLUMN); - break; - case 1: - model.setValueAt( - "Running", - modelRow, - HpcJobActivityEditor.STATUS_COLUMN); - break; - case 2: - model.setValueAt( - "Kill Request", - modelRow, - HpcJobActivityEditor.STATUS_COLUMN); - break; - } - - // Submitted time - if (submittedJob.getSubmittedTime() != null) { - model.setValueAt( - FilePrint - .fileTimestamp(hpcJobInfo - .getSubmittedTime() - .getTime()), - modelRow, - HpcJobActivityEditor.ACTIVE_SUBMITTED_COLUMN); - } - - // Hpc Pid - model.setValueAt( - submittedJob.getPid(), - modelRow, - HpcJobActivityEditor.ACTIVE_HPC_JOB_ID_COLUMN); - - // last update - HpcJobLog hpcJobLog = hpcJobManager - .getHpcJobLog(submittedJob); - model.setValueAt( - FilePrint.fileTimestamp(hpcJobLog - .getLastUpdatedTime() - .getTime()), - modelRow, - HpcJobActivityEditor.ACTIVE_LAST_UPDATED_COLUMN); - - // Remove from the pending queue - notPendingJobAnymoreSet.add(submittedJob); - - continue; + if (dataUploadProgress == 100 + && (priorKnowledgeUploadProgress == -1 || priorKnowledgeUploadProgress == 100)) { + + System.out.println("HpcJobInfo Id: " + hpcJobInfo.getId() + " done with both uploading"); + + Map> pendingHpcJobInfoMap = hpcJobManager.getPendingHpcJobInfoMap(); + + Map> submittedHpcJobInfoMap = hpcJobManager.getSubmittedHpcJobInfoMap(); + + if (pendingHpcJobInfoMap != null) { + Set pendingJobSet = pendingHpcJobInfoMap.get(hpcJobInfo.getHpcAccount()); + + // Is the job still stuck in the pre-processed schedule + // task? + long id = -1; + for (HpcJobInfo pendingJob : pendingJobSet) { + if (pendingJob.getId() == hpcJobInfo.getId()) { + id = pendingJob.getId(); + continue; + } + } + + // The job is not in the pre-processed schedule task + if (id == -1 && submittedHpcJobInfoMap != null) { + Set submittedJobSet = submittedHpcJobInfoMap.get(hpcJobInfo.getHpcAccount()); + + // Is the job in the submitted schedule task? + for (HpcJobInfo submittedJob : submittedJobSet) { + if (submittedJob.getId() == hpcJobInfo.getId()) { + + // Status + switch (submittedJob.getStatus()) { + case -1: + model.setValueAt("Pending", modelRow, HpcJobActivityEditor.STATUS_COLUMN); + break; + case 0: + model.setValueAt("Submitted", modelRow, HpcJobActivityEditor.STATUS_COLUMN); + break; + case 1: + model.setValueAt("Running", modelRow, HpcJobActivityEditor.STATUS_COLUMN); + break; + case 2: + model.setValueAt("Kill Request", modelRow, HpcJobActivityEditor.STATUS_COLUMN); + break; + } + + // Submitted time + if (submittedJob.getSubmittedTime() != null) { + model.setValueAt( + FilePrint.fileTimestamp(hpcJobInfo.getSubmittedTime().getTime()), + modelRow, HpcJobActivityEditor.ACTIVE_SUBMITTED_COLUMN); + } + + // Hpc Pid + model.setValueAt(submittedJob.getPid(), modelRow, + HpcJobActivityEditor.ACTIVE_HPC_JOB_ID_COLUMN); + + // last update + HpcJobLog hpcJobLog = hpcJobManager.getHpcJobLog(submittedJob); + model.setValueAt(FilePrint.fileTimestamp(hpcJobLog.getLastUpdatedTime().getTime()), + modelRow, HpcJobActivityEditor.ACTIVE_LAST_UPDATED_COLUMN); + + // Remove from the pending queue + notPendingJobAnymoreSet.add(submittedJob); + + continue; + } + } + } + } } - } } - } } - } - } - return notPendingJobAnymoreSet; - } + return notPendingJobAnymoreSet; + } } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/task/SubmittedHpcJobUpdaterTask.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/task/SubmittedHpcJobUpdaterTask.java index 81ca9abc7a..098827003d 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/task/SubmittedHpcJobUpdaterTask.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/task/SubmittedHpcJobUpdaterTask.java @@ -24,113 +24,100 @@ */ public class SubmittedHpcJobUpdaterTask extends TimerTask { - private final HpcJobManager hpcJobManager; + private final HpcJobManager hpcJobManager; - private final HpcJobActivityEditor hpcJobActivityEditor; + private final HpcJobActivityEditor hpcJobActivityEditor; - public SubmittedHpcJobUpdaterTask(final HpcJobManager hpcJobManager, - final HpcJobActivityEditor hpcJobActivityEditor) { - this.hpcJobManager = hpcJobManager; - this.hpcJobActivityEditor = hpcJobActivityEditor; - } - - @Override - public void run() { - if (hpcJobActivityEditor.selectedTabbedPaneIndex() != 0) - return; - - final DefaultTableModel model = (DefaultTableModel) hpcJobActivityEditor - .getJobsTableModel(); + public SubmittedHpcJobUpdaterTask(final HpcJobManager hpcJobManager, + final HpcJobActivityEditor hpcJobActivityEditor) { + this.hpcJobManager = hpcJobManager; + this.hpcJobActivityEditor = hpcJobActivityEditor; + } - Set submittedDisplayHpcJobInfoSet = hpcJobActivityEditor - .getSubmittedDisplayHpcJobInfoSet(); + @Override + public void run() { + if (hpcJobActivityEditor.selectedTabbedPaneIndex() != 0) + return; - Set finishedJobSet = monitorSubmittedJobStatus( - submittedDisplayHpcJobInfoSet, model); + final DefaultTableModel model = (DefaultTableModel) hpcJobActivityEditor.getJobsTableModel(); - hpcJobActivityEditor.removeSubmittedDisplayHpcJobInfo(finishedJobSet); + Set submittedDisplayHpcJobInfoSet = hpcJobActivityEditor.getSubmittedDisplayHpcJobInfoSet(); - hpcJobActivityEditor - .removeSubmittedDisplayJobFromActiveTableModel(finishedJobSet); + Set finishedJobSet = monitorSubmittedJobStatus(submittedDisplayHpcJobInfoSet, model); - } + hpcJobActivityEditor.removeSubmittedDisplayHpcJobInfo(finishedJobSet); - private synchronized Set monitorSubmittedJobStatus( - final Set submittedDisplayHpcJobInfoSet, - final DefaultTableModel model) { + hpcJobActivityEditor.removeSubmittedDisplayJobFromActiveTableModel(finishedJobSet); - Map rowMap = new HashMap<>(); - for (int row = 0; row < model.getRowCount(); row++) { - rowMap.put( - Long.valueOf(model.getValueAt(row, - HpcJobActivityEditor.ID_COLUMN).toString()), row); } - Set finishedJobSet = new HashSet<>(); - - for (HpcJobInfo hpcJobInfo : submittedDisplayHpcJobInfoSet) { + private synchronized Set monitorSubmittedJobStatus(final Set submittedDisplayHpcJobInfoSet, + final DefaultTableModel model) { - Long id = hpcJobInfo.getId(); - - if (!rowMap.containsKey(id)) { - //System.out.println("hpcJobInfo not found in rowMap"); - continue; - } + Map rowMap = new HashMap<>(); + for (int row = 0; row < model.getRowCount(); row++) { + rowMap.put(Long.valueOf(model.getValueAt(row, HpcJobActivityEditor.ID_COLUMN).toString()), row); + } - int modelRow = rowMap.get(id); + Set finishedJobSet = new HashSet<>(); + + for (HpcJobInfo hpcJobInfo : submittedDisplayHpcJobInfoSet) { + + Long id = hpcJobInfo.getId(); + + if (!rowMap.containsKey(id)) { + // System.out.println("hpcJobInfo not found in rowMap"); + continue; + } + + int modelRow = rowMap.get(id); + + Map> submittedHpcJobInfoMap = hpcJobManager.getSubmittedHpcJobInfoMap(); + Set submittedJobSet = submittedHpcJobInfoMap.get(hpcJobInfo.getHpcAccount()); + if (submittedJobSet != null) { + for (HpcJobInfo submittedJob : submittedJobSet) { + if (submittedJob.getId() == hpcJobInfo.getId()) { + hpcJobInfo = submittedJob; + // System.out + // .println("Found submittedJob in the + // submittedHpcJobInfoMap id matched!"); + continue; + } + } + } + + int status = hpcJobInfo.getStatus(); + + // Status + switch (hpcJobInfo.getStatus()) { + case -1: + model.setValueAt("Pending", modelRow, HpcJobActivityEditor.STATUS_COLUMN); + break; + case 0: + model.setValueAt("Submitted", modelRow, HpcJobActivityEditor.STATUS_COLUMN); + break; + case 1: + model.setValueAt("Running", modelRow, HpcJobActivityEditor.STATUS_COLUMN); + break; + case 2: + model.setValueAt("Kill Request", modelRow, HpcJobActivityEditor.STATUS_COLUMN); + break; + } + + // last update + HpcJobLog hpcJobLog = hpcJobManager.getHpcJobLog(hpcJobInfo); + model.setValueAt(FilePrint.fileTimestamp(hpcJobLog.getLastUpdatedTime().getTime()), modelRow, + HpcJobActivityEditor.ACTIVE_LAST_UPDATED_COLUMN); + + // In case the job was accidentally added to the map OR the job + // was finished. + if (status > 2) { + finishedJobSet.add(hpcJobInfo); + } - Map> submittedHpcJobInfoMap = hpcJobManager - .getSubmittedHpcJobInfoMap(); - Set submittedJobSet = submittedHpcJobInfoMap - .get(hpcJobInfo.getHpcAccount()); - if (submittedJobSet != null) { - for (HpcJobInfo submittedJob : submittedJobSet) { - if (submittedJob.getId() == hpcJobInfo.getId()) { - hpcJobInfo = submittedJob; - //System.out - // .println("Found submittedJob in the submittedHpcJobInfoMap id matched!"); - continue; - } } - } - - int status = hpcJobInfo.getStatus(); - - // Status - switch (hpcJobInfo.getStatus()) { - case -1: - model.setValueAt("Pending", modelRow, - HpcJobActivityEditor.STATUS_COLUMN); - break; - case 0: - model.setValueAt("Submitted", modelRow, - HpcJobActivityEditor.STATUS_COLUMN); - break; - case 1: - model.setValueAt("Running", modelRow, - HpcJobActivityEditor.STATUS_COLUMN); - break; - case 2: - model.setValueAt("Kill Request", modelRow, - HpcJobActivityEditor.STATUS_COLUMN); - break; - } - - // last update - HpcJobLog hpcJobLog = hpcJobManager.getHpcJobLog(hpcJobInfo); - model.setValueAt(FilePrint.fileTimestamp(hpcJobLog - .getLastUpdatedTime().getTime()), modelRow, - HpcJobActivityEditor.ACTIVE_LAST_UPDATED_COLUMN); - - // In case the job was accidentally added to the map OR the job - // was finished. - if (status > 2) { - finishedJobSet.add(hpcJobInfo); - } + return finishedJobSet; } - return finishedJobSet; - } - } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/util/HpcAccountUtils.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/util/HpcAccountUtils.java index c001a2b8f7..056c6fdf17 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/util/HpcAccountUtils.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/hpc/util/HpcAccountUtils.java @@ -22,77 +22,68 @@ */ public class HpcAccountUtils { - public static boolean testConnection(final HpcAccountManager hpcAccountManager, final HpcAccount hpcAccount) { - try { - getJsonWebToken(hpcAccountManager, hpcAccount); + public static boolean testConnection(final HpcAccountManager hpcAccountManager, final HpcAccount hpcAccount) { + try { + getJsonWebToken(hpcAccountManager, hpcAccount); - return true; + return true; - } catch (ClientProtocolException e) { - } catch (URISyntaxException e) { - } catch (IOException e) { - } catch (Exception e) { - } + } catch (ClientProtocolException e) { + } catch (URISyntaxException e) { + } catch (IOException e) { + } catch (Exception e) { + } - return false; - } + return false; + } - public static DataFile getRemoteDataFile(final HpcAccountManager hpcAccountManager, - final RemoteDataFileService remoteDataService, - final HpcAccount hpcAccount, String md5) - throws ClientProtocolException, URISyntaxException, IOException, - Exception { - Set dataFiles = remoteDataService - .retrieveDataFileInfo(getJsonWebToken(hpcAccountManager, hpcAccount)); - for (DataFile dataFile : dataFiles) { - String remoteMd5 = dataFile.getMd5checkSum(); - if (md5.equalsIgnoreCase(remoteMd5)) { - return dataFile; - } + public static DataFile getRemoteDataFile(final HpcAccountManager hpcAccountManager, + final RemoteDataFileService remoteDataService, final HpcAccount hpcAccount, String md5) + throws ClientProtocolException, URISyntaxException, IOException, Exception { + Set dataFiles = remoteDataService + .retrieveDataFileInfo(getJsonWebToken(hpcAccountManager, hpcAccount)); + for (DataFile dataFile : dataFiles) { + String remoteMd5 = dataFile.getMd5checkSum(); + if (md5.equalsIgnoreCase(remoteMd5)) { + return dataFile; + } + } + return null; } - return null; - } - public static DataFile getRemotePriorKnowledgeFile(final HpcAccountManager hpcAccountManager, - final RemoteDataFileService remoteDataService, - final HpcAccount hpcAccount, String md5) - throws ClientProtocolException, URISyntaxException, IOException, - Exception { - Set priorFiles = remoteDataService - .retrievePriorKnowledgeFileInfo(getJsonWebToken(hpcAccountManager, hpcAccount)); - for (DataFile priorFile : priorFiles) { - String remoteMd5 = priorFile.getMd5checkSum(); - if (md5.equalsIgnoreCase(remoteMd5)) { - return priorFile; - } + public static DataFile getRemotePriorKnowledgeFile(final HpcAccountManager hpcAccountManager, + final RemoteDataFileService remoteDataService, final HpcAccount hpcAccount, String md5) + throws ClientProtocolException, URISyntaxException, IOException, Exception { + Set priorFiles = remoteDataService + .retrievePriorKnowledgeFileInfo(getJsonWebToken(hpcAccountManager, hpcAccount)); + for (DataFile priorFile : priorFiles) { + String remoteMd5 = priorFile.getMd5checkSum(); + if (md5.equalsIgnoreCase(remoteMd5)) { + return priorFile; + } + } + return null; } - return null; - } - - public static JsonWebToken getJsonWebToken( - final HpcAccountManager hpcAccountManager, final HpcAccount hpcAccount) - throws Exception { - return hpcAccountManager.getJsonWebTokenManager().getJsonWebToken( - hpcAccount); - } - public static void summarizeDataset( - final RemoteDataFileService remoteDataService, - final AlgorithmParamRequest algorParamReq, - final long datasetFileId, final JsonWebToken jsonWebToken) - throws ClientProtocolException, URISyntaxException, IOException { - String variableType = "continuous"; - if (algorParamReq.getVariableType().equalsIgnoreCase("discrete")) { - variableType = "discrete"; + public static JsonWebToken getJsonWebToken(final HpcAccountManager hpcAccountManager, final HpcAccount hpcAccount) + throws Exception { + return hpcAccountManager.getJsonWebTokenManager().getJsonWebToken(hpcAccount); } - String fileDelimiter = "tab"; - if (!algorParamReq.getFileDelimiter().equalsIgnoreCase("tab")) { - fileDelimiter = "comma"; + public static void summarizeDataset(final RemoteDataFileService remoteDataService, + final AlgorithmParamRequest algorParamReq, final long datasetFileId, final JsonWebToken jsonWebToken) + throws ClientProtocolException, URISyntaxException, IOException { + String variableType = "continuous"; + if (algorParamReq.getVariableType().equalsIgnoreCase("discrete")) { + variableType = "discrete"; + } + + String fileDelimiter = "tab"; + if (!algorParamReq.getFileDelimiter().equalsIgnoreCase("tab")) { + fileDelimiter = "comma"; + } + + remoteDataService.summarizeDataFile(datasetFileId, variableType, fileDelimiter, jsonWebToken); } - remoteDataService.summarizeDataFile(datasetFileId, variableType, - fileDelimiter, jsonWebToken); - } - } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/DagEditor.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/DagEditor.java index 97ef27008d..b0871f964b 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/DagEditor.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/DagEditor.java @@ -89,8 +89,6 @@ public void propertyChange(PropertyChangeEvent evt) { int numModels = dagWrapper.getNumModels(); - System.out.println("numModels = " + numModels); - if (numModels > 1) { final JComboBox comp = new JComboBox<>(); diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/EditorWindow.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/EditorWindow.java index f704d207d3..d9ccaa962f 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/EditorWindow.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/EditorWindow.java @@ -69,8 +69,6 @@ public EditorWindow(JPanel editor, String title, String buttonName, boolean cancellable, Component centeringComp) { super(title, true, true, true, false); - System.out.println("EditorWindow: " + title + " : " + buttonName); - if (editor == null) { throw new NullPointerException("Editor must not be null."); } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/FactorAnalysisEditor.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/FactorAnalysisEditor.java index b6cb9ca3c6..2abc4f3525 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/FactorAnalysisEditor.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/FactorAnalysisEditor.java @@ -26,7 +26,7 @@ import edu.cmu.tetrad.graph.TimeLagGraph; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetradapp.model.FactorAnalysisRunner; -import edu.cmu.tetradapp.model.IFgsRunner; +import edu.cmu.tetradapp.model.IFgesRunner; import edu.cmu.tetradapp.workbench.GraphWorkbench; import javax.swing.*; @@ -145,9 +145,9 @@ private JComponent getIndTestParamBox(Parameters params) { } if (params instanceof Parameters) { - if (getAlgorithmRunner() instanceof IFgsRunner) { - IFgsRunner gesRunner = ((IFgsRunner) getAlgorithmRunner()); - return new FgsIndTestParamsEditor(params, gesRunner.getType()); + if (getAlgorithmRunner() instanceof IFgesRunner) { + IFgesRunner gesRunner = ((IFgesRunner) getAlgorithmRunner()); + return new FgesIndTestParamsEditor(params, gesRunner.getType()); } } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/FciCcdSearchEditor.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/FciCcdSearchEditor.java index a910a53964..a374e28a68 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/FciCcdSearchEditor.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/FciCcdSearchEditor.java @@ -398,8 +398,8 @@ private JComponent getIndTestParamBox(Parameters params) { } if (params instanceof Parameters) { - FgsRunner fgsRunner = ((FgsRunner) getAlgorithmRunner()); - return new FgsIndTestParamsEditor(params, fgsRunner.getType()); + FgesRunner fgesRunner = ((FgesRunner) getAlgorithmRunner()); + return new FgesIndTestParamsEditor(params, fgesRunner.getType()); } // if (params instanceof LagIndTestParams) { diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/FgsDisplay.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/FgesDisplay.java similarity index 97% rename from tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/FgsDisplay.java rename to tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/FgesDisplay.java index 8ec1b37757..d691e90882 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/FgsDisplay.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/FgesDisplay.java @@ -46,7 +46,7 @@ * * @author Joseph Ramsey */ -public class FgsDisplay extends JPanel implements GraphEditable { +public class FgesDisplay extends JPanel implements GraphEditable { private final Graph resultGraph; private final GraphWorkbench workbench; private List topGraphs; @@ -56,7 +56,7 @@ public class FgsDisplay extends JPanel implements GraphEditable { private final JLabel scoreLabel; private Indexable indexable; - public FgsDisplay(Graph resultGraph, final List topGraphs, Indexable indexable) { + public FgesDisplay(Graph resultGraph, final List topGraphs, Indexable indexable) { this.nf = NumberFormatUtil.getInstance().getNumberFormat(); this.indexable = indexable; this.topGraphs = topGraphs; diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/FgsIndTestParamsEditor.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/FgesIndTestParamsEditor.java similarity index 86% rename from tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/FgsIndTestParamsEditor.java rename to tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/FgesIndTestParamsEditor.java index de2d70df69..30b13c8bf4 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/FgsIndTestParamsEditor.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/FgesIndTestParamsEditor.java @@ -22,7 +22,7 @@ package edu.cmu.tetradapp.editor; import edu.cmu.tetrad.util.Parameters; -import edu.cmu.tetradapp.model.FgsRunner; +import edu.cmu.tetradapp.model.FgesRunner; import edu.cmu.tetradapp.util.DoubleTextField; import edu.cmu.tetradapp.util.IntTextField; @@ -39,8 +39,8 @@ * @author Ricardo Silva */ -class FgsIndTestParamsEditor extends JComponent { - private final FgsRunner.Type type; +class FgesIndTestParamsEditor extends JComponent { + private final FgesRunner.Type type; private final Parameters params; private DoubleTextField cellPriorField, structurePriorField; private JButton uniformStructurePrior; @@ -53,21 +53,21 @@ class FgsIndTestParamsEditor extends JComponent { */ private final JCheckBox faithfulnessAssumed; - public FgsIndTestParamsEditor(Parameters params, FgsRunner.Type type) { + public FgesIndTestParamsEditor(Parameters params, FgesRunner.Type type) { this.params = params; this.type = type; NumberFormat nf = new DecimalFormat("0.0####"); NumberFormat smallNf = new DecimalFormat("0.0E0"); - if (type == FgsRunner.Type.DISCRETE) { + if (type == FgesRunner.Type.DISCRETE) { this.cellPriorField = new DoubleTextField( - getFgsIndTestParams().getDouble("samplePrior", 1), 5, nf, smallNf, 1e-4); + getFgesIndTestParams().getDouble("samplePrior", 1), 5, nf, smallNf, 1e-4); this.cellPriorField.setFilter(new DoubleTextField.Filter() { public double filter(double value, double oldValue) { try { - getFgsIndTestParams().set("samplePrior", value); + getFgesIndTestParams().set("samplePrior", value); return value; } catch (IllegalArgumentException e) { @@ -77,11 +77,11 @@ public double filter(double value, double oldValue) { }); this.structurePriorField = new DoubleTextField( - getFgsIndTestParams().getDouble("structurePrior", 1), 5, nf); + getFgesIndTestParams().getDouble("structurePrior", 1), 5, nf); this.structurePriorField.setFilter(new DoubleTextField.Filter() { public double filter(double value, double oldValue) { try { - getFgsIndTestParams().set("structurePrior", value); + getFgesIndTestParams().set("structurePrior", value); return value; } catch (IllegalArgumentException e) { @@ -112,11 +112,11 @@ public void actionPerformed(ActionEvent e) { }); } else { this.penaltyDiscount = new DoubleTextField( - getFgsIndTestParams().getDouble("penaltyDiscount", 4), 5, nf); + getFgesIndTestParams().getDouble("penaltyDiscount", 4), 5, nf); this.penaltyDiscount.setFilter(new DoubleTextField.Filter() { public double filter(double value, double oldValue) { try { - getFgsIndTestParams().set("penaltyDiscount", value); + getFgesIndTestParams().set("penaltyDiscount", value); return value; } catch (IllegalArgumentException e) { @@ -127,11 +127,11 @@ public double filter(double value, double oldValue) { } this.numPatternsToSave = new IntTextField( - getFgsIndTestParams().getInt("numPatternsToSave", 1), 5); + getFgesIndTestParams().getInt("numPatternsToSave", 1), 5); this.numPatternsToSave.setFilter(new IntTextField.Filter() { public int filter(int value, int oldValue) { try { - getFgsIndTestParams().set("numPatternsToSave", value); + getFgesIndTestParams().set("numPatternsToSave", value); return value; } catch (IllegalArgumentException e) { @@ -140,11 +140,11 @@ public int filter(int value, int oldValue) { } }); - this.depth = new IntTextField(getFgsIndTestParams().getInt("depth", -1), 4); + this.depth = new IntTextField(getFgesIndTestParams().getInt("depth", -1), 4); this.depth.setFilter(new IntTextField.Filter() { public int filter(int value, int oldValue) { try { - getFgsIndTestParams().set("depth", value); + getFgesIndTestParams().set("depth", value); return value; } catch (IllegalArgumentException e) { @@ -154,11 +154,11 @@ public int filter(int value, int oldValue) { }); faithfulnessAssumed = new JCheckBox(); - faithfulnessAssumed.setSelected(getFgsIndTestParams().getBoolean("faithfulnessAssumed", true)); + faithfulnessAssumed.setSelected(getFgesIndTestParams().getBoolean("faithfulnessAssumed", true)); faithfulnessAssumed.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent actionEvent) { JCheckBox source = (JCheckBox) actionEvent.getSource(); - getFgsIndTestParams().set("faithfulnessAssumed", source.isSelected()); + getFgesIndTestParams().set("faithfulnessAssumed", source.isSelected()); } }); @@ -168,7 +168,7 @@ public void actionPerformed(ActionEvent actionEvent) { private void buildGui() { setLayout(new BoxLayout(this, BoxLayout.Y_AXIS)); - if (type == FgsRunner.Type.DISCRETE) { + if (type == FgesRunner.Type.DISCRETE) { Box b0 = Box.createHorizontalBox(); b0.add(new JLabel("BDeu:")); b0.add(Box.createHorizontalGlue()); @@ -213,7 +213,7 @@ private void buildGui() { b4b.add(Box.createHorizontalGlue()); b4b.add(depth); add(b4b); - } else if (type == FgsRunner.Type.CONTINUOUS || type == FgsRunner.Type.MIXED){ + } else if (type == FgesRunner.Type.CONTINUOUS || type == FgesRunner.Type.MIXED){ Box b7 = Box.createHorizontalBox(); b7.add(new JLabel("Penalty Discount")); b7.add(Box.createHorizontalGlue()); @@ -237,7 +237,7 @@ private void buildGui() { b4b.add(Box.createHorizontalGlue()); b4b.add(depth); add(b4b); - } else if (type == FgsRunner.Type.GRAPH) { + } else if (type == FgesRunner.Type.GRAPH) { Box b8 = Box.createHorizontalBox(); b8.add(new JLabel("Num Patterns to Save")); b8.add(Box.createHorizontalGlue()); @@ -261,7 +261,7 @@ private void buildGui() { } - private Parameters getFgsIndTestParams() { + private Parameters getFgesIndTestParams() { return params; } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/FgsSearchEditor.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/FgesSearchEditor.java similarity index 94% rename from tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/FgsSearchEditor.java rename to tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/FgesSearchEditor.java index 1e751267b8..f78325d107 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/FgsSearchEditor.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/FgesSearchEditor.java @@ -55,41 +55,41 @@ * * @author Joseph Ramsey */ -public class FgsSearchEditor extends AbstractSearchEditor +public class FgesSearchEditor extends AbstractSearchEditor implements KnowledgeEditable, LayoutEditable, Indexable, DoNotScroll { // private JTextArea bootstrapEdgeCountsScroll; private JTabbedPane tabbedPane; private boolean alreadyLaidOut = false; - private FgsDisplay gesDisplay; - private FgsIndTestParamsEditor paramsEditor; + private FgesDisplay gesDisplay; + private FgesIndTestParamsEditor paramsEditor; //=========================CONSTRUCTORS============================// /** - * Opens up an editor to let the user view the given FgsRunner. + * Opens up an editor to let the user view the given FgesRunner. */ - public FgsSearchEditor(FgsRunner runner) { + public FgesSearchEditor(FgesRunner runner) { super(runner, "Result forbid_latent_common_causes"); } - public FgsSearchEditor(WFgsRunner runner) { + public FgesSearchEditor(WFgesRunner runner) { super(runner, "Result forbid_latent_common_causes"); } - public FgsSearchEditor(FgsMbRunner runner) { + public FgesSearchEditor(FgesMbRunner runner) { super(runner, "Result forbid_latent_common_causes"); } - public FgsSearchEditor(ImagesRunner runner) { + public FgesSearchEditor(ImagesRunner runner) { super(runner, "Result forbid_latent_common_causes"); } - public FgsSearchEditor(TsFgsRunner runner) { + public FgesSearchEditor(TsFgesRunner runner) { super(runner, "Result forbid_latent_common_causes"); } - public FgsSearchEditor(TsImagesRunner runner) { + public FgesSearchEditor(TsImagesRunner runner) { super(runner, "Result forbid_latent_common_causes"); } @@ -157,7 +157,7 @@ public void watch() { try { storeLatestWorkbenchGraph(); getAlgorithmRunner().execute(); - IFgsRunner runner = (IFgsRunner) getAlgorithmRunner(); + IFgesRunner runner = (IFgesRunner) getAlgorithmRunner(); arrangeGraphs(); gesDisplay.resetGraphs(runner.getTopGraphs()); } catch (Exception e) { @@ -358,7 +358,7 @@ public void watch() { } protected void addSpecialMenus(JMenuBar menuBar) { - if (!(getAlgorithmRunner() instanceof FgsRunner)) { + if (!(getAlgorithmRunner() instanceof FgesRunner)) { JMenu test = new JMenu("Independence"); menuBar.add(test); @@ -431,9 +431,9 @@ public void watch() { return; } - if (runner instanceof FgsRunner) { - GraphScorer scorer = ((FgsRunner) runner).getGraphScorer(); - Graph _graph = ((FgsRunner) runner).getTopGraphs().get(getIndex()).getGraph(); + if (runner instanceof FgesRunner) { + GraphScorer scorer = ((FgesRunner) runner).getGraphScorer(); + Graph _graph = ((FgesRunner) runner).getTopGraphs().get(getIndex()).getGraph(); ScoredGraphsDisplay display = new ScoredGraphsDisplay(_graph, scorer); GraphWorkbench workbench = getWorkbench(); @@ -527,10 +527,10 @@ public void actionPerformed(ActionEvent e) { //==============================PRIVATE METHODS=============================// - private FgsDisplay gesDisplay() { + private FgesDisplay gesDisplay() { Graph resultGraph = resultGraph(); List topGraphs = arrangeGraphs(); - FgsDisplay display = new FgsDisplay(resultGraph, topGraphs, this); + FgesDisplay display = new FgesDisplay(resultGraph, topGraphs, this); this.gesDisplay = display; // Superfluous? @@ -555,7 +555,7 @@ public void mouseExited(MouseEvent e) { } private List arrangeGraphs() { - IFgsRunner runner = (IFgsRunner) getAlgorithmRunner(); + IFgesRunner runner = (IFgesRunner) getAlgorithmRunner(); List topGraphs = runner.getTopGraphs(); @@ -598,7 +598,7 @@ private Graph resultGraph() { private void calcStats() { - FgsRunner runner = (FgsRunner) getAlgorithmRunner(); + FgesRunner runner = (FgesRunner) getAlgorithmRunner(); if (runner.getTopGraphs().isEmpty()) { throw new IllegalArgumentException("No patterns were recorded. Please adjust the number of " + @@ -640,8 +640,8 @@ private void calcStats() { // throw new IllegalArgumentException(""); // } - String bayesFactorsReport = ((FgsRunner) getAlgorithmRunner()).getBayesFactorsReport(dag); -// String bootstrapEdgeCountsReport = ((ImaFgsRunner) getAlgorithmRunner()).getBootstrapEdgeCountsReport(25); + String bayesFactorsReport = ((FgesRunner) getAlgorithmRunner()).getBayesFactorsReport(dag); +// String bootstrapEdgeCountsReport = ((ImagesRunner) getAlgorithmRunner()).getBootstrapEdgeCountsReport(25); JScrollPane dagWorkbenchScroll = dagWorkbenchScroll(dag); @@ -699,7 +699,7 @@ public void actionPerformed(ActionEvent actionEvent) { new WatchedProcess(owner) { public void watch() { int n = numBootstraps.getValue(); -// String bootstrapEdgeCountsReport = ((ImaFgsRunner) getAlgorithmRunner()).getBootstrapEdgeCountsReport(n); +// String bootstrapEdgeCountsReport = ((ImagesRunner) getAlgorithmRunner()).getBootstrapEdgeCountsReport(n); // bootstrapEdgeCountsScroll.setText(bootstrapEdgeCountsReport); } }; @@ -879,14 +879,14 @@ private JComponent getIndTestParamBox(Parameters params) { AlgorithmRunner algorithmRunner = getAlgorithmRunner(); - if (algorithmRunner instanceof IFgsRunner) { - IFgsRunner fgsRunner = ((IFgsRunner) algorithmRunner); - return new FgsIndTestParamsEditor(params, fgsRunner.getType()); + if (algorithmRunner instanceof IFgesRunner) { + IFgesRunner fgesRunner = ((IFgesRunner) algorithmRunner); + return new FgesIndTestParamsEditor(params, fgesRunner.getType()); } - if (algorithmRunner instanceof FgsMbRunner) { - FgsMbRunner fgsRunner = ((FgsMbRunner) algorithmRunner); - return new FgsIndTestParamsEditor(params, fgsRunner.getType()); + if (algorithmRunner instanceof FgesMbRunner) { + FgesMbRunner fgesRunner = ((FgesMbRunner) algorithmRunner); + return new FgesIndTestParamsEditor(params, fgesRunner.getType()); } throw new IllegalArgumentException(); diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/GeneralAlgorithmEditor.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/GeneralAlgorithmEditor.java index dda586ae19..07f7193604 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/GeneralAlgorithmEditor.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/GeneralAlgorithmEditor.java @@ -28,8 +28,8 @@ import edu.cmu.tetrad.algcomparison.algorithm.continuous.dag.Lingam; import edu.cmu.tetrad.algcomparison.algorithm.mixed.Mgm; import edu.cmu.tetrad.algcomparison.algorithm.multi.ImagesBDeu; +import edu.cmu.tetrad.algcomparison.algorithm.multi.ImagesCcd; import edu.cmu.tetrad.algcomparison.algorithm.multi.ImagesSemBic; -import edu.cmu.tetrad.algcomparison.algorithm.multi.TsImagesSemBic; import edu.cmu.tetrad.algcomparison.algorithm.oracle.pag.*; import edu.cmu.tetrad.algcomparison.algorithm.oracle.pattern.*; import edu.cmu.tetrad.algcomparison.algorithm.other.Glasso; @@ -97,1417 +97,1317 @@ */ public class GeneralAlgorithmEditor extends JPanel implements FinalizingEditor { - private static final long serialVersionUID = -5719467682865706447L; - - private final HashMap mappedDescriptions; - private final GeneralAlgorithmRunner runner; - private final JButton searchButton1 = new JButton("Search"); - private final JButton searchButton2 = new JButton("Search"); - private final JTabbedPane pane; - private final JComboBox algTypesDropdown = new JComboBox<>(); - private final JComboBox algNamesDropdown = new JComboBox<>(); - private final JComboBox testDropdown = new JComboBox<>(); - private final JComboBox scoreDropdown = new JComboBox<>(); - private final GraphSelectionEditor graphEditor; - private final Parameters parameters; - private final HelpSet helpSet; - private final Dimension searchButton1Size; - private Box knowledgePanel; - private JLabel whatYouChose; - - private final TetradDesktop desktop; - private HpcJobInfo hpcJobInfo; - - private String jsonResult; - private String errorResult; - - // =========================CONSTRUCTORS============================// - - /** - * Opens up an editor to let the user view the given PcRunner. - */ - public GeneralAlgorithmEditor(final GeneralAlgorithmRunner runner) { - this.runner = runner; - - String helpHS = "/resources/javahelp/TetradHelp.hs"; - - try { - URL url = this.getClass().getResource(helpHS); - this.helpSet = new HelpSet(null, url); - } catch (Exception ee) { - System.out.println("HelpSet " + ee.getMessage()); - System.out.println("HelpSet " + helpHS + " not found"); - throw new IllegalArgumentException(); - } + private static final long serialVersionUID = -5719467682865706447L; + + private final HashMap mappedDescriptions; + private final GeneralAlgorithmRunner runner; + private final JButton searchButton1 = new JButton("Search"); + private final JButton searchButton2 = new JButton("Search"); + private final JTabbedPane pane; + private final JComboBox algTypesDropdown = new JComboBox<>(); + private final JComboBox algNamesDropdown = new JComboBox<>(); + private final JComboBox testDropdown = new JComboBox<>(); + private final JComboBox scoreDropdown = new JComboBox<>(); + private final GraphSelectionEditor graphEditor; + private final Parameters parameters; + private final HelpSet helpSet; + private final Dimension searchButton1Size; + private Box knowledgePanel; + private JLabel whatYouChose; + + private final TetradDesktop desktop; + private HpcJobInfo hpcJobInfo; + + private String jsonResult; + + // =========================CONSTRUCTORS============================// + + /** + * Opens up an editor to let the user view the given PcRunner. + */ + public GeneralAlgorithmEditor(final GeneralAlgorithmRunner runner) { + this.runner = runner; + + String helpHS = "/resources/javahelp/TetradHelp.hs"; + + try { + URL url = this.getClass().getResource(helpHS); + this.helpSet = new HelpSet(null, url); + } catch (Exception ee) { + System.out.println("HelpSet " + ee.getMessage()); + System.out.println("HelpSet " + helpHS + " not found"); + throw new IllegalArgumentException(); + } - algTypesDropdown.setFont(new Font("Dialog", Font.PLAIN, 13)); - algNamesDropdown.setFont(new Font("Dialog", Font.PLAIN, 13)); - testDropdown.setFont(new Font("Dialog", Font.PLAIN, 13)); - scoreDropdown.setFont(new Font("Dialog", Font.PLAIN, 13)); - - Dimension dim = searchButton1.getPreferredSize(); - searchButton1Size = new Dimension(dim.width + 5, dim.height + 5); - - List discreteTests = new ArrayList<>(); - discreteTests.add(TestType.ChiSquare); - discreteTests.add(TestType.GSquare); - discreteTests.add(TestType.Discrete_BIC_Test); - discreteTests.add(TestType.Conditional_Gaussian_LRT); - - List continuousTests = new ArrayList<>(); - continuousTests.add(TestType.Fisher_Z); - continuousTests.add(TestType.Correlation_T); - continuousTests.add(TestType.SEM_BIC); - continuousTests.add(TestType.Conditional_Correlation); - continuousTests.add(TestType.Conditional_Gaussian_LRT); - - List mixedTests = new ArrayList<>(); - mixedTests.add(TestType.Conditional_Gaussian_LRT); - - List dsepTests = new ArrayList<>(); - dsepTests.add(TestType.D_SEPARATION); - - List discreteScores = new ArrayList<>(); - discreteScores.add(ScoreType.BDeu); - discreteScores.add(ScoreType.Discrete_BIC); - discreteScores.add(ScoreType.Conditional_Gaussian_BIC); - - List continuousScores = new ArrayList<>(); - continuousScores.add(ScoreType.SEM_BIC); - continuousScores.add(ScoreType.Conditional_Gaussian_BIC); - - List mixedScores = new ArrayList<>(); - mixedScores.add(ScoreType.Conditional_Gaussian_BIC); - - List dsepScores = new ArrayList<>(); - dsepScores.add(ScoreType.D_SEPARATION); - - final List descriptions = new ArrayList<>(); - - descriptions.add(new AlgorithmDescription(AlgName.PC, - AlgType.forbid_latent_common_causes, OracleType.Test)); - descriptions.add(new AlgorithmDescription(AlgName.CPC, - AlgType.forbid_latent_common_causes, OracleType.Test)); - descriptions.add(new AlgorithmDescription(AlgName.PCStable, - AlgType.forbid_latent_common_causes, OracleType.Test)); - descriptions.add(new AlgorithmDescription(AlgName.CPCStable, - AlgType.forbid_latent_common_causes, OracleType.Test)); - // descriptions.add(new AlgorithmDescription(AlgName.PcLocal, - // AlgType.forbid_latent_common_causes, OracleType.Test)); - descriptions.add(new AlgorithmDescription(AlgName.PcMax, - AlgType.forbid_latent_common_causes, OracleType.Test)); - descriptions.add(new AlgorithmDescription(AlgName.FGES, - AlgType.forbid_latent_common_causes, OracleType.Score)); - descriptions.add(new AlgorithmDescription(AlgName.IMaGES_BDeu, - AlgType.forbid_latent_common_causes, OracleType.None)); - descriptions.add(new AlgorithmDescription(AlgName.IMaGES_SEM_BIC, - AlgType.forbid_latent_common_causes, OracleType.None)); - // descriptions.add(new AlgorithmDescription(AlgName.PcMaxLocal, - // AlgType.forbid_latent_common_causes, OracleType.Test)); - // descriptions.add(new AlgorithmDescription(AlgName.JCPC, - // AlgType.forbid_latent_common_causes, OracleType.Test)); - descriptions.add(new AlgorithmDescription(AlgName.CCD, - AlgType.forbid_latent_common_causes, OracleType.Test)); - descriptions.add(new AlgorithmDescription(AlgName.GCCD, - AlgType.forbid_latent_common_causes, OracleType.Both)); - - descriptions.add(new AlgorithmDescription(AlgName.FCI, - AlgType.allow_latent_common_causes, OracleType.Test)); - descriptions.add(new AlgorithmDescription(AlgName.RFCI, - AlgType.allow_latent_common_causes, OracleType.Test)); - descriptions.add(new AlgorithmDescription(AlgName.CFCI, - AlgType.allow_latent_common_causes, OracleType.Test)); - descriptions.add(new AlgorithmDescription(AlgName.GFCI, - AlgType.allow_latent_common_causes, OracleType.Both)); - descriptions.add(new AlgorithmDescription(AlgName.TsFCI, - AlgType.allow_latent_common_causes, OracleType.None)); - descriptions.add(new AlgorithmDescription(AlgName.TsGFCI, - AlgType.allow_latent_common_causes, OracleType.None)); - descriptions.add(new AlgorithmDescription(AlgName.TsImages, - AlgType.allow_latent_common_causes, OracleType.None)); - // descriptions.add(new AlgorithmDescription(AlgName.FgsMeasurement, - // AlgType.forbid_latent_common_causes, OracleType.Score)); - descriptions.add(new AlgorithmDescription(AlgName.TsFCI, - AlgType.allow_latent_common_causes, OracleType.Test)); - descriptions.add(new AlgorithmDescription(AlgName.TsGFCI, - AlgType.allow_latent_common_causes, OracleType.Both)); - - descriptions.add(new AlgorithmDescription(AlgName.FgsMb, - AlgType.search_for_Markov_blankets, OracleType.Score)); - descriptions.add(new AlgorithmDescription(AlgName.MBFS, - AlgType.search_for_Markov_blankets, OracleType.Score)); - // descriptions.add(new AlgorithmDescription(AlgName.Wfgs, - // AlgType.forbid_latent_common_causes, OracleType.None)); - descriptions.add(new AlgorithmDescription(AlgName.FAS, - AlgType.produce_undirected_graphs, OracleType.Test)); - - // descriptions.add(new AlgorithmDescription(AlgName.LiNGAM, - // AlgType.DAG, OracleType.None)); - descriptions.add(new AlgorithmDescription(AlgName.MGM, - AlgType.produce_undirected_graphs, OracleType.None)); - descriptions.add(new AlgorithmDescription(AlgName.GLASSO, - AlgType.produce_undirected_graphs, OracleType.None)); - - descriptions.add(new AlgorithmDescription(AlgName.Bpc, - AlgType.search_for_structure_over_latents, OracleType.None)); - descriptions.add(new AlgorithmDescription(AlgName.Fofc, - AlgType.search_for_structure_over_latents, OracleType.None)); - descriptions.add(new AlgorithmDescription(AlgName.Ftfc, - AlgType.search_for_structure_over_latents, OracleType.None)); - - descriptions.add(new AlgorithmDescription(AlgName.EB, - AlgType.orient_pairwise, OracleType.None)); - descriptions.add(new AlgorithmDescription(AlgName.R1, - AlgType.orient_pairwise, OracleType.None)); - descriptions.add(new AlgorithmDescription(AlgName.R2, - AlgType.orient_pairwise, OracleType.None)); - descriptions.add(new AlgorithmDescription(AlgName.R3, - AlgType.orient_pairwise, OracleType.None)); - descriptions.add(new AlgorithmDescription(AlgName.R4, - AlgType.orient_pairwise, OracleType.None)); - descriptions.add(new AlgorithmDescription(AlgName.RSkew, - AlgType.orient_pairwise, OracleType.None)); - descriptions.add(new AlgorithmDescription(AlgName.RSkewE, - AlgType.orient_pairwise, OracleType.None)); - descriptions.add(new AlgorithmDescription(AlgName.Skew, - AlgType.orient_pairwise, OracleType.None)); - descriptions.add(new AlgorithmDescription(AlgName.SkewE, - AlgType.orient_pairwise, OracleType.None)); - descriptions.add(new AlgorithmDescription(AlgName.Tahn, - AlgType.orient_pairwise, OracleType.None)); - - mappedDescriptions = new HashMap<>(); - - for (AlgorithmDescription description : descriptions) { - mappedDescriptions.put(description.getAlgName(), description); - } + algTypesDropdown.setFont(new Font("Dialog", Font.PLAIN, 13)); + algNamesDropdown.setFont(new Font("Dialog", Font.PLAIN, 13)); + testDropdown.setFont(new Font("Dialog", Font.PLAIN, 13)); + scoreDropdown.setFont(new Font("Dialog", Font.PLAIN, 13)); + + Dimension dim = searchButton1.getPreferredSize(); + searchButton1Size = new Dimension(dim.width + 5, dim.height + 5); + + List discreteTests = new ArrayList<>(); + discreteTests.add(TestType.ChiSquare); + discreteTests.add(TestType.GSquare); + discreteTests.add(TestType.Discrete_BIC_Test); + discreteTests.add(TestType.Conditional_Gaussian_LRT); + + List continuousTests = new ArrayList<>(); + continuousTests.add(TestType.Fisher_Z); + continuousTests.add(TestType.Correlation_T); + continuousTests.add(TestType.SEM_BIC); + continuousTests.add(TestType.Conditional_Correlation); + continuousTests.add(TestType.Conditional_Gaussian_LRT); + + List mixedTests = new ArrayList<>(); + mixedTests.add(TestType.Conditional_Gaussian_LRT); + + List dsepTests = new ArrayList<>(); + dsepTests.add(TestType.D_SEPARATION); + + List discreteScores = new ArrayList<>(); + discreteScores.add(ScoreType.BDeu); + discreteScores.add(ScoreType.Discrete_BIC); + discreteScores.add(ScoreType.Conditional_Gaussian_BIC); + + List continuousScores = new ArrayList<>(); + continuousScores.add(ScoreType.SEM_BIC); + continuousScores.add(ScoreType.Conditional_Gaussian_BIC); + + List mixedScores = new ArrayList<>(); + mixedScores.add(ScoreType.Conditional_Gaussian_BIC); + + List dsepScores = new ArrayList<>(); + dsepScores.add(ScoreType.D_SEPARATION); + + final List descriptions = new ArrayList<>(); + + descriptions.add(new AlgorithmDescription(AlgName.PC, AlgType.forbid_latent_common_causes, OracleType.Test)); + descriptions.add(new AlgorithmDescription(AlgName.CPC, AlgType.forbid_latent_common_causes, OracleType.Test)); + descriptions + .add(new AlgorithmDescription(AlgName.PCStable, AlgType.forbid_latent_common_causes, OracleType.Test)); + descriptions + .add(new AlgorithmDescription(AlgName.CPCStable, AlgType.forbid_latent_common_causes, OracleType.Test)); + descriptions.add(new AlgorithmDescription(AlgName.PcMax, AlgType.forbid_latent_common_causes, OracleType.Test)); + descriptions.add(new AlgorithmDescription(AlgName.FGES, AlgType.forbid_latent_common_causes, OracleType.Score)); + descriptions.add( + new AlgorithmDescription(AlgName.IMaGES_BDeu, AlgType.forbid_latent_common_causes, OracleType.None)); + descriptions.add( + new AlgorithmDescription(AlgName.IMaGES_SEM_BIC, AlgType.forbid_latent_common_causes, OracleType.None)); + descriptions.add( + new AlgorithmDescription(AlgName.IMaGES_CCD, AlgType.forbid_latent_common_causes, OracleType.None)); + descriptions.add(new AlgorithmDescription(AlgName.CCD, AlgType.forbid_latent_common_causes, OracleType.Test)); + descriptions + .add(new AlgorithmDescription(AlgName.CCD_MAX, AlgType.forbid_latent_common_causes, OracleType.Test)); + + descriptions.add(new AlgorithmDescription(AlgName.FCI, AlgType.allow_latent_common_causes, OracleType.Test)); + descriptions.add(new AlgorithmDescription(AlgName.RFCI, AlgType.allow_latent_common_causes, OracleType.Test)); + descriptions.add(new AlgorithmDescription(AlgName.CFCI, AlgType.allow_latent_common_causes, OracleType.Test)); + descriptions.add(new AlgorithmDescription(AlgName.GFCI, AlgType.allow_latent_common_causes, OracleType.Both)); + descriptions.add(new AlgorithmDescription(AlgName.TsFCI, AlgType.allow_latent_common_causes, OracleType.Test)); + descriptions.add(new AlgorithmDescription(AlgName.TsGFCI, AlgType.allow_latent_common_causes, OracleType.Both)); + descriptions + .add(new AlgorithmDescription(AlgName.TsImages, AlgType.allow_latent_common_causes, OracleType.Test)); + + descriptions + .add(new AlgorithmDescription(AlgName.FgesMb, AlgType.search_for_Markov_blankets, OracleType.Score)); + descriptions.add(new AlgorithmDescription(AlgName.MBFS, AlgType.search_for_Markov_blankets, OracleType.Score)); + descriptions.add(new AlgorithmDescription(AlgName.FAS, AlgType.produce_undirected_graphs, OracleType.Test)); + + // descriptions.add(new AlgorithmDescription(AlgName.LiNGAM, + // AlgType.DAG, OracleType.None)); + descriptions.add(new AlgorithmDescription(AlgName.MGM, AlgType.produce_undirected_graphs, OracleType.None)); + descriptions.add(new AlgorithmDescription(AlgName.GLASSO, AlgType.produce_undirected_graphs, OracleType.None)); + + descriptions + .add(new AlgorithmDescription(AlgName.Bpc, AlgType.search_for_structure_over_latents, OracleType.None)); + descriptions.add( + new AlgorithmDescription(AlgName.Fofc, AlgType.search_for_structure_over_latents, OracleType.None)); + descriptions.add( + new AlgorithmDescription(AlgName.Ftfc, AlgType.search_for_structure_over_latents, OracleType.None)); + + descriptions.add(new AlgorithmDescription(AlgName.EB, AlgType.orient_pairwise, OracleType.None)); + descriptions.add(new AlgorithmDescription(AlgName.R1, AlgType.orient_pairwise, OracleType.None)); + descriptions.add(new AlgorithmDescription(AlgName.R2, AlgType.orient_pairwise, OracleType.None)); + descriptions.add(new AlgorithmDescription(AlgName.R3, AlgType.orient_pairwise, OracleType.None)); + descriptions.add(new AlgorithmDescription(AlgName.R4, AlgType.orient_pairwise, OracleType.None)); + descriptions.add(new AlgorithmDescription(AlgName.RSkew, AlgType.orient_pairwise, OracleType.None)); + descriptions.add(new AlgorithmDescription(AlgName.RSkewE, AlgType.orient_pairwise, OracleType.None)); + descriptions.add(new AlgorithmDescription(AlgName.Skew, AlgType.orient_pairwise, OracleType.None)); + descriptions.add(new AlgorithmDescription(AlgName.SkewE, AlgType.orient_pairwise, OracleType.None)); + descriptions.add(new AlgorithmDescription(AlgName.Tahn, AlgType.orient_pairwise, OracleType.None)); + + mappedDescriptions = new HashMap<>(); - this.parameters = runner.getParameters(); - graphEditor = new GraphSelectionEditor(new GraphSelectionWrapper( - runner.getGraphs(), new Parameters())); - setLayout(new BorderLayout()); - - whatYouChose = new JLabel(); - - // if (runner.getDataModelList() == null) { - // throw new NullPointerException("No data has been provided."); - // } - - List tests; - - DataModelList dataModelList = runner.getDataModelList(); - - if ((dataModelList.isEmpty() && runner.getSourceGraph() != null)) { - tests = dsepTests; - } else if (!(dataModelList.isEmpty())) { - DataModel dataSet = dataModelList.get(0); - - if (dataSet.isContinuous()) { - tests = continuousTests; - } else if (dataSet.isDiscrete()) { - tests = discreteTests; - } else if (dataSet.isMixed()) { - tests = mixedTests; - } else { - throw new IllegalArgumentException(); - } - } else { - throw new IllegalArgumentException( - "You need either some data sets or a graph as input."); - } + for (AlgorithmDescription description : descriptions) { + mappedDescriptions.put(description.getAlgName(), description); + } - for (TestType item : tests) { - testDropdown.addItem(item); - } + this.parameters = runner.getParameters(); + graphEditor = new GraphSelectionEditor(new GraphSelectionWrapper(runner.getGraphs(), new Parameters())); + setLayout(new BorderLayout()); - List scores; - - if ((dataModelList.isEmpty() && runner.getSourceGraph() != null)) { - tests = dsepTests; - } else if (!dataModelList.isEmpty()) { - DataModel dataSet = dataModelList.get(0); - - if (dataSet.isContinuous()) { - tests = continuousTests; - } else if (dataSet.isDiscrete()) { - tests = discreteTests; - } else if (dataSet.isMixed()) { - tests = mixedTests; - } else { - throw new IllegalArgumentException(); - } - } else { - throw new IllegalArgumentException( - "You need either some data sets or a graph as input."); - } + whatYouChose = new JLabel(); - if (dataModelList.isEmpty() && runner.getGraphs() != null) { - scores = dsepScores; - } else if (!(dataModelList.isEmpty())) { - DataModel dataSet = dataModelList.get(0); - - if (dataSet.isContinuous()) { - scores = continuousScores; - } else if (dataSet.isDiscrete()) { - scores = discreteScores; - } else if (dataSet.isMixed()) { - scores = mixedScores; - } else { - throw new IllegalArgumentException(); - } - } else { - throw new IllegalArgumentException( - "You need either some data sets or a graph as input."); - } + // if (runner.getDataModelList() == null) { + // throw new NullPointerException("No data has been provided."); + // } - for (ScoreType item : scores) { - scoreDropdown.addItem(item); - } + List tests; - for (AlgType item : AlgType.values()) { - algTypesDropdown.addItem(item.toString().replace("_", " ")); - } + DataModelList dataModelList = runner.getDataModelList(); - for (AlgorithmDescription description : descriptions) { - if (description.getAlgType() == getAlgType() - || getAlgType() == AlgType.ALL) { - algNamesDropdown.addItem(description.getAlgName()); - } - } + if ((dataModelList.isEmpty() && runner.getSourceGraph() != null)) { + tests = dsepTests; + } else if (!(dataModelList.isEmpty())) { + DataModel dataSet = dataModelList.get(0); - algTypesDropdown.setSelectedItem(getAlgType().toString().replace("_", - " ")); - algNamesDropdown.setSelectedItem(getAlgName()); + if (dataSet.isContinuous()) { + tests = continuousTests; + } else if (dataSet.isDiscrete()) { + tests = discreteTests; + } else if (dataSet.isMixed()) { + tests = mixedTests; + } else { + throw new IllegalArgumentException(); + } + } else { + throw new IllegalArgumentException("You need either some data sets or a graph as input."); + } - if (tests.contains(getTestType())) { - testDropdown.setSelectedItem(getTestType()); - } + for (TestType item : tests) { + testDropdown.addItem(item); + } - if (scores.contains(getScoreType())) { - scoreDropdown.setSelectedItem(getScoreType()); - } + List scores; + + if ((dataModelList.isEmpty() && runner.getSourceGraph() != null)) { + tests = dsepTests; + } else if (!dataModelList.isEmpty()) { + DataModel dataSet = dataModelList.get(0); + + if (dataSet.isContinuous()) { + tests = continuousTests; + } else if (dataSet.isDiscrete()) { + tests = discreteTests; + } else if (dataSet.isMixed()) { + tests = mixedTests; + } else { + throw new IllegalArgumentException(); + } + } else { + throw new IllegalArgumentException("You need either some data sets or a graph as input."); + } - testDropdown.setEnabled(parameters.getBoolean("testEnabled", true)); - scoreDropdown.setEnabled(parameters.getBoolean("scoreEnabled", false)); + if (dataModelList.isEmpty() && runner.getGraphs() != null) { + scores = dsepScores; + } else if (!(dataModelList.isEmpty())) { + DataModel dataSet = dataModelList.get(0); + + if (dataSet.isContinuous()) { + scores = continuousScores; + } else if (dataSet.isDiscrete()) { + scores = discreteScores; + } else if (dataSet.isMixed()) { + scores = mixedScores; + } else { + throw new IllegalArgumentException(); + } + } else { + throw new IllegalArgumentException("You need either some data sets or a graph as input."); + } + + for (ScoreType item : scores) { + scoreDropdown.addItem(item); + } - algTypesDropdown.addActionListener(new ActionListener() { - @Override - public void actionPerformed(ActionEvent e) { - algNamesDropdown.removeAllItems(); + for (AlgType item : AlgType.values()) { + algTypesDropdown.addItem(item.toString().replace("_", " ")); + } for (AlgorithmDescription description : descriptions) { - AlgType selectedItem = AlgType - .valueOf(((String) algTypesDropdown - .getSelectedItem()).replace(" ", "_")); - if (description.getAlgType() == selectedItem - || selectedItem == AlgType.ALL) { - algNamesDropdown.addItem(description.getAlgName()); - } + if (description.getAlgType() == getAlgType() || getAlgType() == AlgType.ALL) { + algNamesDropdown.addItem(description.getAlgName()); + } } - } - }); - algNamesDropdown.addActionListener(new ActionListener() { - @Override - public void actionPerformed(ActionEvent e) { - setAlgorithm(); + algTypesDropdown.setSelectedItem(getAlgType().toString().replace("_", " ")); + algNamesDropdown.setSelectedItem(getAlgName()); - @SuppressWarnings("unchecked") - JComboBox box = (JComboBox) e.getSource(); - Object selectedItem = box.getSelectedItem(); + if (tests.contains(getTestType())) { + testDropdown.setSelectedItem(getTestType()); + } - if (selectedItem != null) { - helpSet.setHomeID(selectedItem.toString()); + if (scores.contains(getScoreType())) { + scoreDropdown.setSelectedItem(getScoreType()); } - } - }); - testDropdown.addActionListener(new ActionListener() { - @Override - public void actionPerformed(ActionEvent e) { - setAlgorithm(); - } - }); + testDropdown.setEnabled(parameters.getBoolean("testEnabled", true)); + scoreDropdown.setEnabled(parameters.getBoolean("scoreEnabled", false)); + + algTypesDropdown.addActionListener(new ActionListener() { + @Override + public void actionPerformed(ActionEvent e) { + algNamesDropdown.removeAllItems(); + + for (AlgorithmDescription description : descriptions) { + AlgType selectedItem = AlgType + .valueOf(((String) algTypesDropdown.getSelectedItem()).replace(" ", "_")); + if (description.getAlgType() == selectedItem || selectedItem == AlgType.ALL) { + algNamesDropdown.addItem(description.getAlgName()); + } + } + } + }); + + algNamesDropdown.addActionListener(new ActionListener() { + @Override + public void actionPerformed(ActionEvent e) { + setAlgorithm(); + + JComboBox box = (JComboBox) e.getSource(); + Object selectedItem = box.getSelectedItem(); + + if (selectedItem != null) { + helpSet.setHomeID(selectedItem.toString()); + } + } + }); + + testDropdown.addActionListener(new ActionListener() { + @Override + public void actionPerformed(ActionEvent e) { + setAlgorithm(); + } + }); + + scoreDropdown.addActionListener(new ActionListener() { + @Override + public void actionPerformed(ActionEvent e) { + setAlgorithm(); + } + }); + + pane = new JTabbedPane(); + pane.add("Algorithm", getParametersPane()); + getAlgorithmFromInterface(); + pane.add("Output Graphs", graphEditor); + add(pane, BorderLayout.CENTER); + + if (runner.getGraphs() != null && runner.getGraphs().size() > 0) { + pane.setSelectedComponent(graphEditor); + } + + searchButton1.addActionListener(new ActionListener() { + @Override + public void actionPerformed(ActionEvent e) { + doSearch(runner); + } + }); + + searchButton2.addActionListener(new ActionListener() { + @Override + public void actionPerformed(ActionEvent e) { + doSearch(runner); + } + }); - scoreDropdown.addActionListener(new ActionListener() { - @Override - public void actionPerformed(ActionEvent e) { setAlgorithm(); - } - }); - pane = new JTabbedPane(); - pane.add("Algorithm", getParametersPane()); - getAlgorithmFromInterface(); - pane.add("Output Graphs", graphEditor); - add(pane, BorderLayout.CENTER); + this.desktop = (TetradDesktop) DesktopController.getInstance(); + } + + private Box getKnowledgePanel(GeneralAlgorithmRunner runner) { + class MyKnowledgeInput implements KnowledgeBoxInput { + + private static final long serialVersionUID = 1344090367098647696L; + + private String name; + private List variables; + private List varNames; + + public MyKnowledgeInput(List variables, List varNames) { + this.variables = variables; + this.varNames = varNames; + } + + @Override + public Graph getSourceGraph() { + return null; + } + + @Override + public Graph getResultGraph() { + return null; + } + + @Override + public void setName(String name) { + this.name = name; + } + + @Override + public String getName() { + return name; + } + + @Override + public List getVariables() { + return variables; + } + + @Override + public List getVariableNames() { + return varNames; + } + } + + List variables = null; + MyKnowledgeInput myKnowledgeInput; + + if (runner.getDataModel() != null) { + DataModelList dataModelList = runner.getDataModelList(); + if (dataModelList.size() > 0) { + variables = dataModelList.get(0).getVariables(); + } + } + + if ((variables == null || variables.isEmpty()) && runner.getSourceGraph() != null) { + variables = runner.getSourceGraph().getNodes(); + } - if (runner.getGraphs() != null && runner.getGraphs().size() > 0) { - pane.setSelectedComponent(graphEditor); + if (variables == null) { + throw new IllegalArgumentException("No source of variables!"); + } + + List varNames = new ArrayList<>(); + + for (Node node : variables) { + varNames.add(node.getName()); + } + + myKnowledgeInput = new MyKnowledgeInput(variables, varNames); + + JPanel knowledgePanel = new JPanel(); + knowledgePanel.setLayout(new BorderLayout()); + KnowledgeBoxModel knowledgeBoxModel = new KnowledgeBoxModel(new KnowledgeBoxInput[] { myKnowledgeInput }, + parameters); + knowledgeBoxModel.setKnowledge(runner.getKnowledge()); + KnowledgeBoxEditor knowledgeEditor = new KnowledgeBoxEditor(knowledgeBoxModel); + Box f = Box.createVerticalBox(); + f.add(knowledgeEditor); + Box g = Box.createHorizontalBox(); + g.add(Box.createHorizontalGlue()); + g.add(searchButton2); + g.add(Box.createHorizontalGlue()); + f.add(g); + return f; } - searchButton1.addActionListener(new ActionListener() { - @Override - public void actionPerformed(ActionEvent e) { - doSearch(runner); - } - }); - - searchButton2.addActionListener(new ActionListener() { - @Override - public void actionPerformed(ActionEvent e) { - doSearch(runner); - } - }); - - setAlgorithm(); - - this.desktop = (TetradDesktop) DesktopController.getInstance(); - } - - private Box getKnowledgePanel(GeneralAlgorithmRunner runner) { - class MyKnowledgeInput implements KnowledgeBoxInput { - - private static final long serialVersionUID = 1344090367098647696L; - - private String name; - private List variables; - private List varNames; - - public MyKnowledgeInput(List variables, List varNames) { - this.variables = variables; - this.varNames = varNames; - } - - @Override - public Graph getSourceGraph() { - return null; - } - - @Override - public Graph getResultGraph() { - return null; - } - - @Override - public void setName(String name) { - this.name = name; - } - - @Override - public String getName() { - return name; - } - - @Override - public List getVariables() { - return variables; - } - - @Override - public List getVariableNames() { - return varNames; - } + private void doSearch(final GeneralAlgorithmRunner runner) { + new WatchedProcess((Window) getTopLevelAncestor()) { + @Override + public void watch() { + HpcAccount hpcAccount = null; + + AlgName name = (AlgName) algNamesDropdown.getSelectedItem(); + switch (name) { + case FGES: + case GFCI: + hpcAccount = showRemoteComputingOptions(name); + break; + default: + } + + if (hpcAccount == null) { + runner.execute(); + graphEditor.replace(runner.getGraphs()); + graphEditor.validate(); + firePropertyChange("modelChanged", null, null); + pane.setSelectedComponent(graphEditor); + } else { + doRemoteCompute(runner, hpcAccount); + } + } + }; } - List variables = null; - MyKnowledgeInput myKnowledgeInput; + private HpcAccount showRemoteComputingOptions(AlgName name) { + List hpcAccounts = desktop.getHpcAccountManager().getHpcAccounts(); - if (runner.getDataModel() != null) { - DataModelList dataModelList = runner.getDataModelList(); - if (dataModelList.size() > 0) { - variables = dataModelList.get(0).getVariables(); - } - } + if (hpcAccounts == null || hpcAccounts.size() == 0) { + return null; + } + + String no_answer = "No, thanks"; + String yes_answer = "Please run it on "; + + Object[] options = new String[hpcAccounts.size() + 1]; + options[0] = no_answer; + for (int i = 0; i < hpcAccounts.size(); i++) { + String connName = hpcAccounts.get(i).getConnectionName(); + options[i + 1] = yes_answer + connName; + } - if ((variables == null || variables.isEmpty()) - && runner.getSourceGraph() != null) { - variables = runner.getSourceGraph().getNodes(); + int n = JOptionPane.showOptionDialog(this, "Would you like to execute a " + name + " search in the cloud?", + "A Silly Question", JOptionPane.YES_NO_OPTION, JOptionPane.QUESTION_MESSAGE, null, options, options[0]); + if (n == 0) + return null; + return hpcAccounts.get(n - 1); } - if (variables == null) { - throw new IllegalArgumentException("No source of variables!"); + private void doRemoteCompute(final GeneralAlgorithmRunner runner, final HpcAccount hpcAccount) { + + // ********************** + // Show progress panel * + // ********************** + + Frame ancestor = (Frame) JOptionUtils.centeringComp().getTopLevelAncestor(); + final JDialog progressDialog = new JDialog(ancestor, "HPC Job Submission's Progress...", false); + + Dimension progressDim = new Dimension(500, 150); + + JTextArea progressTextArea = new JTextArea(); + progressTextArea.setPreferredSize(progressDim); + progressTextArea.setEditable(false); + + JScrollPane progressScroller = new JScrollPane(progressTextArea); + progressScroller.setAlignmentX(LEFT_ALIGNMENT); + + progressDialog.setLayout(new BorderLayout()); + progressDialog.getContentPane().add(progressScroller, BorderLayout.CENTER); + progressDialog.pack(); + Dimension screenDim = Toolkit.getDefaultToolkit().getScreenSize(); + progressDialog.setLocation((screenDim.width - progressDim.width) / 2, + (screenDim.height - progressDim.height) / 2); + progressDialog.setVisible(true); + + int totalProcesses = 4; + String newline = "\n"; + String tab = "\t"; + int progressTextLength = 0; + + DataModel dataModel = runner.getDataModel(); + + // 1. Generate temp file + Path file = null; + Path prior = null; + try { + // **************************** + // Data Preparation Progress * + // **************************** + String dataMessage = String.format("1/%1$d Data Preparation", totalProcesses); + progressTextArea.append(dataMessage); + progressTextArea.append(tab); + + progressTextLength = progressTextArea.getText().length(); + + progressTextArea.append("Preparing..."); + progressTextArea.updateUI(); + + file = Files.createTempFile("Tetrad-data-", ".txt"); + // System.out.println(file.toAbsolutePath().toString()); + List tempLine = new ArrayList<>(); + + // Header + List variables = dataModel.getVariables(); + if ((variables == null || variables.isEmpty()) && runner.getSourceGraph() != null) { + variables = runner.getSourceGraph().getNodes(); + } + + String vars = StringUtils.join(variables.toArray(), tab); + tempLine.add(vars); + + // Data + DataSet dataSet = (DataSet) dataModel; + for (int i = 0; i < dataSet.getNumRows(); i++) { + String line = null; + for (int j = 0; j < dataSet.getNumColumns(); j++) { + String cell = null; + if (dataSet.isContinuous()) { + cell = String.valueOf(dataSet.getDouble(i, j)); + } else { + cell = String.valueOf(dataSet.getInt(i, j)); + } + if (line == null) { + line = cell; + } else { + line = line + "\t" + cell; + } + } + tempLine.add(line); + } + + // for (String line : tempLine) { + // System.out.println(line); + // } + + Files.write(file, tempLine); + + // Get file's MD5 hash and use it as its identifier + String datasetMd5 = MessageDigestHash.computeMD5Hash(file); + + progressTextArea.replaceRange("Done", progressTextLength, progressTextArea.getText().length()); + progressTextArea.append(newline); + progressTextArea.updateUI(); + + // *************************************** + // Prior Knowledge Preparation Progress * + // *************************************** + String priorMessage = String.format("2/%1$d Prior Knowledge Preparation", totalProcesses); + progressTextArea.append(priorMessage); + progressTextArea.append(tab); + + progressTextLength = progressTextArea.getText().length(); + + progressTextArea.append("Preparing..."); + progressTextArea.updateUI(); + + // 2. Generate temp prior knowledge file + Knowledge2 knowledge = (Knowledge2) dataModel.getKnowledge(); + if (knowledge != null && !knowledge.isEmpty()) { + prior = Files.createTempFile(file.getFileName().toString(), ".prior"); + knowledge.saveKnowledge(Files.newBufferedWriter(prior)); + + progressTextArea.replaceRange("Done", progressTextLength, progressTextArea.getText().length()); + progressTextArea.append(newline); + progressTextArea.updateUI(); + } else { + progressTextArea.replaceRange("Skipped", progressTextLength, progressTextArea.getText().length()); + progressTextArea.append(newline); + progressTextArea.updateUI(); + } + // Get knowledge file's MD5 hash and use it as its identifier + String priorKnowledgeMd5 = null; + if (prior != null) { + priorKnowledgeMd5 = MessageDigestHash.computeMD5Hash(prior); + } + + // ******************************************* + // Algorithm Parameter Preparation Progress * + // ******************************************* + String algorMessage = String.format("3/%1$d Algorithm Preparation", totalProcesses); + progressTextArea.append(algorMessage); + progressTextArea.append(tab); + + progressTextLength = progressTextArea.getText().length(); + + progressTextArea.append("Preparing..."); + progressTextArea.updateUI(); + + // 3.1 Algorithm name + String algorithmName = AbstractAlgorithmRequest.FGES; + Algorithm algorithm = runner.getAlgorithm(); + System.out.println("Algorithm: " + algorithm.getDescription()); + AlgName name = (AlgName) algNamesDropdown.getSelectedItem(); + switch (name) { + case FGES: + algorithmName = AbstractAlgorithmRequest.FGES; + if (dataModel.isDiscrete()) { + algorithmName = AbstractAlgorithmRequest.FGES_DISCRETE; + } + break; + case GFCI: + algorithmName = AbstractAlgorithmRequest.GFCI; + if (dataModel.isDiscrete()) { + algorithmName = AbstractAlgorithmRequest.GFCI_DISCRETE; + } + break; + default: + return; + } + + // 3.2 Parameters + AlgorithmParamRequest algorithmParamRequest = new AlgorithmParamRequest(); + + // Dataset and Prior paths + String datasetPath = file.toAbsolutePath().toString(); + System.out.println(datasetPath); + algorithmParamRequest.setDatasetPath(datasetPath); + algorithmParamRequest.setDatasetMd5(datasetMd5); + if (prior != null) { + String priorKnowledgePath = prior.toAbsolutePath().toString(); + System.out.println(priorKnowledgePath); + algorithmParamRequest.setPriorKnowledgePath(priorKnowledgePath); + algorithmParamRequest.setPriorKnowledgeMd5(priorKnowledgeMd5); + } + + // VariableType + if (dataModel.isContinuous()) { + algorithmParamRequest.setVariableType("continuous"); + } else { + algorithmParamRequest.setVariableType("discrete"); + } + + // FileDelimiter + String fileDelimiter = "tab"; // Pre-determined + algorithmParamRequest.setFileDelimiter(fileDelimiter); + + // Default Data Validation Parameters + DataValidation dataValidation = new DataValidation(); + dataValidation.setUniqueVarName(true); + if (dataModel.isContinuous()) { + dataValidation.setNonZeroVariance(true); + } else { + dataValidation.setCategoryLimit(true); + } + algorithmParamRequest.setDataValidation(dataValidation); + + List AlgorithmParameters = new ArrayList<>(); + + Parameters parameters = runner.getParameters(); + List parameterNames = runner.getAlgorithm().getParameters(); + for (String parameter : parameterNames) { + String value = parameters.get(parameter).toString(); + System.out.println("parameter: " + parameter + "\tvalue: " + value); + if (value != null) { + AlgorithmParameter algorParam = new AlgorithmParameter(); + algorParam.setParameter(parameter); + algorParam.setValue(value); + AlgorithmParameters.add(algorParam); + } + } + + algorithmParamRequest.setAlgorithmParameters(AlgorithmParameters); + + String maxHeapSize = null; + do { + maxHeapSize = JOptionPane.showInputDialog(progressDialog, "Enter Your Request Java Max Heap Size (GB):", + "5"); + } while (maxHeapSize != null && !StringUtils.isNumeric(maxHeapSize)); + + if (maxHeapSize != null) { + JvmOption jvmOption = new JvmOption(); + jvmOption.setParameter("maxHeapSize"); + jvmOption.setValue(maxHeapSize); + List jvmOptions = new ArrayList<>(); + jvmOptions.add(jvmOption); + algorithmParamRequest.setJvmOptions(jvmOptions); + } + + progressTextArea.replaceRange("Done", progressTextLength, progressTextArea.getText().length()); + progressTextArea.append(newline); + progressTextArea.updateUI(); + + // ******************************** + // Adding HPC Job Queue Progress * + // ******************************** + String dbMessage = String.format("4/%1$d HPC Job Queue Submission", totalProcesses); + progressTextArea.append(dbMessage); + progressTextArea.append(tab); + + progressTextLength = progressTextArea.getText().length(); + + progressTextArea.append("Preparing..."); + progressTextArea.updateUI(); + + HpcJobManager hpcJobManager = desktop.getHpcJobManager(); + + // 4.1 Save HpcJobInfo + hpcJobInfo = new HpcJobInfo(); + hpcJobInfo.setAlgorithmName(algorithmName); + hpcJobInfo.setAlgorithmParamRequest(algorithmParamRequest); + hpcJobInfo.setStatus(-1); + hpcJobInfo.setHpcAccount(hpcAccount); + hpcJobManager.submitNewHpcJobToQueue(hpcJobInfo, this); + + progressTextArea.replaceRange("Done", progressTextLength, progressTextArea.getText().length()); + progressTextArea.append(newline); + progressTextArea.updateUI(); + + this.jsonResult = null; + + JOptionPane.showMessageDialog(ancestor, "The " + hpcJobInfo.getAlgorithmName() + " job on the " + + hpcJobInfo.getHpcAccount().getConnectionName() + " node is in the queue successfully!"); + + } catch (IOException e1) { + e1.printStackTrace(); + } finally { + progressDialog.setVisible(false); + progressDialog.dispose(); + } + + (new HpcJobActivityAction("")).actionPerformed(null); + } - List varNames = new ArrayList<>(); + public void setAlgorithmResult(String jsonResult) { + this.jsonResult = jsonResult; - for (Node node : variables) { - varNames.add(node.getName()); + final Graph graph = JsonUtils.parseJSONObjectToTetradGraph(jsonResult); + final List graphs = new ArrayList<>(); + graphs.add(graph); + int size = runner.getGraphs().size(); + for (int index = 0; index < size; index++) { + runner.getGraphs().remove(index); + } + runner.getGraphs().add(graph); + graphEditor.replace(graphs); + graphEditor.validate(); + System.out.println("Remote graph result assigned to runner!"); + firePropertyChange("modelChanged", null, null); + pane.setSelectedComponent(graphEditor); } - myKnowledgeInput = new MyKnowledgeInput(variables, varNames); - - JPanel knowledgePanel = new JPanel(); - knowledgePanel.setLayout(new BorderLayout()); - KnowledgeBoxModel knowledgeBoxModel = new KnowledgeBoxModel( - new KnowledgeBoxInput[] { myKnowledgeInput }, parameters); - knowledgeBoxModel.setKnowledge(runner.getKnowledge()); - KnowledgeBoxEditor knowledgeEditor = new KnowledgeBoxEditor( - knowledgeBoxModel); - Box f = Box.createVerticalBox(); - f.add(knowledgeEditor); - Box g = Box.createHorizontalBox(); - g.add(Box.createHorizontalGlue()); - g.add(searchButton2); - g.add(Box.createHorizontalGlue()); - f.add(g); - return f; - } - - private void doSearch(final GeneralAlgorithmRunner runner) { - new WatchedProcess((Window) getTopLevelAncestor()) { - @Override - public void watch() { - HpcAccount hpcAccount = null; + public void setAlgorithmErrorResult(String errorResult) { + JOptionPane.showMessageDialog(desktop, jsonResult); + throw new IllegalArgumentException(errorResult); + } + public Algorithm getAlgorithmFromInterface() { AlgName name = (AlgName) algNamesDropdown.getSelectedItem(); - switch (name) { - case FGES: - case GFCI: - hpcAccount = showRemoteComputingOptions(name); - break; - default: + + if (name == null) { + throw new NullPointerException(); } - if (hpcAccount == null) { - runner.execute(); - graphEditor.replace(runner.getGraphs()); - graphEditor.validate(); - firePropertyChange("modelChanged", null, null); - pane.setSelectedComponent(graphEditor); + IndependenceWrapper independenceWrapper = getIndependenceWrapper(); + ScoreWrapper scoreWrapper = getScoreWrapper(); + + Algorithm algorithm = getAlgorithm(name, independenceWrapper, scoreWrapper); + + if (algorithm instanceof HasKnowledge) { + if (knowledgePanel == null) { + knowledgePanel = getKnowledgePanel(runner); + } + + pane.remove(graphEditor); + pane.add("Knowledge", knowledgePanel); + pane.add("Output Graphs", graphEditor); } else { - doRemoteCompute(runner, hpcAccount); + pane.remove(knowledgePanel); } - } - }; - } - - private HpcAccount showRemoteComputingOptions(AlgName name) { - List hpcAccounts = desktop.getHpcAccountManager() - .getHpcAccounts(); - if (hpcAccounts == null || hpcAccounts.size() == 0) { - return null; + return algorithm; } - String no_answer = "No, thanks"; - String yes_answer = "Please run it on "; + private Algorithm getAlgorithm(AlgName name, IndependenceWrapper independenceWrapper, ScoreWrapper scoreWrapper) { + Algorithm algorithm; - Object[] options = new String[hpcAccounts.size() + 1]; - options[0] = no_answer; - for (int i = 0; i < hpcAccounts.size(); i++) { - String connName = hpcAccounts.get(i).getConnectionName(); - options[i + 1] = yes_answer + connName; - } + switch (name) { + case FGES: + algorithm = new Fges(scoreWrapper); + +// if (runner.getSourceGraph() != null && !runner.getDataModelList().isEmpty()) { +// algorithm = new Fges(scoreWrapper, new SingleGraphAlg(runner.getSourceGraph())); +// } else { +// algorithm = new Fges(scoreWrapper); +// } + break; +// case FgesMeasurement: +// if (runner.getSourceGraph() != null && !runner.getDataModelList().isEmpty()) { +// algorithm = new FgesMeasurement(scoreWrapper, new SingleGraphAlg(runner.getSourceGraph())); +// } else { +// algorithm = new FgesMeasurement(scoreWrapper); +// } +// break; + case PC: + if (runner.getSourceGraph() != null && !runner.getDataModelList().isEmpty()) { + algorithm = new Pc(independenceWrapper, new SingleGraphAlg(runner.getSourceGraph())); + } else { + algorithm = new Pc(independenceWrapper); + } + break; + case CPC: + if (runner.getSourceGraph() != null && !runner.getDataModelList().isEmpty()) { + algorithm = new Cpc(independenceWrapper, new SingleGraphAlg(runner.getSourceGraph())); + } else { + algorithm = new Cpc(independenceWrapper); + } + break; + case CPCStable: + algorithm = new CpcStable(independenceWrapper); + break; + case PCStable: + if (runner.getSourceGraph() != null && !runner.getDataModelList().isEmpty()) { + algorithm = new PcStable(independenceWrapper, new SingleGraphAlg(runner.getSourceGraph())); + } else { + algorithm = new PcStable(independenceWrapper); + } + break; + case GFCI: + algorithm = new Gfci(independenceWrapper, scoreWrapper); + break; + case FCI: + algorithm = new Fci(independenceWrapper); +// if (runner.getSourceGraph() != null && !runner.getDataModelList().isEmpty()) { +// algorithm = new Fci(independenceWrapper, new SingleGraphAlg(runner.getSourceGraph())); +// } else { +// algorithm = new Fci(independenceWrapper); +// } + break; + case RFCI: + algorithm = new Rfci(independenceWrapper); + break; + case CFCI: + algorithm = new Cfci(independenceWrapper); + break; + case TsFCI: + algorithm = new TsFci(independenceWrapper); +// if (runner.getSourceGraph() != null && !runner.getDataModelList().isEmpty()) { +// algorithm = new TsFci(independenceWrapper, new SingleGraphAlg(runner.getSourceGraph())); +// } else { +// algorithm = new TsFci(independenceWrapper); +// } + break; + case TsGFCI: + algorithm = new TsGfci(independenceWrapper, scoreWrapper); + break; + case TsImages: + algorithm = new TsImages(scoreWrapper); + break; + case CCD: + algorithm = new Ccd(independenceWrapper); + break; + case CCD_MAX: + algorithm = new CcdMax(independenceWrapper); + break; + case FAS: + algorithm = new FAS(independenceWrapper); + break; + case FgesMb: + algorithm = new FgesMb(scoreWrapper); +// if (runner.getSourceGraph() != null && !runner.getDataModelList().isEmpty()) { +// algorithm = new FgesMb(scoreWrapper, new SingleGraphAlg(runner.getSourceGraph())); +// } else { +// algorithm = new FgesMb(scoreWrapper); +// } + break; + case MBFS: + algorithm = new MBFS(independenceWrapper); + break; + case PcMax: + algorithm = new PcMax(independenceWrapper); + break; + case JCPC: + algorithm = new Jcpc(independenceWrapper, scoreWrapper); + break; + case LiNGAM: + algorithm = new Lingam(); + break; + case MGM: + algorithm = new Mgm(); + break; + case IMaGES_BDeu: + algorithm = new ImagesBDeu(); + break; + case IMaGES_SEM_BIC: + algorithm = new ImagesSemBic(); + break; + case IMaGES_CCD: + algorithm = new ImagesCcd(); + break; + case GLASSO: + algorithm = new Glasso(); + break; + case Bpc: + algorithm = new Bpc(); + break; + case Fofc: + algorithm = new Fofc(); + break; + case Ftfc: + algorithm = new Ftfc(); + break; + + // LOFS algorithms. + case EB: + algorithm = new EB(new SingleGraphAlg(runner.getSourceGraph())); + break; + case R1: + algorithm = new R1(new SingleGraphAlg(runner.getSourceGraph())); + break; + case R2: + algorithm = new R2(new SingleGraphAlg(runner.getSourceGraph())); + break; + case R3: + algorithm = new R3(new SingleGraphAlg(runner.getSourceGraph())); + break; + case R4: + algorithm = new R4(new SingleGraphAlg(runner.getSourceGraph())); + break; + case RSkew: + algorithm = new RSkew(new SingleGraphAlg(runner.getSourceGraph())); + break; + case RSkewE: + algorithm = new RSkewE(new SingleGraphAlg(runner.getSourceGraph())); + break; + case Skew: + algorithm = new Skew(new SingleGraphAlg(runner.getSourceGraph())); + break; + case SkewE: + algorithm = new SkewE(new SingleGraphAlg(runner.getSourceGraph())); + break; + case Tahn: + algorithm = new Tanh(new SingleGraphAlg(runner.getSourceGraph())); + break; + + default: + throw new IllegalArgumentException("Please configure that algorithm: " + name); - int n = JOptionPane - .showOptionDialog(this, "Would you like to execute a " + name - + " search in the cloud?", "A Silly Question", - JOptionPane.YES_NO_OPTION, - JOptionPane.QUESTION_MESSAGE, null, options, options[0]); - if (n == 0) - return null; - return hpcAccounts.get(n - 1); - } - - private void doRemoteCompute(final GeneralAlgorithmRunner runner, - final HpcAccount hpcAccount) { - - // ********************** - // Show progress panel * - // ********************** - - Frame ancestor = (Frame) JOptionUtils.centeringComp() - .getTopLevelAncestor(); - final JDialog progressDialog = new JDialog(ancestor, - "HPC Job Submission's Progress...", false); - - Dimension progressDim = new Dimension(500, 150); - - JTextArea progressTextArea = new JTextArea(); - progressTextArea.setPreferredSize(progressDim); - progressTextArea.setEditable(false); - - JScrollPane progressScroller = new JScrollPane(progressTextArea); - progressScroller.setAlignmentX(LEFT_ALIGNMENT); - - progressDialog.setLayout(new BorderLayout()); - progressDialog.getContentPane().add(progressScroller, - BorderLayout.CENTER); - progressDialog.pack(); - Dimension screenDim = Toolkit.getDefaultToolkit().getScreenSize(); - progressDialog.setLocation((screenDim.width - progressDim.width) / 2, - (screenDim.height - progressDim.height) / 2); - progressDialog.setVisible(true); - - int totalProcesses = 4; - String newline = "\n"; - String tab = "\t"; - int progressTextLength = 0; - - DataModel dataModel = runner.getDataModel(); - - // 1. Generate temp file - Path file = null; - Path prior = null; - try { - // **************************** - // Data Preparation Progress * - // **************************** - String dataMessage = String.format("1/%1$d Data Preparation", - totalProcesses); - progressTextArea.append(dataMessage); - progressTextArea.append(tab); - - progressTextLength = progressTextArea.getText().length(); - - progressTextArea.append("Preparing..."); - progressTextArea.updateUI(); - - file = Files.createTempFile("Tetrad-data-", ".txt"); - // System.out.println(file.toAbsolutePath().toString()); - List tempLine = new ArrayList<>(); - - // Header - List variables = dataModel.getVariables(); - if ((variables == null || variables.isEmpty()) - && runner.getSourceGraph() != null) { - variables = runner.getSourceGraph().getNodes(); - } - - String vars = StringUtils.join(variables.toArray(), tab); - tempLine.add(vars); - - // Data - DataSet dataSet = (DataSet) dataModel; - for (int i = 0; i < dataSet.getNumRows(); i++) { - String line = null; - for (int j = 0; j < dataSet.getNumColumns(); j++) { - String cell = null; - if (dataSet.isContinuous()) { - cell = String.valueOf(dataSet.getDouble(i, j)); - } else { - cell = String.valueOf(dataSet.getInt(i, j)); - } - if (line == null) { - line = cell; - } else { - line = line + "\t" + cell; - } - } - tempLine.add(line); - } - - // for (String line : tempLine) { - // System.out.println(line); - // } - - Files.write(file, tempLine); - - // Get file's MD5 hash and use it as its identifier - String datasetMd5 = MessageDigestHash.computeMD5Hash(file); - - progressTextArea.replaceRange("Done", progressTextLength, - progressTextArea.getText().length()); - progressTextArea.append(newline); - progressTextArea.updateUI(); - - // *************************************** - // Prior Knowledge Preparation Progress * - // *************************************** - String priorMessage = String.format( - "2/%1$d Prior Knowledge Preparation", totalProcesses); - progressTextArea.append(priorMessage); - progressTextArea.append(tab); - - progressTextLength = progressTextArea.getText().length(); - - progressTextArea.append("Preparing..."); - progressTextArea.updateUI(); - - // 2. Generate temp prior knowledge file - Knowledge2 knowledge = (Knowledge2) dataModel.getKnowledge(); - if (knowledge != null && !knowledge.isEmpty()) { - prior = Files.createTempFile(file.getFileName().toString(), - ".prior"); - knowledge.saveKnowledge(Files.newBufferedWriter(prior)); - - progressTextArea.replaceRange("Done", progressTextLength, - progressTextArea.getText().length()); - progressTextArea.append(newline); - progressTextArea.updateUI(); - } else { - progressTextArea.replaceRange("Skipped", progressTextLength, - progressTextArea.getText().length()); - progressTextArea.append(newline); - progressTextArea.updateUI(); - } - // Get knowledge file's MD5 hash and use it as its identifier - String priorKnowledgeMd5 = null; - if (prior != null) { - priorKnowledgeMd5 = MessageDigestHash.computeMD5Hash(prior); - } - - // ******************************************* - // Algorithm Parameter Preparation Progress * - // ******************************************* - String algorMessage = String.format("3/%1$d Algorithm Preparation", - totalProcesses); - progressTextArea.append(algorMessage); - progressTextArea.append(tab); - - progressTextLength = progressTextArea.getText().length(); - - progressTextArea.append("Preparing..."); - progressTextArea.updateUI(); - - // 3.1 Algorithm name - String algorithmName = AbstractAlgorithmRequest.FGES; - Algorithm algorithm = runner.getAlgorithm(); - System.out.println("Algorithm: " + algorithm.getDescription()); - AlgName name = (AlgName) algNamesDropdown.getSelectedItem(); - switch (name) { - case FGES: - algorithmName = AbstractAlgorithmRequest.FGES; - if (dataModel.isDiscrete()) { - algorithmName = AbstractAlgorithmRequest.FGES_DISCRETE; - } - break; - case GFCI: - algorithmName = AbstractAlgorithmRequest.GFCI; - if (dataModel.isDiscrete()) { - algorithmName = AbstractAlgorithmRequest.GFCI_DISCRETE; - } - break; - default: - return; - } - - // 3.2 Parameters - AlgorithmParamRequest algorithmParamRequest = new AlgorithmParamRequest(); - - // Dataset and Prior paths - String datasetPath = file.toAbsolutePath().toString(); - System.out.println(datasetPath); - algorithmParamRequest.setDatasetPath(datasetPath); - algorithmParamRequest.setDatasetMd5(datasetMd5); - if (prior != null) { - String priorKnowledgePath = prior.toAbsolutePath().toString(); - System.out.println(priorKnowledgePath); - algorithmParamRequest.setPriorKnowledgePath(priorKnowledgePath); - algorithmParamRequest.setPriorKnowledgeMd5(priorKnowledgeMd5); - } - - // VariableType - if (dataModel.isContinuous()) { - algorithmParamRequest.setVariableType("continuous"); - } else { - algorithmParamRequest.setVariableType("discrete"); - } - - // FileDelimiter - String fileDelimiter = "tab"; // Pre-determined - algorithmParamRequest.setFileDelimiter(fileDelimiter); - - // Default Data Validation Parameters - DataValidation dataValidation = new DataValidation(); - dataValidation.setUniqueVarName(true); - if (dataModel.isContinuous()) { - dataValidation.setNonZeroVariance(true); - } else { - dataValidation.setCategoryLimit(true); - } - algorithmParamRequest.setDataValidation(dataValidation); - - List AlgorithmParameters = new ArrayList<>(); - - Parameters parameters = runner.getParameters(); - List parameterNames = runner.getAlgorithm().getParameters(); - for (String parameter : parameterNames) { - String value = parameters.get(parameter).toString(); - System.out.println("parameter: " + parameter + "\tvalue: " - + value); - if (value != null) { - AlgorithmParameter algorParam = new AlgorithmParameter(); - algorParam.setParameter(parameter); - algorParam.setValue(value); - AlgorithmParameters.add(algorParam); } - } - - algorithmParamRequest.setAlgorithmParameters(AlgorithmParameters); - - String maxHeapSize = null; - do { - maxHeapSize = JOptionPane.showInputDialog(progressDialog, - "Enter Your Request Java Max Heap Size (GB):", "5"); - } while (maxHeapSize != null && !StringUtils.isNumeric(maxHeapSize)); - - if (maxHeapSize != null) { - JvmOption jvmOption = new JvmOption(); - jvmOption.setParameter("maxHeapSize"); - jvmOption.setValue(maxHeapSize); - List jvmOptions = new ArrayList<>(); - jvmOptions.add(jvmOption); - algorithmParamRequest.setJvmOptions(jvmOptions); - } - - progressTextArea.replaceRange("Done", progressTextLength, - progressTextArea.getText().length()); - progressTextArea.append(newline); - progressTextArea.updateUI(); - - // ******************************** - // Adding HPC Job Queue Progress * - // ******************************** - String dbMessage = String.format("4/%1$d HPC Job Queue Submission", - totalProcesses); - progressTextArea.append(dbMessage); - progressTextArea.append(tab); - - progressTextLength = progressTextArea.getText().length(); - - progressTextArea.append("Preparing..."); - progressTextArea.updateUI(); - - HpcJobManager hpcJobManager = desktop.getHpcJobManager(); - - // 4.1 Save HpcJobInfo - hpcJobInfo = new HpcJobInfo(); - hpcJobInfo.setAlgorithmName(algorithmName); - hpcJobInfo.setAlgorithmParamRequest(algorithmParamRequest); - hpcJobInfo.setStatus(-1); - hpcJobInfo.setHpcAccount(hpcAccount); - hpcJobManager.submitNewHpcJobToQueue(hpcJobInfo, this); - - progressTextArea.replaceRange("Done", progressTextLength, - progressTextArea.getText().length()); - progressTextArea.append(newline); - progressTextArea.updateUI(); - - this.jsonResult = null; - this.errorResult = null; - - JOptionPane.showMessageDialog(this, - "The " + hpcJobInfo.getAlgorithmName() + " job on the " - + hpcJobInfo.getHpcAccount().getConnectionName() - + " node is in the queue successfully!"); - - } catch (IOException e1) { - e1.printStackTrace(); - } finally { - progressDialog.setVisible(false); - progressDialog.dispose(); + return algorithm; } - (new HpcJobActivityAction("")).actionPerformed(null); + private ScoreWrapper getScoreWrapper() { + ScoreType score = (ScoreType) scoreDropdown.getSelectedItem(); + ScoreWrapper scoreWrapper; + + switch (score) { + case BDeu: + scoreWrapper = new BdeuScore(); + break; + case Conditional_Gaussian_BIC: + scoreWrapper = new ConditionalGaussianBicScore(); + break; + case Discrete_BIC: + scoreWrapper = new DiscreteBicScore(); + break; + case SEM_BIC: + scoreWrapper = new SemBicScore(); + break; + case D_SEPARATION: + scoreWrapper = new DseparationScore(new SingleGraph(runner.getSourceGraph())); + break; + default: + throw new IllegalArgumentException("Please configure that score: " + score); + } + return scoreWrapper; + } - } + private IndependenceWrapper getIndependenceWrapper() { + TestType test = (TestType) testDropdown.getSelectedItem(); + + IndependenceWrapper independenceWrapper; + + switch (test) { + case ChiSquare: + independenceWrapper = new ChiSquare(); + break; + case Conditional_Correlation: + independenceWrapper = new ConditionalCorrelation(); + break; + case Conditional_Gaussian_LRT: + independenceWrapper = new ConditionalGaussianLRT(); + break; + case Fisher_Z: + independenceWrapper = new FisherZ(); + break; + case Correlation_T: + independenceWrapper = new CorrelationT(); + break; + case GSquare: + independenceWrapper = new GSquare(); + break; + case SEM_BIC: + independenceWrapper = new SemBicTest(); + break; + case D_SEPARATION: + independenceWrapper = new DSeparationTest(new SingleGraph(runner.getSourceGraph())); + break; + default: + throw new IllegalArgumentException("Please configure that test: " + test); + } - public void setAlgorithmResult(String jsonResult) { - this.jsonResult = jsonResult; + List tests = new ArrayList<>(); - // JOptionPane.showMessageDialog(null, jsonResult); + for (DataModel dataModel : runner.getDataModelList()) { + IndependenceTest _test = independenceWrapper.getTest(dataModel, parameters); + tests.add(_test); + } - final Graph graph = JsonUtils.parseJSONObjectToTetradGraph(jsonResult); - final List graphs = new ArrayList<>(); - graphs.add(graph); - int size = runner.getGraphs().size(); - for (int index = 0; index < size; index++) { - runner.getGraphs().remove(index); + runner.setIndependenceTests(tests); + return independenceWrapper; } - runner.getGraphs().add(graph); - graphEditor.replace(graphs); - graphEditor.validate(); - System.out.println("Remote graph result assigned to runner!"); - firePropertyChange("modelChanged", null, null); - pane.setSelectedComponent(graphEditor); - } - public void setAlgorithmErrorResult(String errorResult) { - this.errorResult = errorResult; + private void setAlgorithm() { + AlgName name = (AlgName) algNamesDropdown.getSelectedItem(); + AlgorithmDescription description = mappedDescriptions.get(name); - JOptionPane.showMessageDialog(null, jsonResult); + if (name == null) { + return; + } - throw new IllegalArgumentException(errorResult); - } + TestType test = (TestType) testDropdown.getSelectedItem(); + ScoreType score = (ScoreType) scoreDropdown.getSelectedItem(); + + Algorithm algorithm = getAlgorithmFromInterface(); + + OracleType oracle = description.getOracleType(); + + if (oracle == OracleType.None) { + testDropdown.setEnabled(false); + scoreDropdown.setEnabled(false); + } else if (oracle == OracleType.Score) { + testDropdown.setEnabled(false); + scoreDropdown.setEnabled(true); + } else if (oracle == OracleType.Test) { + testDropdown.setEnabled(true); + scoreDropdown.setEnabled(false); + } else if (oracle == OracleType.Both) { + testDropdown.setEnabled(true); + scoreDropdown.setEnabled(true); + } - public Algorithm getAlgorithmFromInterface() { - AlgName name = (AlgName) algNamesDropdown.getSelectedItem(); + parameters.set("testEnabled", testDropdown.isEnabled()); + parameters.set("scoreEnabled", scoreDropdown.isEnabled()); - if (name == null) { - throw new NullPointerException(); - } + runner.setAlgorithm(algorithm); - IndependenceWrapper independenceWrapper = getIndependenceWrapper(); - ScoreWrapper scoreWrapper = getScoreWrapper(); + setAlgName(name); + setTestType(test); + setScoreType(score); + setAlgType(((String) algTypesDropdown.getSelectedItem()).replace(" ", "_")); - Algorithm algorithm = getAlgorithm(name, independenceWrapper, - scoreWrapper); + if (whatYouChose != null) { + whatYouChose.setText("You chose: " + algorithm.getDescription()); + } - if (algorithm instanceof HasKnowledge) { - if (knowledgePanel == null) { - knowledgePanel = getKnowledgePanel(runner); - } + if (pane != null) { + pane.setComponentAt(0, getParametersPane()); + } - pane.remove(graphEditor); - pane.add("Knowledge", knowledgePanel); - pane.add("Output Graphs", graphEditor); - } else { - pane.remove(knowledgePanel); } - return algorithm; - } - - private Algorithm getAlgorithm(AlgName name, - IndependenceWrapper independenceWrapper, ScoreWrapper scoreWrapper) { - Algorithm algorithm; - - switch (name) { - case FGES: - if (runner.getSourceGraph() != null - && !runner.getDataModelList().isEmpty()) { - algorithm = new Fgs(scoreWrapper, new SingleGraphAlg( - runner.getSourceGraph())); - } else { - algorithm = new Fgs(scoreWrapper); - } - break; - // case FgsMeasurement: - // if (runner.getSourceGraph() != null && - // !runner.getDataModelList().isEmpty()) { - // algorithm = new FgsMeasurement(scoreWrapper, new - // SingleGraphAlg(runner.getSourceGraph())); - // } else { - // algorithm = new FgsMeasurement(scoreWrapper); - // } - // break; - case PC: - if (runner.getSourceGraph() != null - && !runner.getDataModelList().isEmpty()) { - algorithm = new Pc(independenceWrapper, new SingleGraphAlg( - runner.getSourceGraph())); - } else { - algorithm = new Pc(independenceWrapper); - } - break; - case CPC: - if (runner.getSourceGraph() != null - && !runner.getDataModelList().isEmpty()) { - algorithm = new Cpc(independenceWrapper, new SingleGraphAlg( - runner.getSourceGraph())); - } else { - algorithm = new Cpc(independenceWrapper); - } - break; - case CPCStable: - algorithm = new CpcStable(independenceWrapper); - break; - case PCStable: - if (runner.getSourceGraph() != null - && !runner.getDataModelList().isEmpty()) { - algorithm = new PcStable(independenceWrapper, - new SingleGraphAlg(runner.getSourceGraph())); - } else { - algorithm = new PcStable(independenceWrapper); - } - break; - case GFCI: - algorithm = new Gfci(independenceWrapper, scoreWrapper); - break; - case FCI: - if (runner.getSourceGraph() != null - && !runner.getDataModelList().isEmpty()) { - algorithm = new Fci(independenceWrapper, new SingleGraphAlg( - runner.getSourceGraph())); - } else { - algorithm = new Fci(independenceWrapper); - } - break; - case RFCI: - algorithm = new Rfci(independenceWrapper); - break; - case CFCI: - algorithm = new Cfci(independenceWrapper); - break; - case TsFCI: - if (runner.getSourceGraph() != null - && !runner.getDataModelList().isEmpty()) { - algorithm = new TsFci(independenceWrapper, new SingleGraphAlg( - runner.getSourceGraph())); - } else { - algorithm = new TsFci(independenceWrapper); - } - break; - case TsGFCI: - algorithm = new TsGfci(independenceWrapper, scoreWrapper); - break; - case TsImages: - algorithm = new TsImagesSemBic(); - break; - case CCD: - algorithm = new Ccd(independenceWrapper); - break; - case GCCD: - algorithm = new GCcd(independenceWrapper, scoreWrapper); - break; - case FAS: - algorithm = new FAS(independenceWrapper); - break; - case FgsMb: - if (runner.getSourceGraph() != null - && !runner.getDataModelList().isEmpty()) { - algorithm = new FgsMb(scoreWrapper, new SingleGraphAlg( - runner.getSourceGraph())); - } else { - algorithm = new FgsMb(scoreWrapper); - } - break; - case MBFS: - algorithm = new MBFS(independenceWrapper); - break; - // case PcLocal: - // algorithm = new PcLocal(independenceWrapper); - // break; - case PcMax: - algorithm = new PcMax(independenceWrapper); - break; - case JCPC: - algorithm = new Jcpc(independenceWrapper, scoreWrapper); - break; - case LiNGAM: - algorithm = new Lingam(); - break; - case MGM: - algorithm = new Mgm(); - break; - case IMaGES_BDeu: - algorithm = new ImagesBDeu(); - break; - case IMaGES_SEM_BIC: - algorithm = new ImagesSemBic(); - break; - case GLASSO: - algorithm = new Glasso(); - break; - case Bpc: - algorithm = new Bpc(); - break; - case Fofc: - algorithm = new Fofc(); - break; - case Ftfc: - algorithm = new Ftfc(); - break; - - // LOFS algorithms. - case EB: - algorithm = new EB(new SingleGraphAlg(runner.getSourceGraph())); - break; - case R1: - algorithm = new R1(new SingleGraphAlg(runner.getSourceGraph())); - break; - case R2: - algorithm = new R2(new SingleGraphAlg(runner.getSourceGraph())); - break; - case R3: - algorithm = new R3(new SingleGraphAlg(runner.getSourceGraph())); - break; - case R4: - algorithm = new R4(new SingleGraphAlg(runner.getSourceGraph())); - break; - case RSkew: - algorithm = new RSkew(new SingleGraphAlg(runner.getSourceGraph())); - break; - case RSkewE: - algorithm = new RSkewE(new SingleGraphAlg(runner.getSourceGraph())); - break; - case Skew: - algorithm = new Skew(new SingleGraphAlg(runner.getSourceGraph())); - break; - case SkewE: - algorithm = new SkewE(new SingleGraphAlg(runner.getSourceGraph())); - break; - case Tahn: - algorithm = new Tanh(new SingleGraphAlg(runner.getSourceGraph())); - break; - - default: - throw new IllegalArgumentException( - "Please configure that algorithm: " + name); + // =============================== Public Methods + // ==================================// + + private JPanel getParametersPane() { + JPanel panel = new JPanel(); + panel.setLayout(new BorderLayout()); + + helpSet.setHomeID("tetrad_overview"); + + ParameterPanel comp = new ParameterPanel(runner.getAlgorithm().getParameters(), getParameters()); + final JScrollPane scroll = new JScrollPane(comp); + scroll.setPreferredSize(new Dimension(800, 300)); + + JButton explain1 = new JButton(new ImageIcon(ImageUtils.getImage(this, "info.png"))); + JButton explain2 = new JButton(new ImageIcon(ImageUtils.getImage(this, "info.png"))); + JButton explain3 = new JButton(new ImageIcon(ImageUtils.getImage(this, "info.png"))); + JButton explain4 = new JButton(new ImageIcon(ImageUtils.getImage(this, "info.png"))); + + explain1.setBorder(new EmptyBorder(0, 0, 0, 0)); + explain2.setBorder(new EmptyBorder(0, 0, 0, 0)); + explain3.setBorder(new EmptyBorder(0, 0, 0, 0)); + explain4.setBorder(new EmptyBorder(0, 0, 0, 0)); + + explain1.addActionListener(new ActionListener() { + @Override + public void actionPerformed(ActionEvent e) { + // helpSet.setHomeID("types_of_algorithms"); + helpSet.setHomeID("under_construction"); + HelpBroker broker = helpSet.createHelpBroker(); + ActionListener listener = new CSH.DisplayHelpFromSource(broker); + listener.actionPerformed(e); + } + }); + + explain2.addActionListener(new ActionListener() { + @Override + public void actionPerformed(ActionEvent e) { + // JComboBox box = (JComboBox) algNamesDropdown; + // String name = box.getSelectedItem().toString(); + // helpSet.setHomeID(name.toLowerCase()); + helpSet.setHomeID("under_construction"); + HelpBroker broker = helpSet.createHelpBroker(); + ActionListener listener = new CSH.DisplayHelpFromSource(broker); + listener.actionPerformed(e); + } + }); + + explain3.addActionListener(new ActionListener() { + @Override + public void actionPerformed(ActionEvent e) { + // JComboBox box = (JComboBox) testDropdown; + // String name = box.getSelectedItem().toString(); + // helpSet.setHomeID(name.toLowerCase()); + helpSet.setHomeID("under_construction"); + HelpBroker broker = helpSet.createHelpBroker(); + ActionListener listener = new CSH.DisplayHelpFromSource(broker); + listener.actionPerformed(e); + } + }); + + explain4.addActionListener(new ActionListener() { + @Override + public void actionPerformed(ActionEvent e) { + // JComboBox box = (JComboBox) scoreDropdown; + // String name = box.getSelectedItem().toString(); + // helpSet.setHomeID(name.toLowerCase()); + helpSet.setHomeID("under_construction"); + HelpBroker broker = helpSet.createHelpBroker(); + ActionListener listener = new CSH.DisplayHelpFromSource(broker); + listener.actionPerformed(e); + } + }); + + searchButton1.setPreferredSize(searchButton1Size); + searchButton1.setMaximumSize(searchButton1Size); + + searchButton1.setFont(new Font("Dialog", Font.BOLD, 16)); + + Box d3 = Box.createHorizontalBox(); + JLabel label3 = new JLabel("List Algorithms that "); + label3.setFont(new Font("Dialog", Font.BOLD, 13)); + d3.add(label3); + algTypesDropdown.setMaximumSize(algTypesDropdown.getPreferredSize()); + d3.add(algTypesDropdown); + JLabel label4 = new JLabel(" : "); + label4.setFont(new Font("Dialog", Font.BOLD, 13)); + d3.add(label4); + algNamesDropdown.setMaximumSize(algNamesDropdown.getPreferredSize()); + d3.add(algNamesDropdown); + d3.add(explain2); + d3.add(new JLabel(" ")); + d3.add(searchButton1); + d3.add(Box.createHorizontalGlue()); + + Box d1 = Box.createHorizontalBox(); + JLabel label1 = new JLabel("Test if needed:"); + label1.setFont(new Font("Dialog", Font.BOLD, 13)); + d1.add(label1); + testDropdown.setMaximumSize(testDropdown.getPreferredSize()); + d1.add(testDropdown); + d1.add(explain3); + d1.add(Box.createHorizontalGlue()); + + Box d2 = Box.createHorizontalBox(); + JLabel label2 = new JLabel("Score if needed:"); + label2.setFont(new Font("Dialog", Font.BOLD, 13)); + d2.add(label2); + scoreDropdown.setMaximumSize(scoreDropdown.getPreferredSize()); + d2.add(scoreDropdown); + d2.add(explain4); + d2.add(Box.createHorizontalGlue()); + + Box d0 = Box.createHorizontalBox(); + JLabel label0 = new JLabel("Parameters:"); + label0.setFont(new Font("Dialog", Font.BOLD, 13)); + d0.add(label0); + d0.add(Box.createHorizontalGlue()); + + Box c = Box.createVerticalBox(); + c.add(d3); + c.add(d1); + c.add(d2); + // c.add(Box.createVerticalGlue()); + c.add(d0); + c.add(Box.createVerticalStrut(10)); + c.add(scroll); + + panel.add(c, BorderLayout.CENTER); + + Algorithm algorithm = getAlgorithmFromInterface(); + runner.setAlgorithm(algorithm); + + return panel; + } + private Parameters getParameters() { + return parameters; } - return algorithm; - } - - private ScoreWrapper getScoreWrapper() { - ScoreType score = (ScoreType) scoreDropdown.getSelectedItem(); - ScoreWrapper scoreWrapper; - - switch (score) { - case BDeu: - scoreWrapper = new BdeuScore(); - break; - case Conditional_Gaussian_BIC: - scoreWrapper = new ConditionalGaussianBicScore(); - break; - case Discrete_BIC: - scoreWrapper = new DiscreteBicScore(); - break; - case SEM_BIC: - scoreWrapper = new SemBicScore(); - break; - case D_SEPARATION: - scoreWrapper = new DseparationScore(new SingleGraph( - runner.getSourceGraph())); - break; - default: - throw new IllegalArgumentException("Please configure that score: " - + score); + + private AlgType getAlgType() { + return AlgType.valueOf(parameters.getString("algType", "ALL").replace(" ", "_")); } - return scoreWrapper; - } - - private IndependenceWrapper getIndependenceWrapper() { - TestType test = (TestType) testDropdown.getSelectedItem(); - - IndependenceWrapper independenceWrapper; - - switch (test) { - case ChiSquare: - independenceWrapper = new ChiSquare(); - break; - case Conditional_Correlation: - independenceWrapper = new ConditionalCorrelation(); - break; - case Conditional_Gaussian_LRT: - independenceWrapper = new ConditionalGaussianLRT(); - break; - case Fisher_Z: - independenceWrapper = new FisherZ(); - break; - case Correlation_T: - independenceWrapper = new CorrelationT(); - break; - case GSquare: - independenceWrapper = new GSquare(); - break; - case SEM_BIC: - independenceWrapper = new SemBicTest(); - break; - case D_SEPARATION: - independenceWrapper = new DSeparationTest(new SingleGraph( - runner.getSourceGraph())); - break; - default: - throw new IllegalArgumentException("Please configure that test: " - + test); + + private void setAlgType(String algType) { + parameters.set("algType", algType.replace(" ", "_")); } - List tests = new ArrayList<>(); + private AlgName getAlgName() { + return AlgName.valueOf(parameters.getString("algName", "PC")); + } - for (DataModel dataModel : runner.getDataModelList()) { - IndependenceTest _test = independenceWrapper.getTest(dataModel, - parameters); - tests.add(_test); + private void setAlgName(AlgName algName) { + parameters.set("algName", algName.toString()); } - runner.setIndependenceTests(tests); - return independenceWrapper; - } + private TestType getTestType() { + return TestType.valueOf(parameters.getString("testType", "ChiSquare")); + } - private void setAlgorithm() { - AlgName name = (AlgName) algNamesDropdown.getSelectedItem(); - AlgorithmDescription description = mappedDescriptions.get(name); + private void setTestType(TestType testType) { + parameters.set("testType", testType.toString()); + } - if (name == null) { - return; + private ScoreType getScoreType() { + return ScoreType.valueOf(parameters.getString("scoreType", "BDeu")); } - TestType test = (TestType) testDropdown.getSelectedItem(); - ScoreType score = (ScoreType) scoreDropdown.getSelectedItem(); - - Algorithm algorithm = getAlgorithmFromInterface(); - - OracleType oracle = description.getOracleType(); - - if (oracle == OracleType.None) { - testDropdown.setEnabled(false); - scoreDropdown.setEnabled(false); - } else if (oracle == OracleType.Score) { - testDropdown.setEnabled(false); - scoreDropdown.setEnabled(true); - } else if (oracle == OracleType.Test) { - testDropdown.setEnabled(true); - scoreDropdown.setEnabled(false); - } else if (oracle == OracleType.Both) { - testDropdown.setEnabled(true); - scoreDropdown.setEnabled(true); + private void setScoreType(ScoreType scoreType) { + parameters.set("scoreType", scoreType.toString()); } - parameters.set("testEnabled", testDropdown.isEnabled()); - parameters.set("scoreEnabled", scoreDropdown.isEnabled()); + @Override + public boolean finalizeEditor() { + List graphs = runner.getGraphs(); - runner.setAlgorithm(algorithm); + // Remote search mode - setAlgName(name); - setTestType(test); - setScoreType(score); - setAlgType(((String) algTypesDropdown.getSelectedItem()).replace(" ", - "_")); + if (hpcJobInfo == null && (graphs == null || graphs.isEmpty())) { + int option = JOptionPane.showConfirmDialog(this, "You have not performed a search. Close anyway?", "Close?", + JOptionPane.YES_NO_OPTION); + return option == JOptionPane.YES_OPTION; + } - if (whatYouChose != null) { - whatYouChose.setText("You chose: " + algorithm.getDescription()); + return true; } - if (pane != null) { - pane.setComponentAt(0, getParametersPane()); - } + private class AlgorithmDescription { + private AlgName algName; + private AlgType algType; + private OracleType oracleType; - } - - // =============================== Public Methods - // ==================================// - - private JPanel getParametersPane() { - JPanel panel = new JPanel(); - panel.setLayout(new BorderLayout()); - - helpSet.setHomeID("tetrad_overview"); - - ParameterPanel comp = new ParameterPanel(runner.getAlgorithm() - .getParameters(), getParameters()); - final JScrollPane scroll = new JScrollPane(comp); - scroll.setPreferredSize(new Dimension(800, 300)); - - JButton explain1 = new JButton(new ImageIcon(ImageUtils.getImage(this, - "info.png"))); - JButton explain2 = new JButton(new ImageIcon(ImageUtils.getImage(this, - "info.png"))); - JButton explain3 = new JButton(new ImageIcon(ImageUtils.getImage(this, - "info.png"))); - JButton explain4 = new JButton(new ImageIcon(ImageUtils.getImage(this, - "info.png"))); - - explain1.setBorder(new EmptyBorder(0, 0, 0, 0)); - explain2.setBorder(new EmptyBorder(0, 0, 0, 0)); - explain3.setBorder(new EmptyBorder(0, 0, 0, 0)); - explain4.setBorder(new EmptyBorder(0, 0, 0, 0)); - - explain1.addActionListener(new ActionListener() { - @Override - public void actionPerformed(ActionEvent e) { - // helpSet.setHomeID("types_of_algorithms"); - helpSet.setHomeID("under_construction"); - HelpBroker broker = helpSet.createHelpBroker(); - ActionListener listener = new CSH.DisplayHelpFromSource(broker); - listener.actionPerformed(e); - } - }); - - explain2.addActionListener(new ActionListener() { - @Override - public void actionPerformed(ActionEvent e) { - // JComboBox box = (JComboBox) algNamesDropdown; - // String name = box.getSelectedItem().toString(); - // helpSet.setHomeID(name.toLowerCase()); - helpSet.setHomeID("under_construction"); - HelpBroker broker = helpSet.createHelpBroker(); - ActionListener listener = new CSH.DisplayHelpFromSource(broker); - listener.actionPerformed(e); - } - }); - - explain3.addActionListener(new ActionListener() { - @Override - public void actionPerformed(ActionEvent e) { - // JComboBox box = (JComboBox) testDropdown; - // String name = box.getSelectedItem().toString(); - // helpSet.setHomeID(name.toLowerCase()); - helpSet.setHomeID("under_construction"); - HelpBroker broker = helpSet.createHelpBroker(); - ActionListener listener = new CSH.DisplayHelpFromSource(broker); - listener.actionPerformed(e); - } - }); - - explain4.addActionListener(new ActionListener() { - @Override - public void actionPerformed(ActionEvent e) { - // JComboBox box = (JComboBox) scoreDropdown; - // String name = box.getSelectedItem().toString(); - // helpSet.setHomeID(name.toLowerCase()); - helpSet.setHomeID("under_construction"); - HelpBroker broker = helpSet.createHelpBroker(); - ActionListener listener = new CSH.DisplayHelpFromSource(broker); - listener.actionPerformed(e); - } - }); - - searchButton1.setPreferredSize(searchButton1Size); - searchButton1.setMaximumSize(searchButton1Size); - - searchButton1.setFont(new Font("Dialog", Font.BOLD, 16)); - - Box d3 = Box.createHorizontalBox(); - JLabel label3 = new JLabel("List Algorithms that "); - label3.setFont(new Font("Dialog", Font.BOLD, 13)); - d3.add(label3); - algTypesDropdown.setMaximumSize(algTypesDropdown.getPreferredSize()); - d3.add(algTypesDropdown); - JLabel label4 = new JLabel(" : "); - label4.setFont(new Font("Dialog", Font.BOLD, 13)); - d3.add(label4); - algNamesDropdown.setMaximumSize(algNamesDropdown.getPreferredSize()); - d3.add(algNamesDropdown); - d3.add(explain2); - d3.add(new JLabel(" ")); - d3.add(searchButton1); - d3.add(Box.createHorizontalGlue()); - - Box d1 = Box.createHorizontalBox(); - JLabel label1 = new JLabel("Test if needed:"); - label1.setFont(new Font("Dialog", Font.BOLD, 13)); - d1.add(label1); - testDropdown.setMaximumSize(testDropdown.getPreferredSize()); - d1.add(testDropdown); - d1.add(explain3); - d1.add(Box.createHorizontalGlue()); - - Box d2 = Box.createHorizontalBox(); - JLabel label2 = new JLabel("Score if needed:"); - label2.setFont(new Font("Dialog", Font.BOLD, 13)); - d2.add(label2); - scoreDropdown.setMaximumSize(scoreDropdown.getPreferredSize()); - d2.add(scoreDropdown); - d2.add(explain4); - d2.add(Box.createHorizontalGlue()); - - Box d0 = Box.createHorizontalBox(); - JLabel label0 = new JLabel("Parameters:"); - label0.setFont(new Font("Dialog", Font.BOLD, 13)); - d0.add(label0); - d0.add(Box.createHorizontalGlue()); - - Box c = Box.createVerticalBox(); - c.add(d3); - c.add(d1); - c.add(d2); - // c.add(Box.createVerticalGlue()); - c.add(d0); - c.add(Box.createVerticalStrut(10)); - c.add(scroll); - - panel.add(c, BorderLayout.CENTER); - - Algorithm algorithm = getAlgorithmFromInterface(); - runner.setAlgorithm(algorithm); - - return panel; - } - - private Parameters getParameters() { - return parameters; - } - - private AlgType getAlgType() { - return AlgType.valueOf(parameters.getString("algType", "ALL").replace( - " ", "_")); - } - - private void setAlgType(String algType) { - parameters.set("algType", algType.replace(" ", "_")); - } - - private AlgName getAlgName() { - return AlgName.valueOf(parameters.getString("algName", "PC")); - } - - private void setAlgName(AlgName algName) { - parameters.set("algName", algName.toString()); - } - - private TestType getTestType() { - return TestType.valueOf(parameters.getString("testType", "ChiSquare")); - } - - private void setTestType(TestType testType) { - parameters.set("testType", testType.toString()); - } - - private ScoreType getScoreType() { - return ScoreType.valueOf(parameters.getString("scoreType", "BDeu")); - } - - private void setScoreType(ScoreType scoreType) { - parameters.set("scoreType", scoreType.toString()); - } - - @Override - public boolean finalizeEditor() { - List graphs = runner.getGraphs(); - - // Remote search mode - - if (hpcJobInfo == null && (graphs == null || graphs.isEmpty())) { - int option = JOptionPane.showConfirmDialog(this, - "You have not performed a search. Close anyway?", "Close?", - JOptionPane.YES_NO_OPTION); - return option == JOptionPane.YES_OPTION; - } + public AlgorithmDescription(AlgName name, AlgType algType, OracleType oracleType) { + this.algName = name; + this.algType = algType; + this.oracleType = oracleType; + } + + public AlgName getAlgName() { + return algName; + } + + public AlgType getAlgType() { + return algType; + } - return true; - } + public OracleType getOracleType() { + return oracleType; + } + } - private class AlgorithmDescription { - private AlgName algName; - private AlgType algType; - private OracleType oracleType; + private enum AlgName { + PC, PCStable, CPC, CPCStable, FGES, /*PcLocal,*/ PcMax, FAS, + FgesMb, MBFS, Wfges, JCPC, /*FgesMeasurement,*/ + FCI, RFCI, CFCI, GFCI, TsFCI, TsGFCI, TsImages, CCD, CCD_MAX, + LiNGAM, MGM, + IMaGES_BDeu, IMaGES_SEM_BIC, IMaGES_CCD, + Bpc, Fofc, Ftfc, + GLASSO, + EB, R1, R2, R3, R4, RSkew, RSkewE, Skew, SkewE, Tahn + } - public AlgorithmDescription(AlgName name, AlgType algType, - OracleType oracleType) { - this.algName = name; - this.algType = algType; - this.oracleType = oracleType; + private enum OracleType { + None, Test, Score, Both } - public AlgName getAlgName() { - return algName; + private enum AlgType { + ALL, forbid_latent_common_causes, allow_latent_common_causes, /* DAG, */ + search_for_Markov_blankets, produce_undirected_graphs, orient_pairwise, + search_for_structure_over_latents } - public AlgType getAlgType() { - return algType; + private enum TestType { + ChiSquare, Conditional_Correlation, Conditional_Gaussian_LRT, Fisher_Z, GSquare, + SEM_BIC, D_SEPARATION, Discrete_BIC_Test, Correlation_T } - public OracleType getOracleType() { - return oracleType; + public enum ScoreType { + BDeu, Conditional_Gaussian_BIC, Discrete_BIC, SEM_BIC, D_SEPARATION } - } - - private enum AlgName { - PC, PCStable, CPC, CPCStable, FGES, /* PcLocal, */PcMax, PcMaxLocal, FAS, FgsMb, MBFS, Wfgs, JCPC, /* - * FgsMeasurement - * , - */ - FCI, RFCI, CFCI, GFCI, TsFCI, TsGFCI, TsImages, CCD, GCCD, LiNGAM, MGM, IMaGES_BDeu, IMaGES_SEM_BIC, Bpc, Fofc, Ftfc, GLASSO, EB, R1, R2, R3, R4, RSkew, RSkewE, Skew, SkewE, Tahn - } - - private enum OracleType { - None, Test, Score, Both - } - - private enum AlgType { - ALL, forbid_latent_common_causes, allow_latent_common_causes, /* DAG, */ - search_for_Markov_blankets, produce_undirected_graphs, orient_pairwise, search_for_structure_over_latents - } - - private enum TestType { - ChiSquare, Conditional_Correlation, Conditional_Gaussian_LRT, Fisher_Z, GSquare, SEM_BIC, D_SEPARATION, Discrete_BIC_Test, Correlation_T - } - - public enum ScoreType { - BDeu, Conditional_Gaussian_BIC, Discrete_BIC, SEM_BIC, D_SEPARATION - } } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/GraphEditor.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/GraphEditor.java index 2ac1e95ac3..964b1ddad7 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/GraphEditor.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/GraphEditor.java @@ -89,8 +89,6 @@ public void propertyChange(PropertyChangeEvent evt) { int numModels = graphEditable.getNumModels(); - System.out.println("numModels = " + numModels); - if (numModels > 1) { final JComboBox comp = new JComboBox<>(); diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/GraphSelectionEditor.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/GraphSelectionEditor.java index 7412d8c3f5..f3c2340c36 100755 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/GraphSelectionEditor.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/GraphSelectionEditor.java @@ -479,7 +479,7 @@ public Graph getGraph() { public void setGraph(Graph graph) { wrapper.setGraphs(Collections.singletonList(graph)); editorPanel.reset(); - getWorkbench().setGraph(new EdgeListGraphSingleConnections()); + getWorkbench().setGraph(graph); } public void replace(List graphs) { diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/LofsSearchEditor.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/LofsSearchEditor.java index 479137fd08..2118bf23dd 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/LofsSearchEditor.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/LofsSearchEditor.java @@ -533,9 +533,9 @@ private JComponent getIndTestParamBox(Parameters params) { } if (params instanceof Parameters) { - if (getAlgorithmRunner() instanceof IFgsRunner) { - IFgsRunner fgsRunner = ((IFgsRunner) getAlgorithmRunner()); - return new FgsIndTestParamsEditor(params, fgsRunner.getType()); + if (getAlgorithmRunner() instanceof IFgesRunner) { + IFgesRunner fgesRunner = ((IFgesRunner) getAlgorithmRunner()); + return new FgesIndTestParamsEditor(params, fgesRunner.getType()); } } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/LofsSearchEditorNew.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/LofsSearchEditorNew.java index deab779813..c6d54b6424 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/LofsSearchEditorNew.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/LofsSearchEditorNew.java @@ -657,9 +657,9 @@ private JComponent getIndTestParamBox(Parameters params) { } if (params instanceof Parameters) { - if (getAlgorithmRunner() instanceof IFgsRunner) { - IFgsRunner gesRunner = ((IFgsRunner) getAlgorithmRunner()); - return new FgsIndTestParamsEditor(params, gesRunner.getType()); + if (getAlgorithmRunner() instanceof IFgesRunner) { + IFgesRunner gesRunner = ((IFgesRunner) getAlgorithmRunner()); + return new FgesIndTestParamsEditor(params, gesRunner.getType()); } } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/LogisticRegressionEditor.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/LogisticRegressionEditor.java index 31e9d7810b..28997ab639 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/LogisticRegressionEditor.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/LogisticRegressionEditor.java @@ -108,8 +108,6 @@ public void actionPerformed(ActionEvent e) { int numModels = regressionRunner.getNumModels(); - System.out.println("numModels = " + numModels); - if (numModels > 1) { final JComboBox comp = new JComboBox<>(); diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/MbSearchEditor.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/MbSearchEditor.java index 0b36ad4cc1..43c61362c7 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/MbSearchEditor.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/MbSearchEditor.java @@ -61,7 +61,7 @@ public class MbSearchEditor extends AbstractSearchEditor public MbSearchEditor(MbfsRunner runner) { super(runner, "Result MB forbid_latent_common_causes"); } - public MbSearchEditor(FgsMbRunner runner) { + public MbSearchEditor(FgesMbRunner runner) { super(runner, "Result MB forbid_latent_common_causes"); } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/PValueImproverEditor.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/PValueImproverEditor.java index d1744294ad..91f7216731 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/PValueImproverEditor.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/PValueImproverEditor.java @@ -195,7 +195,7 @@ public void actionPerformed(ActionEvent actionEvent) { gesRadioButton.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent e) { - wrapper.setAlgorithmType(PValueImproverWrapper.AlgorithmType.FGS); + wrapper.setAlgorithmType(PValueImproverWrapper.AlgorithmType.FGES); alphaField.setEnabled(false); beamWidthField.setEnabled(false); } @@ -206,7 +206,7 @@ public void actionPerformed(ActionEvent e) { alphaField.setEnabled(true); beamWidthField.setEnabled(true); } - else if (getWrapper().getAlgorithmType() == PValueImproverWrapper.AlgorithmType.FGS) { + else if (getWrapper().getAlgorithmType() == PValueImproverWrapper.AlgorithmType.FGES) { gesRadioButton.setSelected(true); alphaField.setEnabled(false); beamWidthField.setEnabled(false); diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/ParameterPanel.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/ParameterPanel.java index 44302881b9..6a26c90113 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/ParameterPanel.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/ParameterPanel.java @@ -33,6 +33,7 @@ import java.awt.event.ActionEvent; import java.awt.event.ActionListener; import java.text.DecimalFormat; +import java.util.ArrayList; import java.util.List; /** @@ -45,6 +46,16 @@ public ParameterPanel(List parametersToEdit, Parameters parameters) { Box a = Box.createHorizontalBox(); Box b = Box.createVerticalBox(); + List removeDuplicates = new ArrayList<>(); + + for (String param : parametersToEdit) { + if (!removeDuplicates.contains(param)) { + removeDuplicates.add(param); + } + } + + parametersToEdit = removeDuplicates; + a.add(b); // Box d = Box.createHorizontalBox(); diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/PcGesSearchEditor.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/PcGesSearchEditor.java index e82a713632..e8a34494de 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/PcGesSearchEditor.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/PcGesSearchEditor.java @@ -775,8 +775,8 @@ private JComponent getIndTestParamBox(Parameters params) { } if (params instanceof Parameters) { - if (getAlgorithmRunner() instanceof IFgsRunner) { - return new FgsIndTestParamsEditor(params, ((IFgsRunner) getAlgorithmRunner()).getType()); + if (getAlgorithmRunner() instanceof IFgesRunner) { + return new FgesIndTestParamsEditor(params, ((IFgesRunner) getAlgorithmRunner()).getType()); } } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/PcLocalSearchEditor.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/PcLocalSearchEditor.java index 119023e7e1..05ee92f1e0 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/PcLocalSearchEditor.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/PcLocalSearchEditor.java @@ -743,9 +743,9 @@ private JComponent getIndTestParamBox(Parameters params) { } if (params instanceof Parameters) { - if (getAlgorithmRunner() instanceof IFgsRunner) { - IFgsRunner fgsRunner = ((IFgsRunner) getAlgorithmRunner()); - return new FgsIndTestParamsEditor(params, fgsRunner.getType()); + if (getAlgorithmRunner() instanceof IFgesRunner) { + IFgesRunner fgesRunner = ((IFgesRunner) getAlgorithmRunner()); + return new FgesIndTestParamsEditor(params, fgesRunner.getType()); } } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/RegressionEditor.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/RegressionEditor.java index 9398a39e44..4cc473ed9d 100755 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/RegressionEditor.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/RegressionEditor.java @@ -149,8 +149,6 @@ public void propertyChange(PropertyChangeEvent evt) { int numModels = runner.getNumModels(); - System.out.println("numModels = " + numModels); - if (numModels > 1) { final JComboBox comp = new JComboBox<>(); diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/SemGraphEditor.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/SemGraphEditor.java index b7d342ba18..fa0251e940 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/SemGraphEditor.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/SemGraphEditor.java @@ -101,8 +101,6 @@ public void propertyChange(PropertyChangeEvent evt) { int numModels = getSemGraphWrapper().getNumModels(); - System.out.println("numModels = " + numModels); - if (numModels > 1) { final JComboBox comp = new JComboBox<>(); diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/SimulationEditor.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/SimulationEditor.java index 3d36436eda..61f9b597e9 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/SimulationEditor.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/SimulationEditor.java @@ -23,6 +23,7 @@ import edu.cmu.tetrad.algcomparison.Comparison; import edu.cmu.tetrad.algcomparison.graph.*; +import edu.cmu.tetrad.algcomparison.independence.FisherZ; import edu.cmu.tetrad.algcomparison.simulation.*; import edu.cmu.tetrad.algcomparison.utils.TakesData; import edu.cmu.tetrad.data.DataModelList; @@ -173,8 +174,6 @@ public void watch() { List graphs = new ArrayList<>(); for (int i = 0; i < _simulation.getNumDataSets(); i++) { graphs.add(_simulation.getTrueGraph(i)); - - } graphEditor.replace(graphs); @@ -467,6 +466,10 @@ private String[] getSimulationItems(Simulation simulation) { simulationItems = new String[]{ "Structural Equation Model" }; + } else if (simulation.getSimulation() instanceof LinearFisherModel) { + simulationItems = new String[]{ + "Linear Fisher Model" + }; } else if (simulation.getSimulation() instanceof StandardizedSemSimulation) { simulationItems = new String[]{ "Standardized Structural Equation Model" @@ -484,11 +487,12 @@ private String[] getSimulationItems(Simulation simulation) { + simulation.getSimulation().getClass()); } } else { - if (simulation.getSimulation() instanceof TakesData) { - simulationItems = new String[]{ - "Linear Fisher Model", - }; - } else if (simulation.getSourceGraph() != null) { +// if (simulation.getSimulation() instanceof TakesData) { +// simulationItems = new String[]{ +// "Linear Fisher Model", +// }; +// } else + if (simulation.getSourceGraph() != null) { simulationItems = new String[]{ "Bayes net", "Structural Equation Model", diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/TabularComparisonEditor.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/TabularComparisonEditor.java index 5e8bc8a19c..bddea32e4f 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/TabularComparisonEditor.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/TabularComparisonEditor.java @@ -90,8 +90,8 @@ private Box getTableDisplay() { private TextTable getTextTable(DataSet dataSet, NumberFormat nf) { TextTable table = new TextTable(dataSet.getNumRows() + 2, dataSet.getNumColumns() + 1); - - table.setToken(0, 0, "Run #"); + table.setTabDelimited(true); + table.setToken(0, 0, "Run"); for (int j = 0; j < dataSet.getNumColumns(); j++) { table.setToken(0, j + 1, dataSet.getVariable(j).getName()); diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/knowledge_editor/KnowledgeDisplayEdge.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/knowledge_editor/KnowledgeDisplayEdge.java index a708086112..75a72a1878 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/knowledge_editor/KnowledgeDisplayEdge.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/knowledge_editor/KnowledgeDisplayEdge.java @@ -207,6 +207,7 @@ public class KnowledgeDisplayEdge extends JComponent implements IDisplayEdge { */ private final PropertyChangeHandler propertyChangeHandler = new PropertyChangeHandler(); + private boolean dashed = false; //==========================CONSTRUCTORS============================// @@ -981,7 +982,17 @@ public Color getLineColor() { * @throws UnsupportedOperationException */ public void setLineColor(Color lineColor) { - throw new UnsupportedOperationException(); +// throw new UnsupportedOperationException(); + } + + @Override + public boolean getDashed() { + return dashed; + } + + @Override + public void setDashed(boolean dashed) { + this.dashed = dashed; } /** diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/knowledge_editor/KnowledgeGraph.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/knowledge_editor/KnowledgeGraph.java index cd997e3d82..c5c26f6a59 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/knowledge_editor/KnowledgeGraph.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/knowledge_editor/KnowledgeGraph.java @@ -53,6 +53,8 @@ public class KnowledgeGraph implements Graph, TetradSerializableExcluded { * @serial */ private IKnowledge knowledge; + private boolean pag; + private boolean pattern; //============================CONSTRUCTORS=============================// @@ -310,12 +312,12 @@ else if (_edge.getType() == KnowledgeModelEdge.FORBIDDEN_BY_TIERS) { } else if(_edge.getType() == KnowledgeModelEdge.FORBIDDEN_BY_GROUPS){ if(!this.knowledge.isForbiddenByGroups(from, to)){ throw new IllegalArgumentException("Edge " + from + "-->" + to + - " is not forbidden by groups."); + " is not forbidden by groups."); } } else if(_edge.getType() == KnowledgeModelEdge.REQUIRED_BY_GROUPS){ if(!this.knowledge.isRequiredByGroups(from, to)){ throw new IllegalArgumentException("Edge " + from + "-->" + to + - " is not required by groups."); + " is not required by groups."); } } @@ -392,7 +394,7 @@ else if (_edge.getType() == KnowledgeModelEdge.FORBIDDEN_BY_TIERS) { "remove edges forbidden by groups."); } else if(_edge.getType() == KnowledgeModelEdge.REQUIRED_BY_GROUPS){ throw new IllegalArgumentException("Please use the Other Groups interface to " + - "remove edges required by groups."); + "remove edges required by groups."); } return getGraph().removeEdge(edge); @@ -511,7 +513,7 @@ public boolean defNonDescendent(Node node1, Node node2) { } public boolean isDConnectedTo(Node node1, Node node2, - List conditioningNodes) { + List conditioningNodes) { return getGraph().isDConnectedTo(node1, node2, conditioningNodes); } @@ -564,6 +566,26 @@ public List getTriplesClassificationTypes() { public List> getTriplesLists(Node node) { return null; } + + @Override + public boolean isPag() { + return pag; + } + + @Override + public void setPag(boolean pag) { + this.pag = pag; + } + + @Override + public boolean isPattern() { + return pattern; + } + + @Override + public void setPattern(boolean pattern) { + this.pattern = pattern; + } } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/AbstractAlgorithmRunner.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/AbstractAlgorithmRunner.java index 6460a277b6..e2f8649164 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/AbstractAlgorithmRunner.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/AbstractAlgorithmRunner.java @@ -43,7 +43,7 @@ * @author Joseph Ramsey */ public abstract class AbstractAlgorithmRunner - implements AlgorithmRunner, ParamsResettable, Unmarshallable { + implements AlgorithmRunner, ParamsResettable, Unmarshallable, MultipleGraphSource { static final long serialVersionUID = 23L; private DataWrapper dataWrapper; diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/BayesPmWrapper.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/BayesPmWrapper.java index 151d8c586d..2d8f9c5773 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/BayesPmWrapper.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/BayesPmWrapper.java @@ -207,7 +207,7 @@ else if (params.getString("bayesPmInitializationMode", "automatic").equals("auto } public BayesPmWrapper(GraphWrapper graphWrapper, - BayesPmWrapper oldBayesPmWrapper, Parameters params) { + BayesPmWrapper oldBayesPmWrapper, Parameters params) { try { if (graphWrapper == null) { throw new NullPointerException("Graph must not be null."); @@ -294,7 +294,7 @@ public BayesPmWrapper(Graph graph, DataWrapper dataWrapper) { } public BayesPmWrapper(GraphWrapper graphWrapper, - Simulation simulation) { + Simulation simulation) { this(graphWrapper, (DataWrapper) simulation); } @@ -343,10 +343,10 @@ public BayesPmWrapper(DagWrapper dagWrapper, Parameters params) { int lowerBound, upperBound; - if (params.getString("initializationMode", "manualRetain").equals("manual")) { + if (params.getString("bayesPmInitializationMode", "manualRetain").equals("manual")) { lowerBound = upperBound = 2; } - else if (params.getString("initializationMode", "manualRetain").equals("automatic")) { + else if (params.getString("bayesPmInitializationMode", "manualRetain").equals("automatic")) { lowerBound = params.getInt("lowerBoundNumVals", 2); upperBound = params.getInt("upperBoundNumVals", 2); } @@ -359,7 +359,7 @@ else if (params.getString("initializationMode", "manualRetain").equals("automati } public BayesPmWrapper(DagWrapper dagWrapper, - BayesPmWrapper oldBayesPmWrapper, Parameters params) { + BayesPmWrapper oldBayesPmWrapper, Parameters params) { try { if (dagWrapper == null) { throw new NullPointerException("Graph must not be null."); @@ -369,17 +369,18 @@ public BayesPmWrapper(DagWrapper dagWrapper, throw new NullPointerException("BayesPm must not be null"); } - Dag graph = new Dag(dagWrapper.getDag()); + Graph graph = dagWrapper.getDag(); int lowerBound, upperBound; - if (params.getString("initializationMode", "manualRetain").equals("manual")) { + String string = params.getString("bayesPmInitializationMode", "manual"); + + if (string.equals("manual")) { lowerBound = upperBound = 2; setBayesPm(new BayesPm(graph, oldBayesPmWrapper.getBayesPm(), lowerBound, upperBound)); } - else - if (params.getString("initializationMode", "manualRetain").equals("automatic")) { + else if (string.equals("automatic")) { lowerBound = params.getInt("lowerBoundNumVals", 2); upperBound = params.getInt("upperBoundNumVals", 2); setBayesPm(graph, lowerBound, upperBound); @@ -489,21 +490,21 @@ private void log(BayesPm pm){ } - public Graph getSourceGraph() { - return getGraph(); - } + public Graph getSourceGraph() { + return getGraph(); + } public Graph getResultGraph() { return getGraph(); } public List getVariableNames() { - return getGraph().getNodeNames(); - } + return getGraph().getNodeNames(); + } - public List getVariables() { - return getGraph().getNodes(); - } + public List getVariables() { + return getGraph().getNodes(); + } public int getNumModels() { return numModels; diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/CcdRunner2.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/CcdRunner2.java index e40d9c141b..f751a00408 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/CcdRunner2.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/CcdRunner2.java @@ -41,7 +41,7 @@ public class CcdRunner2 extends AbstractAlgorithmRunner static final long serialVersionUID = 23L; private transient List listeners; - private transient GCcd ccd; + private transient CcdMax ccd; //=========================CONSTRUCTORS================================// @@ -171,8 +171,7 @@ public void execute() { if (model instanceof Graph) { IndependenceTest test = new IndTestDSep((Graph) model); Score score = new GraphScore((Graph) model); - ccd = new GCcd(test, score); - ccd.setVerbose(true); + ccd = new CcdMax(test); } else { if (model instanceof DataSet) { @@ -187,7 +186,7 @@ public void execute() { gesScore.setPenaltyDiscount(penaltyDiscount); System.out.println("Score done"); - ccd = new GCcd(test, gesScore); + ccd = new CcdMax(test); } // else if (dataSet.isDiscrete()) { // double samplePrior = ((Parameters) getParameters()).getSamplePrior(); @@ -205,7 +204,7 @@ public void execute() { gesScore.setPenaltyDiscount(penaltyDiscount); gesScore.setPenaltyDiscount(penaltyDiscount); IndependenceTest test = new IndTestScore(gesScore); - ccd = new GCcd(test, gesScore); + ccd = new CcdMax(test); } else if (model instanceof DataModelList) { DataModelList list = (DataModelList) model; @@ -217,7 +216,7 @@ public void execute() { } if (list.size() != 1) { - throw new IllegalArgumentException("FGS takes exactly one data set, covariance matrix, or initialGraph " + + throw new IllegalArgumentException("FGES takes exactly one data set, covariance matrix, or initialGraph " + "as input. For multiple data sets as input, use IMaGES."); } @@ -227,20 +226,20 @@ public void execute() { if (allContinuous(list)) { double penalty = 4;//params.getPenaltyDiscount(); - SemBicScoreImages fgsScore = new SemBicScoreImages(list); - fgsScore.setPenaltyDiscount(penalty); - IndependenceTest test = new IndTestScore(fgsScore); - ccd = new GCcd(test, fgsScore); + SemBicScoreImages fgesScore = new SemBicScoreImages(list); + fgesScore.setPenaltyDiscount(penalty); + IndependenceTest test = new IndTestScore(fgesScore); + ccd = new CcdMax(test); } // else if (allDiscrete(list)) { // double structurePrior = ((Parameters) getParameters()).getStructurePrior(); // double samplePrior = ((Parameters) getParameters()).getSamplePrior(); // -// BdeuScoreImages fgsScore = new BdeuScoreImages(list); -// fgsScore.setSamplePrior(samplePrior); -// fgsScore.setStructurePrior(structurePrior); +// BdeuScoreImages fgesScore = new BdeuScoreImages(list); +// fgesScore.setSamplePrior(samplePrior); +// fgesScore.setStructurePrior(structurePrior); // -// gfci = new GFci(fgsScore); +// gfci = new GFci(fgesScore); // } else { throw new IllegalArgumentException("Data must be either all discrete or all continuous."); @@ -253,7 +252,6 @@ public void execute() { // gfci.setInitialGraph(initialGraph); // gfci.setKnowledge(getParameters().getKnowledge()); // gfci.setNumPatternsToStore(params.getNumPatternsToSave()); - ccd.setVerbose(true); // gfci.setHeuristicSpeedup(true); // gfci.setMaxIndegree(3); // ccd.setHeuristicSpeedup(params.isFaithfulnessAssumed()); diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/EdgewiseComparisonModel.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/EdgewiseComparisonModel.java index 5e240f0c3a..7df9fc8134 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/EdgewiseComparisonModel.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/EdgewiseComparisonModel.java @@ -25,6 +25,7 @@ import edu.cmu.tetrad.data.DataSet; import edu.cmu.tetrad.graph.Dag; import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.graph.GraphUtils; import edu.cmu.tetrad.search.SearchGraphUtils; import edu.cmu.tetrad.session.DoNotAddOldModel; import edu.cmu.tetrad.session.SessionModel; @@ -33,6 +34,7 @@ import java.io.IOException; import java.io.ObjectInputStream; +import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -56,16 +58,16 @@ public final class EdgewiseComparisonModel implements SessionModel, DoNotAddOldM //=============================CONSTRUCTORS==========================// - public EdgewiseComparisonModel(GeneralAlgorithmRunner model, Parameters params) { - this(model, model.getDataWrapper(), params); - } +// public EdgewiseComparisonModel(GeneralAlgorithmRunner model, Parameters params) { +// this(model, model.getDataWrapper(), params); +// } /** * Compares the results of a PC to a reference workbench by counting errors * of omission and commission. The counts can be retrieved using the methods * countOmissionErrors and countCommissionErrors. */ - public EdgewiseComparisonModel(SessionModel model1, SessionModel model2, + public EdgewiseComparisonModel(MultipleGraphSource model1, MultipleGraphSource model2, Parameters params) { if (params == null) { throw new NullPointerException("Parameters must not be null"); @@ -73,18 +75,18 @@ public EdgewiseComparisonModel(SessionModel model1, SessionModel model2, // Need to be able to construct this object even if the models are // null. Otherwise the interface is annoying. - if (model2 == null) { - model2 = new DagWrapper(new Dag()); - } +// if (model2 == null) { +// model2 = new DagWrapper(new Dag()); +// } +// +// if (model1 == null) { +// model1 = new DagWrapper(new Dag()); +// } - if (model1 == null) { - model1 = new DagWrapper(new Dag()); - } - - if (!(model1 instanceof MultipleGraphSource) || - !(model2 instanceof MultipleGraphSource)) { - throw new IllegalArgumentException("Must be graph sources."); - } +// if (!(model1 instanceof MultipleGraphSource) || +// !(model2 instanceof MultipleGraphSource)) { +// throw new IllegalArgumentException("Must be graph sources."); +// } if (model1 instanceof GeneralAlgorithmRunner && model2 instanceof GeneralAlgorithmRunner) { throw new IllegalArgumentException("Both parents can't be general algorithm runners."); @@ -102,49 +104,82 @@ public EdgewiseComparisonModel(SessionModel model1, SessionModel model2, String referenceName = this.params.getString("referenceGraphName", null); - if (referenceName == null) { - throw new IllegalArgumentException("Must specify a reference graph."); - } else { - Object model11 = model1; - Object model21 = model2; + if (referenceName.equals(model1.getName())) { + if (model1 instanceof Simulation && model2 instanceof GeneralAlgorithmRunner) { + this.referenceGraphs = ((GeneralAlgorithmRunner) model2).getCompareGraphs(((Simulation) model1).getGraphs()); + } else if (model1 instanceof MultipleGraphSource) { + this.referenceGraphs = ((MultipleGraphSource) model1).getGraphs(); + } - if (referenceName.equals(model1.getName())) { - if (model11 instanceof MultipleGraphSource) { - this.referenceGraphs = ((MultipleGraphSource) model11).getGraphs(); - } + if (model2 instanceof MultipleGraphSource) { + this.targetGraphs = ((MultipleGraphSource) model2).getGraphs(); + } - if (model21 instanceof MultipleGraphSource) { - this.targetGraphs = ((MultipleGraphSource) model21).getGraphs(); + if (referenceGraphs.size() == 1 && targetGraphs.size() > 1) { + Graph graph = referenceGraphs.get(0); + referenceGraphs = new ArrayList<>(); + for (Graph _graph : targetGraphs) { + referenceGraphs.add(_graph); } + } - if (referenceGraphs == null) { - this.referenceGraphs = Collections.singletonList(((GraphSource) model11).getGraph()); + if (targetGraphs.size() == 1 && referenceGraphs.size() > 1) { + Graph graph = targetGraphs.get(0); + targetGraphs = new ArrayList<>(); + for (Graph _graph : referenceGraphs) { + targetGraphs.add(graph); } + } - if (targetGraphs == null) { - this.targetGraphs = Collections.singletonList(((GraphSource) model21).getGraph()); - } - } else if (referenceName.equals(model2.getName())) { - if (model21 instanceof MultipleGraphSource) { - this.referenceGraphs = ((MultipleGraphSource) model21).getGraphs(); - } + if (referenceGraphs == null) { + this.referenceGraphs = Collections.singletonList(((GraphSource) model1).getGraph()); + } - if (model11 instanceof MultipleGraphSource) { - this.targetGraphs = ((MultipleGraphSource) model11).getGraphs(); - } + if (targetGraphs == null) { + this.targetGraphs = Collections.singletonList(((GraphSource) model2).getGraph()); + } + } else if (referenceName.equals(model2.getName())) { + if (model2 instanceof Simulation && model1 instanceof GeneralAlgorithmRunner) { + this.referenceGraphs = ((GeneralAlgorithmRunner) model1).getCompareGraphs(((Simulation) model2).getGraphs()); + } else if (model1 instanceof MultipleGraphSource) { + this.referenceGraphs = ((MultipleGraphSource) model2).getGraphs(); + } - if (referenceGraphs == null) { - this.referenceGraphs = Collections.singletonList(((GraphSource) model21).getGraph()); + if (model1 instanceof MultipleGraphSource) { + this.targetGraphs = ((MultipleGraphSource) model1).getGraphs(); + } + + if (referenceGraphs.size() == 1 && targetGraphs.size() > 1) { + Graph graph = referenceGraphs.get(0); + referenceGraphs = new ArrayList<>(); + for (Graph _graph : targetGraphs) { + referenceGraphs.add(_graph); } + } - if (targetGraphs == null) { - this.targetGraphs = Collections.singletonList(((GraphSource) model11).getGraph()); + if (targetGraphs.size() == 1 && referenceGraphs.size() > 1) { + Graph graph = targetGraphs.get(0); + targetGraphs = new ArrayList<>(); + for (Graph _graph : referenceGraphs) { + targetGraphs.add(graph); } - } else { - throw new IllegalArgumentException( - "Neither of the supplied session models is named '" + - referenceName + "'."); } + + if (referenceGraphs == null) { + this.referenceGraphs = Collections.singletonList(((GraphSource) model2).getGraph()); + } + + if (targetGraphs == null) { + this.targetGraphs = Collections.singletonList(((GraphSource) model1).getGraph()); + } + } else { + throw new IllegalArgumentException( + "Neither of the supplied session models is named '" + + referenceName + "'."); + } + + for (int i = 0; i < targetGraphs.size(); i++) { + targetGraphs.set(i, GraphUtils.replaceNodes(targetGraphs.get(i), referenceGraphs.get(i).getNodes())); } if (algorithm != null) { @@ -165,31 +200,31 @@ public EdgewiseComparisonModel(SessionModel model1, SessionModel model2, } } - public EdgewiseComparisonModel(GraphWrapper referenceGraph, - AbstractAlgorithmRunner algorithmRunner, - Parameters params) { - this(referenceGraph, (SessionModel) algorithmRunner, - params); - } - - public EdgewiseComparisonModel(GraphWrapper referenceWrapper, - GraphWrapper targetWrapper, Parameters params) { - this(referenceWrapper, (SessionModel) targetWrapper, - params); - } - - public EdgewiseComparisonModel(DagWrapper referenceGraph, - AbstractAlgorithmRunner algorithmRunner, - Parameters params) { - this(referenceGraph, (SessionModel) algorithmRunner, - params); - } +// public EdgewiseComparisonModel(GraphWrapper referenceGraph, +// AbstractAlgorithmRunner algorithmRunner, +// Parameters params) { +// this(referenceGraph, (SessionModel) algorithmRunner, +// params); +// } +// +// public EdgewiseComparisonModel(GraphWrapper referenceWrapper, +// GraphWrapper targetWrapper, Parameters params) { +// this(referenceWrapper, (SessionModel) targetWrapper, +// params); +// } +// +// public EdgewiseComparisonModel(DagWrapper referenceGraph, +// AbstractAlgorithmRunner algorithmRunner, +// Parameters params) { +// this(referenceGraph, (SessionModel) algorithmRunner, +// params); +// } - public EdgewiseComparisonModel(DagWrapper referenceWrapper, - GraphWrapper targetWrapper, Parameters params) { - this(referenceWrapper, (SessionModel) targetWrapper, - params); - } +// public EdgewiseComparisonModel(DagWrapper referenceWrapper, +// GraphWrapper targetWrapper, Parameters params) { +// this(referenceWrapper, (SessionModel) targetWrapper, +// params); +// } //==============================PUBLIC METHODS========================// diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/FgsMbRunner.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/FgesMbRunner.java similarity index 80% rename from tetrad-gui/src/main/java/edu/cmu/tetradapp/model/FgsMbRunner.java rename to tetrad-gui/src/main/java/edu/cmu/tetradapp/model/FgesMbRunner.java index 1b5b6f95e8..bc5f23a46a 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/FgsMbRunner.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/FgesMbRunner.java @@ -41,11 +41,11 @@ * * @author Frank Wimberly after Joe Ramsey's PcRunner */ -public class FgsMbRunner extends AbstractAlgorithmRunner implements - IndTestProducer, GraphSource, IFgsRunner, Indexable { +public class FgesMbRunner extends AbstractAlgorithmRunner implements + IndTestProducer, GraphSource, IFgesRunner, Indexable { static final long serialVersionUID = 23L; - private transient FgsMb2 fgs; + private transient FgesMb2 fges; private int index; private ArrayList topGraphs; @@ -57,8 +57,8 @@ public class FgsMbRunner extends AbstractAlgorithmRunner implements * contain a DataSet that is either a DataSet or a DataSet or a DataList * containing either a DataSet or a DataSet as its selected model. */ - public FgsMbRunner(DataWrapper dataWrapper, Parameters params, - KnowledgeBoxModel knowledgeBoxModel) { + public FgesMbRunner(DataWrapper dataWrapper, Parameters params, + KnowledgeBoxModel knowledgeBoxModel) { super(dataWrapper, params, knowledgeBoxModel); } @@ -67,58 +67,58 @@ public FgsMbRunner(DataWrapper dataWrapper, Parameters params, * contain a DataSet that is either a DataSet or a DataSet or a DataList * containing either a DataSet or a DataSet as its selected model. */ - public FgsMbRunner(DataWrapper dataWrapper, Parameters params) { + public FgesMbRunner(DataWrapper dataWrapper, Parameters params) { super(dataWrapper, params, null); } /** * Constucts a wrapper for the given EdgeListGraph. */ - public FgsMbRunner(Graph graph, Parameters params) { + public FgesMbRunner(Graph graph, Parameters params) { super(graph, params); } /** * Constucts a wrapper for the given EdgeListGraph. */ - public FgsMbRunner(GraphWrapper dagWrapper, Parameters params) { + public FgesMbRunner(GraphWrapper dagWrapper, Parameters params) { super(dagWrapper.getGraph(), params); } /** * Constucts a wrapper for the given EdgeListGraph. */ - public FgsMbRunner(GraphWrapper dagWrapper, KnowledgeBoxModel knowledgeBoxModel, Parameters params) { + public FgesMbRunner(GraphWrapper dagWrapper, KnowledgeBoxModel knowledgeBoxModel, Parameters params) { super(dagWrapper.getGraph(), params, knowledgeBoxModel); } /** * Constucts a wrapper for the given EdgeListGraph. */ - public FgsMbRunner(DagWrapper dagWrapper, Parameters params) { + public FgesMbRunner(DagWrapper dagWrapper, Parameters params) { super(dagWrapper.getDag(), params); } /** * Constructs a wrapper for the given EdgeListGraph. */ - public FgsMbRunner(DagWrapper dagWrapper, KnowledgeBoxModel knowledgeBoxModel, Parameters params) { + public FgesMbRunner(DagWrapper dagWrapper, KnowledgeBoxModel knowledgeBoxModel, Parameters params) { super(dagWrapper.getDag(), params, knowledgeBoxModel); } - public FgsMbRunner(SemGraphWrapper dagWrapper, Parameters params) { + public FgesMbRunner(SemGraphWrapper dagWrapper, Parameters params) { super(dagWrapper.getGraph(), params); } - public FgsMbRunner(SemGraphWrapper dagWrapper, KnowledgeBoxModel knowledgeBoxModel, Parameters params) { + public FgesMbRunner(SemGraphWrapper dagWrapper, KnowledgeBoxModel knowledgeBoxModel, Parameters params) { super(dagWrapper.getGraph(), params, knowledgeBoxModel); } - public FgsMbRunner(IndependenceFactsModel model, Parameters params) { + public FgesMbRunner(IndependenceFactsModel model, Parameters params) { super(model, params, null); } - public FgsMbRunner(IndependenceFactsModel model, Parameters params, KnowledgeBoxModel knowledgeBoxModel) { + public FgesMbRunner(IndependenceFactsModel model, Parameters params, KnowledgeBoxModel knowledgeBoxModel) { super(model, params, knowledgeBoxModel); } @@ -161,10 +161,10 @@ public void execute() { if (model instanceof Graph) { GraphScore gesScore = new GraphScore((Graph) model); target = gesScore.getVariable(targetName); - fgs = new FgsMb2(gesScore); - fgs.setKnowledge((IKnowledge) getParams().get("knowledge", new Knowledge2())); - fgs.setNumPatternsToStore(params.getInt("numPatternsToSave", 1)); - fgs.setVerbose(true); + fges = new FgesMb2(gesScore); + fges.setKnowledge((IKnowledge) getParams().get("knowledge", new Knowledge2())); + fges.setNumPatternsToStore(params.getInt("numPatternsToSave", 1)); + fges.setVerbose(true); } else if (model instanceof DataSet) { DataSet dataSet = (DataSet) model; @@ -172,7 +172,7 @@ public void execute() { SemBicScore score = new SemBicScore(new CovarianceMatrixOnTheFly((DataSet) model)); target = score.getVariable(targetName); score.setPenaltyDiscount(params.getDouble("penaltyDiscount", 4)); - fgs = new FgsMb2(score); + fges = new FgesMb2(score); } else if (dataSet.isDiscrete()) { double samplePrior = 1;//((Parameters) getParameters()).getSamplePrior(); double structurePrior = 1;//((Parameters) getParameters()).getStructurePrior(); @@ -180,7 +180,7 @@ public void execute() { score.setSamplePrior(samplePrior); score.setStructurePrior(structurePrior); target = score.getVariable(targetName); - fgs = new FgsMb2(score); + fges = new FgesMb2(score); } else { throw new IllegalStateException("Data set must either be continuous or discrete."); } @@ -189,7 +189,7 @@ public void execute() { gesScore.setPenaltyDiscount(params.getDouble("alpha", 0.001)); gesScore.setPenaltyDiscount(params.getDouble("penaltyDiscount", 4)); target = gesScore.getVariable(targetName); - fgs = new FgsMb2(gesScore); + fges = new FgesMb2(gesScore); } else if (model instanceof DataModelList) { DataModelList list = (DataModelList) model; @@ -202,7 +202,7 @@ else if (model instanceof DataModelList) { } // if (list.size() != 1) { -// throw new IllegalArgumentException("FGS takes exactly one data set, covariance matrix, or initialGraph " + +// throw new IllegalArgumentException("FGES takes exactly one data set, covariance matrix, or initialGraph " + // "as input. For multiple data sets as input, use IMaGES."); // } @@ -210,29 +210,29 @@ else if (model instanceof DataModelList) { double penalty = getParams().getDouble("penaltyDiscount", 4); if (params.getBoolean("firstNontriangular", false)) { - SemBicScoreImages fgsScore = new SemBicScoreImages(list); - fgsScore.setPenaltyDiscount(penalty); - target = fgsScore.getVariable(targetName); - fgs = new FgsMb2(fgsScore); + SemBicScoreImages fgesScore = new SemBicScoreImages(list); + fgesScore.setPenaltyDiscount(penalty); + target = fgesScore.getVariable(targetName); + fges = new FgesMb2(fgesScore); } else { - SemBicScoreImages fgsScore = new SemBicScoreImages(list); - fgsScore.setPenaltyDiscount(penalty); - target = fgsScore.getVariable(targetName); - fgs = new FgsMb2(fgsScore); + SemBicScoreImages fgesScore = new SemBicScoreImages(list); + fgesScore.setPenaltyDiscount(penalty); + target = fgesScore.getVariable(targetName); + fges = new FgesMb2(fgesScore); } } else if (allDiscrete(list)) { double structurePrior = getParams().getDouble("structurePrior", 1); double samplePrior = getParams().getDouble("samplePrior", 1); - BdeuScoreImages fgsScore = new BdeuScoreImages(list); - fgsScore.setSamplePrior(samplePrior); - fgsScore.setStructurePrior(structurePrior); - target = fgsScore.getVariable(targetName); + BdeuScoreImages fgesScore = new BdeuScoreImages(list); + fgesScore.setSamplePrior(samplePrior); + fgesScore.setStructurePrior(structurePrior); + target = fgesScore.getVariable(targetName); if (params.getBoolean("firstNontriangular", false)) { - fgs = new FgsMb2(fgsScore); + fges = new FgesMb2(fgesScore); } else { - fgs = new FgsMb2(fgsScore); + fges = new FgesMb2(fgesScore); } } else { throw new IllegalArgumentException("Data must be either all discrete or all continuous."); @@ -248,7 +248,7 @@ else if (model instanceof DataModelList) { // DataModel dataModel = getDataModelList().getSelectedModel(); // ICovarianceMatrix cov; // Node target; -// FgsMb fgs; +// FgesMb fges; // // if (dataModel instanceof DataSet) { // DataSet dataSet = (DataSet) dataModel; @@ -257,14 +257,14 @@ else if (model instanceof DataModelList) { // if (dataSet.isContinuous()) { // SemBicScore gesScore = new SemBicScore(new CovarianceMatrixOnTheFly((DataSet) dataModel), // getParameters().getAlpha()); -// fgs = new FgsMb(gesScore, target); +// fges = new FgesMb(gesScore, target); // } else if (dataSet.isDiscrete()) { // double structurePrior = 1; // double samplePrior = getParameters().getAlpha(); // BDeuScore score = new BDeuScore(dataSet); // score.setSamplePrior(samplePrior); // score.setStructurePrior(structurePrior); -// fgs = new FgsMb(score, target); +// fges = new FgesMb(score, target); // } else { // throw new IllegalStateException("Data set must either be continuous or discrete."); // } @@ -273,25 +273,25 @@ else if (model instanceof DataModelList) { // SemBicScore score = new SemBicScore(cov, // getParameters().getAlpha()); // target = cov.getVariable(targetName); -// fgs = new FgsMb(score, target); +// fges = new FgesMb(score, target); // } else { // throw new IllegalArgumentException("Expecting a data set or a covariance matrix."); // } // -// fgs.setVerbose(true); -// fgs.setHeuristicSpeedup(true); -// searchGraph = fgs.search(); +// fges.setVerbose(true); +// fges.setHeuristicSpeedup(true); +// searchGraph = fges.search(); // } else { // Node target = getIndependenceTest().getVariable(targetName); // System.out.println("Target = " + target); // -// int depth = getParameters().getMaxIndegree(); +// int depth = getParameters().getMaxDegree(); // -// ScoredIndTest fgsScore = new ScoredIndTest(getIndependenceTest()); -// fgsScore.setParameter1(getParameters().getAlpha()); -// FgsMb search = new FgsMb(fgsScore, target); +// ScoredIndTest fgesScore = new ScoredIndTest(getIndependenceTest()); +// fgesScore.setParameter1(getParameters().getAlpha()); +// FgesMb search = new FgesMb(fgesScore, target); // search.setKnowledge(knowledge); -// search.setMaxIndegree(depth); +// search.setMaxDegree(depth); // search.setVerbose(true); // search.setHeuristicSpeedup(true); // searchGraph = search.search(); @@ -305,13 +305,13 @@ else if (model instanceof DataModelList) { // GraphUtils.circleLayout(searchGraph, 200, 200, 150); // } -// fgs.setInitialGraph(initialGraph); - fgs.setKnowledge((IKnowledge) getParams().get("knowledge", new Knowledge2())); - fgs.setNumPatternsToStore(params.getInt("numPatternsToSave", 1)); - fgs.setVerbose(true); -// fgs.setHeuristicSpeedup(((Parameters) params.getIndTestParams()).isFaithfulnessAssumed()); - fgs.setMaxIndegree(params.getInt("depth", -1)); - Graph graph = fgs.search(target); +// fges.setInitialGraph(initialGraph); + fges.setKnowledge((IKnowledge) getParams().get("knowledge", new Knowledge2())); + fges.setNumPatternsToStore(params.getInt("numPatternsToSave", 1)); + fges.setVerbose(true); +// fges.setHeuristicSpeedup(((Parameters) params.getIndTestParams()).isFaithfulnessAssumed()); + fges.setMaxIndegree(params.getInt("depth", -1)); + Graph graph = fges.search(target); if (getSourceGraph() != null) { GraphUtils.arrangeBySourceGraph(graph, getSourceGraph()); @@ -321,7 +321,7 @@ else if (model instanceof DataModelList) { GraphUtils.circleLayout(graph, 200, 200, 150); } - this.topGraphs = new ArrayList<>(fgs.getTopGraphs()); + this.topGraphs = new ArrayList<>(fges.getTopGraphs()); if (topGraphs.isEmpty()) { topGraphs.add(new ScoredGraph(getResultGraph(), Double.NaN)); @@ -427,7 +427,7 @@ public String getAlgorithmName() { * Executes the algorithm, producing (at least) a result workbench. Must be * implemented in the extending class. */ - public FgsRunner.Type getType() { + public FgesRunner.Type getType() { Object model = getDataModel(); if (model == null && getSourceGraph() != null) { @@ -441,29 +441,29 @@ public FgsRunner.Type getType() { "file when you save the session. It can, however, be recreated from the saved seed."); } - FgsRunner.Type type; + FgesRunner.Type type; if (model instanceof Graph) { - type = FgsRunner.Type.GRAPH; + type = FgesRunner.Type.GRAPH; } else if (model instanceof DataSet) { DataSet dataSet = (DataSet) model; if (dataSet.isContinuous()) { - type = FgsRunner.Type.CONTINUOUS; + type = FgesRunner.Type.CONTINUOUS; } else if (dataSet.isDiscrete()) { - type = FgsRunner.Type.DISCRETE; + type = FgesRunner.Type.DISCRETE; } else { throw new IllegalStateException("Data set must either be continuous or discrete."); } } else if (model instanceof ICovarianceMatrix) { - type = FgsRunner.Type.CONTINUOUS; + type = FgesRunner.Type.CONTINUOUS; } else if (model instanceof DataModelList) { DataModelList list = (DataModelList) model; if (allContinuous(list)) { - type = FgsRunner.Type.CONTINUOUS; + type = FgesRunner.Type.CONTINUOUS; } else if (allDiscrete(list)) { - type = FgsRunner.Type.DISCRETE; + type = FgesRunner.Type.DISCRETE; } else { throw new IllegalArgumentException("Data must be either all discrete or all continuous."); } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/FgsRunner.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/FgesRunner.java similarity index 85% rename from tetrad-gui/src/main/java/edu/cmu/tetradapp/model/FgsRunner.java rename to tetrad-gui/src/main/java/edu/cmu/tetradapp/model/FgesRunner.java index 4fca82c36c..4dce373e1e 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/FgsRunner.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/FgesRunner.java @@ -45,7 +45,7 @@ * @author Ricardo Silva */ -public class FgsRunner extends AbstractAlgorithmRunner implements IFgsRunner, GraphSource, +public class FgesRunner extends AbstractAlgorithmRunner implements IFgesRunner, GraphSource, PropertyChangeListener, IGesRunner, Indexable, DoNotAddOldModel, Unmarshallable { static final long serialVersionUID = 23L; @@ -54,36 +54,36 @@ public enum Type {CONTINUOUS, DISCRETE, MIXED, GRAPH} private transient List listeners; private List topGraphs; private int index; - private transient Fgs fgs; + private transient Fges fges; private transient Graph initialGraph; //============================CONSTRUCTORS============================// - public FgsRunner(DataWrapper[] dataWrappers, Parameters params, KnowledgeBoxModel knowledgeBoxModel) { + public FgesRunner(DataWrapper[] dataWrappers, Parameters params, KnowledgeBoxModel knowledgeBoxModel) { super(new MergeDatasetsWrapper(dataWrappers, params), params, knowledgeBoxModel); } - public FgsRunner(DataWrapper[] dataWrappers, Parameters params) { + public FgesRunner(DataWrapper[] dataWrappers, Parameters params) { super(new MergeDatasetsWrapper(dataWrappers, params), params, null); } - public FgsRunner(DataWrapper[] dataWrappers, GraphSource graph, Parameters params) { + public FgesRunner(DataWrapper[] dataWrappers, GraphSource graph, Parameters params) { super(new MergeDatasetsWrapper(dataWrappers, params), params, null); if (graph == this) throw new IllegalArgumentException(); this.initialGraph = graph.getGraph(); } - public FgsRunner(DataWrapper[] dataWrappers, GraphSource graph, Parameters params, KnowledgeBoxModel knowledgeBoxModel) { + public FgesRunner(DataWrapper[] dataWrappers, GraphSource graph, Parameters params, KnowledgeBoxModel knowledgeBoxModel) { super(new MergeDatasetsWrapper(dataWrappers, params), params, knowledgeBoxModel); if (graph == this) throw new IllegalArgumentException(); this.initialGraph = graph.getGraph(); } - public FgsRunner(GraphWrapper graphWrapper, Parameters params, KnowledgeBoxModel knowledgeBoxModel) { + public FgesRunner(GraphWrapper graphWrapper, Parameters params, KnowledgeBoxModel knowledgeBoxModel) { super(graphWrapper.getGraph(), params, knowledgeBoxModel); } - public FgsRunner(GraphWrapper graphWrapper, Parameters params) { + public FgesRunner(GraphWrapper graphWrapper, Parameters params) { super(graphWrapper.getGraph(), params, null); } @@ -122,9 +122,9 @@ public void execute() { if (model instanceof Graph) { GraphScore gesScore = new GraphScore((Graph) model); - fgs = new Fgs(gesScore); - fgs.setKnowledge((IKnowledge) getParams().get("knowledge", new Knowledge2())); - fgs.setVerbose(true); + fges = new Fges(gesScore); + fges.setKnowledge((IKnowledge) getParams().get("knowledge", new Knowledge2())); + fges.setVerbose(true); } else { double penaltyDiscount = params.getDouble("penaltyDiscount", 4); @@ -138,24 +138,24 @@ public void execute() { // SvrScore gesScore = new SvrScore((DataSet) model); gesScore.setPenaltyDiscount(penaltyDiscount); System.out.println("Score done"); - fgs = new Fgs(gesScore); + fges = new Fges(gesScore); } else if (dataSet.isDiscrete()) { double samplePrior = getParams().getDouble("samplePrior", 1); double structurePrior = getParams().getDouble("structurePrior", 1); BDeuScore score = new BDeuScore(dataSet); score.setSamplePrior(samplePrior); score.setStructurePrior(structurePrior); - fgs = new Fgs(score); + fges = new Fges(score); } else { MixedBicScore gesScore = new MixedBicScore(dataSet); gesScore.setPenaltyDiscount(penaltyDiscount); - fgs = new Fgs(gesScore); + fges = new Fges(gesScore); } } else if (model instanceof ICovarianceMatrix) { SemBicScore gesScore = new SemBicScore((ICovarianceMatrix) model); gesScore.setPenaltyDiscount(penaltyDiscount); gesScore.setPenaltyDiscount(penaltyDiscount); - fgs = new Fgs(gesScore); + fges = new Fges(gesScore); } else if (model instanceof DataModelList) { DataModelList list = (DataModelList) model; @@ -167,7 +167,7 @@ public void execute() { } if (list.size() != 1) { - throw new IllegalArgumentException("FGS takes exactly one data set, covariance matrix, or initialGraph " + + throw new IllegalArgumentException("FGES takes exactly one data set, covariance matrix, or initialGraph " + "as input. For multiple data sets as input, use IMaGES."); } @@ -175,26 +175,26 @@ public void execute() { double penalty = getParams().getDouble("penaltyDiscount", 4); if (params.getBoolean("firstNontriangular", false)) { - SemBicScoreImages fgsScore = new SemBicScoreImages(list); - fgsScore.setPenaltyDiscount(penalty); - fgs = new Fgs(fgsScore); + SemBicScoreImages fgesScore = new SemBicScoreImages(list); + fgesScore.setPenaltyDiscount(penalty); + fges = new Fges(fgesScore); } else { - SemBicScoreImages fgsScore = new SemBicScoreImages(list); - fgsScore.setPenaltyDiscount(penalty); - fgs = new Fgs(fgsScore); + SemBicScoreImages fgesScore = new SemBicScoreImages(list); + fgesScore.setPenaltyDiscount(penalty); + fges = new Fges(fgesScore); } } else if (allDiscrete(list)) { double structurePrior = getParams().getDouble("structurePrior", 1); double samplePrior = getParams().getDouble("samplePrior", 1); - BdeuScoreImages fgsScore = new BdeuScoreImages(list); - fgsScore.setSamplePrior(samplePrior); - fgsScore.setStructurePrior(structurePrior); + BdeuScoreImages fgesScore = new BdeuScoreImages(list); + fgesScore.setSamplePrior(samplePrior); + fgesScore.setStructurePrior(structurePrior); if (params.getBoolean("firstNontriangular", false)) { - fgs = new Fgs(fgsScore); + fges = new Fges(fgesScore); } else { - fgs = new Fgs(fgsScore); + fges = new Fges(fgesScore); } } else { throw new IllegalArgumentException("Data must be either all discrete or all continuous."); @@ -204,12 +204,12 @@ public void execute() { } } - fgs.setInitialGraph(initialGraph); - fgs.setKnowledge((IKnowledge) getParams().get("knowledge", new Knowledge2())); - fgs.setNumPatternsToStore(params.getInt("numPatternsToSave", 1)); - fgs.setVerbose(true); - fgs.setFaithfulnessAssumed(params.getBoolean("faithfulnessAssumed", true)); - Graph graph = fgs.search(); + fges.setInitialGraph(initialGraph); + fges.setKnowledge((IKnowledge) getParams().get("knowledge", new Knowledge2())); + fges.setNumPatternsToStore(params.getInt("numPatternsToSave", 1)); + fges.setVerbose(true); + fges.setFaithfulnessAssumed(params.getBoolean("faithfulnessAssumed", true)); + Graph graph = fges.search(); if (getSourceGraph() != null) { GraphUtils.arrangeBySourceGraph(graph, getSourceGraph()); @@ -221,7 +221,7 @@ public void execute() { setResultGraph(graph); - this.topGraphs = new ArrayList<>(fgs.getTopGraphs()); + this.topGraphs = new ArrayList<>(fges.getTopGraphs()); if (topGraphs.isEmpty()) { @@ -363,7 +363,7 @@ public Map getParamSettings() { @Override public String getAlgorithmName() { - return "FGS"; + return "FGES"; } public void propertyChange(PropertyChangeEvent evt) { @@ -392,15 +392,15 @@ public List getTopGraphs() { } public String getBayesFactorsReport(Graph dag) { - if (fgs == null) { + if (fges == null) { return "Please re-run IMaGES."; } else { - return fgs.logEdgeBayesFactorsString(dag); + return fges.logEdgeBayesFactorsString(dag); } } public GraphScorer getGraphScorer() { - return fgs; + return fges; } } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/GFciRunner.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/GFciRunner.java index 17f810ca59..95c91a47d3 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/GFciRunner.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/GFciRunner.java @@ -248,7 +248,7 @@ public void execute() { } if (list.size() != 1) { - throw new IllegalArgumentException("FGS takes exactly one data set, covariance matrix, or initialGraph " + + throw new IllegalArgumentException("FGES takes exactly one data set, covariance matrix, or initialGraph " + "as input. For multiple data sets as input, use IMaGES."); } @@ -258,20 +258,20 @@ public void execute() { if (allContinuous(list)) { double penalty = params.getDouble("penaltyDiscount", 4); - SemBicScoreImages fgsScore = new SemBicScoreImages(list); - IndTestScore test = new IndTestScore(fgsScore); - fgsScore.setPenaltyDiscount(penalty); - gfci = new GFci(test, fgsScore); + SemBicScoreImages fgesScore = new SemBicScoreImages(list); + IndTestScore test = new IndTestScore(fgesScore); + fgesScore.setPenaltyDiscount(penalty); + gfci = new GFci(test, fgesScore); } // else if (allDiscrete(list)) { // double structurePrior = ((Parameters) getParameters()).getStructurePrior(); // double samplePrior = ((Parameters) getParameters()).getSamplePrior(); // -// BdeuScoreImages fgsScore = new BdeuScoreImages(list); -// fgsScore.setSamplePrior(samplePrior); -// fgsScore.setStructurePrior(structurePrior); +// BdeuScoreImages fgesScore = new BdeuScoreImages(list); +// fgesScore.setSamplePrior(samplePrior); +// fgesScore.setStructurePrior(structurePrior); // -// gfci = new GFci(fgsScore); +// gfci = new GFci(fgesScore); // } else { throw new IllegalArgumentException("Data must be either all discrete or all continuous."); diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/GeneralAlgorithmRunner.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/GeneralAlgorithmRunner.java index 7aea3aa627..6b629e2e41 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/GeneralAlgorithmRunner.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/GeneralAlgorithmRunner.java @@ -24,7 +24,7 @@ import edu.cmu.tetrad.algcomparison.algorithm.Algorithm; import edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm; import edu.cmu.tetrad.algcomparison.algorithm.cluster.ClusterAlgorithm; -import edu.cmu.tetrad.algcomparison.algorithm.oracle.pattern.Fgs; +import edu.cmu.tetrad.algcomparison.algorithm.oracle.pattern.Fges; import edu.cmu.tetrad.algcomparison.score.BdeuScore; import edu.cmu.tetrad.algcomparison.utils.HasKnowledge; import edu.cmu.tetrad.data.*; @@ -60,7 +60,7 @@ public class GeneralAlgorithmRunner implements AlgorithmRunner, ParamsResettable private DataWrapper dataWrapper; private String name; - private Algorithm algorithm = new Fgs(new BdeuScore()); + private Algorithm algorithm = new Fges(new BdeuScore()); private Parameters parameters; private Graph sourceGraph; private Graph initialGraph; @@ -249,7 +249,8 @@ public void execute() { } if (dataSets.size() < parameters.getInt("randomSelectionSize")) { - throw new IllegalArgumentException("The random selection size is greater than the number of data sets."); + throw new IllegalArgumentException("Sorry, the 'random selection size' is greater than " + + "the number of data sets."); } Collections.shuffle(dataSets); @@ -264,20 +265,19 @@ public void execute() { } } else if (getAlgorithm() instanceof ClusterAlgorithm) { for (int k = 0; k < parameters.getInt("numRandomSelections"); k++) { - List dataSets = new ArrayList<>(); - for (DataModel dataModel : getDataModelList()) { DataSet dataSet = (DataSet) dataModel; if (!dataSet.isContinuous()) { throw new IllegalArgumentException("Sorry, you need a continuous dataset for a cluster algorithm."); } + + graphList.add(algorithm.search(dataSet, parameters)); } } } else { for (DataModel data : getDataModelList()) { System.out.println("Analyzing data set # " + (++i)); - DataModel dataSet = data; //(DataSet) data; Algorithm algorithm = getAlgorithm(); if (algorithm instanceof HasKnowledge) { @@ -286,12 +286,15 @@ public void execute() { DataType algDataType = algorithm.getDataType(); - if (dataSet.isContinuous() && (algDataType == DataType.Continuous || algDataType == DataType.Mixed)) { - graphList.add(algorithm.search(dataSet, parameters)); - } else if (dataSet.isDiscrete() && (algDataType == DataType.Discrete || algDataType == DataType.Mixed) && dataSet.isDiscrete()) { - graphList.add(algorithm.search(dataSet, parameters)); - } else if (((DataSet) data).isMixed() && algDataType == DataType.Mixed) { - graphList.add(algorithm.search(dataSet, parameters)); + System.out.println("data type = " + algDataType); + System.out.println("Continuous = " + data.isContinuous()); + + if (data.isContinuous() && (algDataType == DataType.Continuous || algDataType == DataType.Mixed)) { + graphList.add(algorithm.search(data, parameters)); + } else if (data.isDiscrete() && (algDataType == DataType.Discrete || algDataType == DataType.Mixed)) { + graphList.add(algorithm.search(data, parameters)); + } else if (data.isMixed() && algDataType == DataType.Mixed) { + graphList.add(algorithm.search(data, parameters)); } else { throw new IllegalArgumentException("The type of data changed; try opening up the search editor and " + "running the algorithm there."); @@ -504,16 +507,6 @@ public List getVariableNames() { return null; } - public List getCompareGraph() { - List compareGraphs = new ArrayList<>(); - - for (Graph graph : getGraphs()) { - compareGraphs.add(algorithm.getComparisonGraph(graph)); - } - - return compareGraphs; - } - public List getCompareGraphs(List graphs) { if (graphs == null) throw new NullPointerException(); diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/GraphSelectionWrapper.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/GraphSelectionWrapper.java index d3e00d2149..6c19fbe3ea 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/GraphSelectionWrapper.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/GraphSelectionWrapper.java @@ -69,15 +69,6 @@ public GraphSelectionWrapper(List graphs, Parameters params) { } this.params = params; - - List oldGraphs = getGraphs(); - - if (oldGraphs != null) { - for (int i = 0; i < graphs.size(); i++) { - graphs.set(i, GraphUtils.replaceNodes(graphs.get(i), oldGraphs.get(0).getNodes())); - } - } - init(params, graphs); } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/GraphWrapper.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/GraphWrapper.java index 1770871cfe..1059fca192 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/GraphWrapper.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/GraphWrapper.java @@ -158,7 +158,7 @@ public Graph getGraph() { public void setGraph(Graph graph) { graphs = new ArrayList<>(); - graphs.add(graph); + graphs.add(new EdgeListGraph(graph)); log(); } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/IFgsRunner.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/IFgesRunner.java similarity index 86% rename from tetrad-gui/src/main/java/edu/cmu/tetradapp/model/IFgsRunner.java rename to tetrad-gui/src/main/java/edu/cmu/tetradapp/model/IFgesRunner.java index 9144cc9c98..837fad563a 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/IFgsRunner.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/IFgesRunner.java @@ -10,8 +10,8 @@ /** * Created by jdramsey on 2/22/16. */ -public interface IFgsRunner { - FgsRunner.Type getType(); +public interface IFgesRunner { + FgesRunner.Type getType(); List getTopGraphs(); diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/ImagesRunner.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/ImagesRunner.java index 3d111cc10a..c0b2b7fdec 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/ImagesRunner.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/ImagesRunner.java @@ -41,20 +41,20 @@ * @author Ricardo Silva */ -public class ImagesRunner extends AbstractAlgorithmRunner implements IFgsRunner, GraphSource, +public class ImagesRunner extends AbstractAlgorithmRunner implements IFgesRunner, GraphSource, PropertyChangeListener, IGesRunner, Indexable { static final long serialVersionUID = 23L; - public FgsRunner.Type getType() { + public FgesRunner.Type getType() { return type; } private transient List listeners; private List topGraphs; private int index; - private transient Fgs fgs; + private transient Fges fges; private Graph graph; - private FgsRunner.Type type; + private FgesRunner.Type type; //============================CONSTRUCTORS============================// @@ -123,28 +123,28 @@ public void execute() { if (model instanceof Graph) { GraphScore gesScore = new GraphScore((Graph) model); - fgs = new Fgs(gesScore); + fges = new Fges(gesScore); } else if (model instanceof DataSet) { DataSet dataSet = (DataSet) model; if (dataSet.isContinuous()) { SemBicScore gesScore = new SemBicScore(new CovarianceMatrixOnTheFly((DataSet) model)); gesScore.setPenaltyDiscount(params.getDouble("penaltyDiscount", 4)); - fgs = new Fgs(gesScore); + fges = new Fges(gesScore); } else if (dataSet.isDiscrete()) { double samplePrior = getParams().getDouble("samplePrior", 1); double structurePrior = getParams().getDouble("structurePrior", 1); BDeuScore score = new BDeuScore(dataSet); score.setSamplePrior(samplePrior); score.setStructurePrior(structurePrior); - fgs = new Fgs(score); + fges = new Fges(score); } else { throw new IllegalStateException("Data set must either be continuous or discrete."); } } else if (model instanceof ICovarianceMatrix) { SemBicScore gesScore = new SemBicScore((ICovarianceMatrix) model); gesScore.setPenaltyDiscount(params.getDouble("penaltyDiscount", 4)); - fgs = new Fgs(gesScore); + fges = new Fges(gesScore); } else if (model instanceof DataModelList) { DataModelList list = (DataModelList) model; @@ -156,7 +156,7 @@ public void execute() { } // if (list.size() != 1) { -// throw new IllegalArgumentException("FGS takes exactly one data set, covariance matrix, or graph " + +// throw new IllegalArgumentException("FGES takes exactly one data set, covariance matrix, or graph " + // "as input. For multiple data sets as input, use IMaGES."); // } @@ -164,26 +164,26 @@ public void execute() { double penalty = getParams().getDouble("penaltyDiscount", 4); if (params.getBoolean("firstNontriangular", false)) { - SemBicScoreImages fgsScore = new SemBicScoreImages(list); - fgsScore.setPenaltyDiscount(penalty); - fgs = new Fgs(fgsScore); + SemBicScoreImages fgesScore = new SemBicScoreImages(list); + fgesScore.setPenaltyDiscount(penalty); + fges = new Fges(fgesScore); } else { - SemBicScoreImages fgsScore = new SemBicScoreImages(list); - fgsScore.setPenaltyDiscount(penalty); - fgs = new Fgs(fgsScore); + SemBicScoreImages fgesScore = new SemBicScoreImages(list); + fgesScore.setPenaltyDiscount(penalty); + fges = new Fges(fgesScore); } } else if (allDiscrete(list)) { double structurePrior = getParams().getDouble("structurePrior", 1); double samplePrior = getParams().getDouble("samplePrior", 1); - BdeuScoreImages fgsScore = new BdeuScoreImages(list); - fgsScore.setSamplePrior(samplePrior); - fgsScore.setStructurePrior(structurePrior); + BdeuScoreImages fgesScore = new BdeuScoreImages(list); + fgesScore.setSamplePrior(samplePrior); + fgesScore.setStructurePrior(structurePrior); if (params.getBoolean("firstNontriangular", false)) { - fgs = new Fgs(fgsScore); + fges = new Fges(fgesScore); } else { - fgs = new Fgs(fgsScore); + fges = new Fges(fgesScore); } } else { throw new IllegalArgumentException("Data must be either all discrete or all continuous."); @@ -192,11 +192,11 @@ public void execute() { System.out.println("No viable input."); } - fgs.setKnowledge((IKnowledge) getParams().get("knowledge", new Knowledge2())); - fgs.setNumPatternsToStore(params.getInt("numPatternsToSave", 1)); - fgs.setFaithfulnessAssumed(params.getBoolean("faithfulnessAssumed", true)); - fgs.setVerbose(true); - Graph graph = fgs.search(); + fges.setKnowledge((IKnowledge) getParams().get("knowledge", new Knowledge2())); + fges.setNumPatternsToStore(params.getInt("numPatternsToSave", 1)); + fges.setFaithfulnessAssumed(params.getBoolean("faithfulnessAssumed", true)); + fges.setVerbose(true); + Graph graph = fges.search(); if (getSourceGraph() != null) { GraphUtils.arrangeBySourceGraph(graph, getSourceGraph()); @@ -208,13 +208,13 @@ public void execute() { setResultGraph(graph); - this.topGraphs = new ArrayList<>(fgs.getTopGraphs()); + this.topGraphs = new ArrayList<>(fges.getTopGraphs()); if (topGraphs.isEmpty()) { topGraphs.add(new ScoredGraph(getResultGraph(), Double.NaN)); } - this.topGraphs = new ArrayList<>(fgs.getTopGraphs()); + this.topGraphs = new ArrayList<>(fges.getTopGraphs()); if (this.topGraphs.isEmpty()) { this.topGraphs.add(new ScoredGraph(getResultGraph(), Double.NaN)); @@ -227,7 +227,7 @@ public void execute() { * Executes the algorithm, producing (at least) a result workbench. Must be * implemented in the extending class. */ - private FgsRunner.Type computeType() { + private FgesRunner.Type computeType() { Object model = getDataModel(); if (model == null && getSourceGraph() != null) { @@ -242,26 +242,26 @@ private FgsRunner.Type computeType() { } if (model instanceof Graph) { - type = FgsRunner.Type.GRAPH; + type = FgesRunner.Type.GRAPH; } else if (model instanceof DataSet) { DataSet dataSet = (DataSet) model; if (dataSet.isContinuous()) { - type = FgsRunner.Type.CONTINUOUS; + type = FgesRunner.Type.CONTINUOUS; } else if (dataSet.isDiscrete()) { - type = FgsRunner.Type.DISCRETE; + type = FgesRunner.Type.DISCRETE; } else { throw new IllegalStateException("Data set must either be continuous or discrete."); } } else if (model instanceof ICovarianceMatrix) { - type = FgsRunner.Type.CONTINUOUS; + type = FgesRunner.Type.CONTINUOUS; } else if (model instanceof DataModelList) { DataModelList list = (DataModelList) model; if (allContinuous(list)) { - type = FgsRunner.Type.CONTINUOUS; + type = FgesRunner.Type.CONTINUOUS; } else if (allDiscrete(list)) { - type = FgsRunner.Type.DISCRETE; + type = FgesRunner.Type.DISCRETE; } else { throw new IllegalArgumentException("Data must be either all discrete or all continuous."); } @@ -366,15 +366,15 @@ public List getTopGraphs() { } public String getBayesFactorsReport(Graph dag) { - if (fgs == null) { + if (fges == null) { return "Please re-run IMaGES."; } else { - return fgs.logEdgeBayesFactorsString(dag); + return fges.logEdgeBayesFactorsString(dag); } } public GraphScorer getGraphScorer() { - return fgs; + return fges; } } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/MbfsRunner.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/MbfsRunner.java index a404ee9f15..1bfc12cedd 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/MbfsRunner.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/MbfsRunner.java @@ -160,7 +160,7 @@ public void execute() { SemBicScore score = new SemBicScore(new CovarianceMatrixOnTheFly(dataSet)); score.setPenaltyDiscount(getParams().getDouble("alpha", 0.001)); - FgsMb search = new FgsMb(score); + FgesMb search = new FgesMb(score); search.setFaithfulnessAssumed(true); Graph searchGraph = search.search(dataSet.getVariable(targetName)); diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/Misclassifications.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/Misclassifications.java index d396126b3a..fb75ca1ff9 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/Misclassifications.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/Misclassifications.java @@ -87,16 +87,16 @@ public final class Misclassifications implements SessionModel, DoNotAddOldModel //=============================CONSTRUCTORS==========================// - public Misclassifications(GeneralAlgorithmRunner model, Parameters params) { - this(model, model.getDataWrapper(), params); - } +// public Misclassifications(GeneralAlgorithmRunner model, Parameters params) { +// this(model, model.getDataWrapper(), params); +// } /** * Compares the results of a PC to a reference workbench by counting errors * of omission and commission. The counts can be retrieved using the methods * countOmissionErrors and countCommissionErrors. */ - public Misclassifications(SessionModel model1, SessionModel model2, + public Misclassifications(MultipleGraphSource model1, MultipleGraphSource model2, Parameters params) { if (params == null) { throw new NullPointerException("Parameters must not be null"); @@ -185,18 +185,18 @@ public Misclassifications(SessionModel model1, SessionModel model2, // Need to be able to construct this object even if the models are // null. Otherwise the interface is annoying. - if (model2 == null) { - model2 = new DagWrapper(new Dag()); - } - - if (model1 == null) { - model1 = new DagWrapper(new Dag()); - } +// if (model2 == null) { +// model2 = new DagWrapper(new Dag()); +// } +// +// if (model1 == null) { +// model1 = new DagWrapper(new Dag()); +// } - if (!(model1 instanceof MultipleGraphSource) || - !(model2 instanceof MultipleGraphSource)) { - throw new IllegalArgumentException("Must be graph sources."); - } +// if (!(model1 instanceof MultipleGraphSource) || +// !(model2 instanceof MultipleGraphSource)) { +// throw new IllegalArgumentException("Must be graph sources."); +// } this.params = params; @@ -217,6 +217,22 @@ public Misclassifications(SessionModel model1, SessionModel model2, this.targetGraphs = ((MultipleGraphSource) model2).getGraphs(); } + if (referenceGraphs.size() == 1 && targetGraphs.size() > 1) { + Graph graph = referenceGraphs.get(0); + referenceGraphs = new ArrayList<>(); + for (Graph _graph : targetGraphs) { + referenceGraphs.add(_graph); + } + } + + if (targetGraphs.size() == 1 && referenceGraphs.size() > 1) { + Graph graph = targetGraphs.get(0); + targetGraphs = new ArrayList<>(); + for (Graph _graph : referenceGraphs) { + targetGraphs.add(graph); + } + } + if (referenceGraphs == null) { this.referenceGraphs = Collections.singletonList(((GraphSource) model1).getGraph()); } @@ -235,6 +251,22 @@ public Misclassifications(SessionModel model1, SessionModel model2, this.targetGraphs = ((MultipleGraphSource) model1).getGraphs(); } + if (referenceGraphs.size() == 1 && targetGraphs.size() > 1) { + Graph graph = referenceGraphs.get(0); + referenceGraphs = new ArrayList<>(); + for (Graph _graph : targetGraphs) { + referenceGraphs.add(_graph); + } + } + + if (targetGraphs.size() == 1 && referenceGraphs.size() > 1) { + Graph graph = targetGraphs.get(0); + targetGraphs = new ArrayList<>(); + for (Graph _graph : referenceGraphs) { + targetGraphs.add(graph); + } + } + if (referenceGraphs == null) { this.referenceGraphs = Collections.singletonList(((GraphSource) model2).getGraph()); } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/MultipleGraphSource.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/MultipleGraphSource.java index d1bc2b40d2..f9db4c63e3 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/MultipleGraphSource.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/MultipleGraphSource.java @@ -32,6 +32,8 @@ */ interface MultipleGraphSource { List getGraphs(); + + String getName(); } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/PValueImproverWrapper.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/PValueImproverWrapper.java index b4ec71665a..2d12dd9b7d 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/PValueImproverWrapper.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/PValueImproverWrapper.java @@ -45,7 +45,7 @@ public class PValueImproverWrapper extends AbstractAlgorithmRunner implements Gr static final long serialVersionUID = 23L; public enum AlgorithmType { - BEAM, FGS + BEAM, FGES } private AlgorithmType algorithmType = AlgorithmType.BEAM; @@ -223,7 +223,7 @@ public void execute() { if (getAlgorithmType() == AlgorithmType.BEAM) { search = new BffBeam(graph2, dataSet, knowledge); - } else if (getAlgorithmType() == AlgorithmType.FGS) { + } else if (getAlgorithmType() == AlgorithmType.FGES) { search = new BffGes(graph2, dataSet); search.setKnowledge(knowledge); } else { @@ -235,7 +235,7 @@ else if (dataModel instanceof CovarianceMatrix) { if (getAlgorithmType() == AlgorithmType.BEAM) { search = new BffBeam(graph2, covarianceMatrix, knowledge); - } else if (getAlgorithmType() == AlgorithmType.FGS) { + } else if (getAlgorithmType() == AlgorithmType.FGES) { throw new IllegalArgumentException("GES method requires a dataset; a covariance matrix was provided."); // search = new BffGes(graph2, covarianceMatrix); // search.setKnowledge(knowledge); diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/PatternFromDagGraphWrapper.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/PatternFromDagGraphWrapper.java index 2cb03ca0db..5cc2ff3447 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/PatternFromDagGraphWrapper.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/PatternFromDagGraphWrapper.java @@ -51,7 +51,7 @@ public PatternFromDagGraphWrapper(Graph graph) { throw new IllegalArgumentException("The source graph is not a DAG."); } - Graph pattern = getPattern(new Dag(graph)); + Graph pattern = getPattern(new EdgeListGraph(graph)); setGraph(pattern); TetradLogger.getInstance().log("info", "\nGenerating pattern from DAG."); @@ -65,8 +65,8 @@ public static PatternFromDagGraphWrapper serializableInstance() { //======================== Private Method ======================// - private static Graph getPattern(Dag dag) { - return SearchGraphUtils.patternFromDag(dag); + private static Graph getPattern(Graph graph) { + return SearchGraphUtils.patternFromDag(graph); } @Override diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/RandomMixedRunner.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/RandomMixedRunner.java index 9cce69d4bd..1a0f89844f 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/RandomMixedRunner.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/RandomMixedRunner.java @@ -137,17 +137,17 @@ public void execute() { DataSet ds = (DataSet) getDataModelList().get(0); -// WGfci fgs = new WGfci(ds); -// fgs.setAlpha(4); -// Graph graph = fgs.search(); +// WGfci fges = new WGfci(ds); +// fges.setAlpha(4); +// Graph graph = fges.search(); - WFgs fgs = new WFgs(ds); - fgs.setPenaltyDiscount(12); - Graph graph = fgs.search(); + WFges fges = new WFges(ds); + fges.setPenaltyDiscount(12); + Graph graph = fges.search(); -// WFgs fgs = new WFgs(ds); -// fgs.setAlpha(4); -// Graph g = fgs.search(); +// WFges fges = new WFges(ds); +// fges.setAlpha(4); +// Graph g = fges.search(); // IndependenceTest test = new IndTestMixedLrt(ds, .001); // Cpc pc = new Cpc(test); // pc.setInitialGraph(g); @@ -162,24 +162,24 @@ public void execute() { // pcs.setVerbose(false); // Graph graph = pcs.search(); -// WFgs fgs = new WFgs(ds); -// fgs.setMaxIndegree(-1); -// fgs.setAlpha(4); -// Graph graph = fgs.search(); +// WFges fges = new WFges(ds); +// fges.setMaxIndegree(-1); +// fges.setAlpha(4); +// Graph graph = fges.search(); -// WFgs fgs = new WFgs(ds); -// fgs.setMaxIndegree(5); -// fgs.setAlpha(8); -// Graph g = fgs.search(); +// WFges fges = new WFges(ds); +// fges.setMaxIndegree(5); +// fges.setAlpha(8); +// Graph g = fges.search(); // IndependenceTest test = new IndTestMixedLrt(ds, .001); // Cpc pc = new Cpc(test); // pc.setInitialGraph(g); // Graph graph = pc.search(); // ConditionalGaussianScore score = new ConditionalGaussianScore(ds); -// Fgs fgs = new Fgs(score); -// fgs.setMaxIndegree(-1); -// Graph graph = fgs.search(); +// Fges fges = new Fges(score); +// fges.setMaxIndegree(-1); +// Graph graph = fges.search(); GraphUtils.arrangeBySourceGraph(graph, getSourceGraph()); diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/ScoredGraphsWrapper.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/ScoredGraphsWrapper.java index c892e0053a..1b25174e64 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/ScoredGraphsWrapper.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/ScoredGraphsWrapper.java @@ -98,7 +98,7 @@ public ScoredGraphsWrapper(Graph graph, GraphScorer scorer) { log(); } - public ScoredGraphsWrapper(FgsRunner runner, Parameters parameters) { + public ScoredGraphsWrapper(FgesRunner runner, Parameters parameters) { this(runner.getTopGraphs().get(runner.getIndex()).getGraph(), runner.getGraphScorer()); } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/SemImWrapper.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/SemImWrapper.java index 50ca475f5f..edc7974090 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/SemImWrapper.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/SemImWrapper.java @@ -26,6 +26,7 @@ import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.graph.Node; import edu.cmu.tetrad.sem.SemIm; +import edu.cmu.tetrad.sem.SemPm; import edu.cmu.tetrad.session.SessionModel; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.util.TetradLogger; @@ -33,6 +34,7 @@ import java.io.IOException; import java.io.ObjectInputStream; +import java.rmi.MarshalledObject; import java.util.ArrayList; import java.util.List; @@ -60,21 +62,21 @@ public class SemImWrapper implements SessionModel, GraphSource { public SemImWrapper(SemIm semIm) { setSemIm(semIm); } -// -// public SemImWrapper(SemEstimatorWrapper semEstWrapper) { -// if (semEstWrapper == null) { -// throw new NullPointerException(); -// } -// -// SemIm oldSemIm = semEstWrapper.getSemEstimator().getEstimatedSem(); -// -// try { -// setSemIm((SemIm) new MarshalledObject(oldSemIm).get()); -// } -// catch (Exception e) { -// throw new RuntimeException("SemIm could not be deep cloned.", e); -// } -// } + + public SemImWrapper(SemEstimatorWrapper semEstWrapper) { + if (semEstWrapper == null) { + throw new NullPointerException(); + } + + SemIm oldSemIm = semEstWrapper.getSemEstimator().getEstimatedSem(); + + try { + setSemIm((SemIm) new MarshalledObject(oldSemIm).get()); + } + catch (Exception e) { + throw new RuntimeException("SemIm could not be deep cloned.", e); + } + } public SemImWrapper(Simulation simulation) { if (simulation == null) { @@ -115,33 +117,33 @@ public SemImWrapper(SemPmWrapper semPmWrapper, Parameters params) { setSemIm(new SemIm(semPmWrapper.getSemPms().get(semPmWrapper.getModelIndex()), params)); } -// public SemImWrapper(SemPmWrapper semPmWrapper, SemImWrapper oldSemImWrapper, -// Parameters params) { -// if (semPmWrapper == null) { -// throw new NullPointerException("SemPmWrapper must not be null."); -// } -// -// if (params == null) { -// throw new NullPointerException("Parameters must not be null."); -// } -// -// SemPm semPm = new SemPm(semPmWrapper.getSemPm()); -// SemIm oldSemIm = oldSemImWrapper.getSemIm(); -// -// if (!params.getBoolean("retainPreviousValues", false)) { -// setSemIm(new SemIm(semPm, params)); -// } else { -// setSemIm(new SemIm(semPm, oldSemIm, params)); -// } -// } - -// public SemImWrapper(SemUpdaterWrapper semUpdaterWrapper) { -// if (semUpdaterWrapper == null) { -// throw new NullPointerException("SemPmWrapper must not be null."); -// } -// -// setSemIm(semUpdaterWrapper.getSemUpdater().getUpdatedSemIm()); -// } + public SemImWrapper(SemPmWrapper semPmWrapper, SemImWrapper oldSemImWrapper, + Parameters params) { + if (semPmWrapper == null) { + throw new NullPointerException("SemPmWrapper must not be null."); + } + + if (params == null) { + throw new NullPointerException("Parameters must not be null."); + } + + SemPm semPm = new SemPm(semPmWrapper.getSemPm()); + SemIm oldSemIm = oldSemImWrapper.getSemIm(); + + if (!params.getBoolean("retainPreviousValues", false)) { + setSemIm(new SemIm(semPm, params)); + } else { + setSemIm(new SemIm(semPm, oldSemIm, params)); + } + } + + public SemImWrapper(SemUpdaterWrapper semUpdaterWrapper) { + if (semUpdaterWrapper == null) { + throw new NullPointerException("SemPmWrapper must not be null."); + } + + setSemIm(semUpdaterWrapper.getSemUpdater().getUpdatedSemIm()); + } private void setSemIm(SemIm updatedSemIm) { semIms = new ArrayList<>(); @@ -152,13 +154,13 @@ private void setSemIm(SemIm updatedSemIm) { } } -// public SemImWrapper(SemImWrapper semImWrapper) { -// if (semImWrapper == null) { -// throw new NullPointerException("SemPmWrapper must not be null."); -// } -// -// setSemIm(semImWrapper.getSemIm()); -// } + public SemImWrapper(SemImWrapper semImWrapper) { + if (semImWrapper == null) { + throw new NullPointerException("SemPmWrapper must not be null."); + } + + setSemIm(semImWrapper.getSemIm()); + } public SemImWrapper(PValueImproverWrapper wrapper) { SemIm oldSemIm = wrapper.getNewSemIm(); @@ -222,21 +224,21 @@ private void readObject(ObjectInputStream s) s.defaultReadObject(); } - public Graph getSourceGraph() { - return getGraph(); - } + public Graph getSourceGraph() { + return getGraph(); + } public Graph getResultGraph() { return getGraph(); } public List getVariableNames() { - return getGraph().getNodeNames(); - } + return getGraph().getNodeNames(); + } - public List getVariables() { - return getGraph().getNodes(); - } + public List getVariables() { + return getGraph().getNodes(); + } public int getNumModels() { return numModels; diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/SessionWrapper.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/SessionWrapper.java index ba975af737..59099fc441 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/SessionWrapper.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/SessionWrapper.java @@ -47,7 +47,7 @@ * @see edu.cmu.tetrad.session.Session * @see edu.cmu.tetrad.graph.Graph */ -public class SessionWrapper implements Graph, SessionWrapperIndirectRef { +public class SessionWrapper extends EdgeListGraph implements SessionWrapperIndirectRef { static final long serialVersionUID = 23L; /** @@ -82,6 +82,8 @@ public class SessionWrapper implements Graph, SessionWrapperIndirectRef { */ private transient SessionHandler sessionHandler; private boolean highlighted = false; + private boolean pag; + private boolean pattern; //==========================CONSTRUCTORS=======================// @@ -260,7 +262,7 @@ private SessionHandler getSessionHandler() { * @param deltaY the shift in y. */ private void adjustNameAndPosition(SessionNodeWrapper wrapper, - List sessionNodeWrappers, int deltaX, int deltaY) { + List sessionNodeWrappers, int deltaX, int deltaY) { String originalName = wrapper.getSessionName(); String base = extractBase(originalName); String uniqueName = nextUniqueName(base, sessionNodeWrappers); @@ -362,7 +364,7 @@ public Edge getEdge(Node node1, Node node2) { } public Edge getDirectedEdge(Node node1, Node node2) { - return null; + return null; } /** @@ -613,15 +615,15 @@ public void transferNodesAndEdges(Graph graph) } public Set getAmbiguousTriples() { - throw new UnsupportedOperationException(); + return new HashSet<>(); } public Set getUnderLines() { - throw new UnsupportedOperationException(); + return new HashSet<>(); } public Set getDottedUnderlines() { - throw new UnsupportedOperationException(); + return new HashSet<>(); } /** @@ -680,7 +682,7 @@ public void setUnderLineTriples(Set triples) { public void setDottedUnderLineTriples(Set triples) { - throw new UnsupportedOperationException(); + throw new UnsupportedOperationException(); } public List getCausalOrdering() { @@ -748,6 +750,26 @@ public List> getTriplesLists(Node node) { return null; } + @Override + public boolean isPag() { + return pag; + } + + @Override + public void setPag(boolean pag) { + this.pag = pag; + } + + @Override + public boolean isPattern() { + return pattern; + } + + @Override + public void setPattern(boolean pattern) { + this.pattern = pattern; + } + /** * Handles SessionEvents. Hides the handling of these from the * API. @@ -781,320 +803,320 @@ public void addingEdge(SessionEvent event) { public List getEdges(Node node1, Node node2) { throw new UnsupportedOperationException(); } - - // Unused methods from Graph - - /** - * Adds a directed edge --> to the graph. - */ - public boolean addDirectedEdge(Node node1, Node node2) { - throw new UnsupportedOperationException(); - } - - /** - * Adds an undirected edge --- to the graph. - */ - public boolean addUndirectedEdge(Node node1, Node node2) { - throw new UnsupportedOperationException(); - } - - /** - * Adds an nondirected edges o-o to the graph. - */ - public boolean addNondirectedEdge(Node node1, Node node2) { - throw new UnsupportedOperationException(); - } - - /** - * Adds a bidirected edges <-> to the graph. - */ - public boolean addBidirectedEdge(Node node1, Node node2) { - throw new UnsupportedOperationException(); - } - - /** - * Adds a partially oriented edge o-> to the graph. - */ - public boolean addPartiallyOrientedEdge(Node node1, Node node2) { - throw new UnsupportedOperationException(); - } - - /** - * @return true iff there is a directed cycle in the graph. - */ - public boolean existsDirectedCycle() { - throw new UnsupportedOperationException(); - } - - public boolean isDirectedFromTo(Node node1, Node node2) { - throw new UnsupportedOperationException(); - } - - public boolean isUndirectedFromTo(Node node1, Node node2) { - throw new UnsupportedOperationException(); - } - - public boolean defVisible(Edge edge) { - throw new UnsupportedOperationException(); - } - - /** - * @return true iff there is a directed path from node1 to node2 in the - * graph. - */ - public boolean existsDirectedPathFromTo(Node node1, Node node2) { - throw new UnsupportedOperationException(); - } - - public boolean existsUndirectedPathFromTo(Node node1, Node node2) { - throw new UnsupportedOperationException(); - } - - public boolean existsSemiDirectedPathFromTo(Node node1, Set nodes2) { - throw new UnsupportedOperationException(); - } - - public boolean existsSemiDirectedPathFromTo(Node node1, Node node2) { - throw new UnsupportedOperationException(); - } - - /** - * @return true iff a trek exists between two nodes in the graph. A trek - * exists if there is a directed path between the two nodes or else, for - * some third node in the graph, there is a path to each of the two nodes in - * question. - */ - public boolean existsTrek(Node node1, Node node2) { - throw new UnsupportedOperationException(); - } - - /** - * @return the list of ancestors for the given nodes. - */ - public List getAncestors(List nodes) { - throw new UnsupportedOperationException(); - } - - /** - * @return the Collection of children for a node. - */ - public List getChildren(Node node) { - throw new UnsupportedOperationException(); - } - - public int getConnectivity() { - throw new UnsupportedOperationException(); - } - - public List getDescendants(List nodes) { - throw new UnsupportedOperationException(); - } - - /** - * @return a matrix of endpoints for the nodes in this graph, with nodes in - * the same order as getNodes(). - */ - public Endpoint[][] getEndpointMatrix() { - throw new UnsupportedOperationException(); - } - - /** - * @return the list of nodes adjacent to the given node. - */ - public List getAdjacentNodes(Node node) { - throw new UnsupportedOperationException(); - } - - /** - * @return the number of arrow endpoint adjacent to an edge. - */ - public int getIndegree(Node node) { - throw new UnsupportedOperationException(); - } - - @Override - public int getDegree(Node node) { - throw new UnsupportedOperationException(); - } - - /** - * @return the number of null endpoints adjacent to an edge. - */ - public int getOutdegree(Node node) { - throw new UnsupportedOperationException(); - } - +// +// // Unused methods from Graph +// +// /** +// * Adds a directed edge --> to the graph. +// */ +// public boolean addDirectedEdge(Node node1, Node node2) { +// throw new UnsupportedOperationException(); +// } +// +// /** +// * Adds an undirected edge --- to the graph. +// */ +// public boolean addUndirectedEdge(Node node1, Node node2) { +// throw new UnsupportedOperationException(); +// } +// +// /** +// * Adds an nondirected edges o-o to the graph. +// */ +// public boolean addNondirectedEdge(Node node1, Node node2) { +// throw new UnsupportedOperationException(); +// } +// +// /** +// * Adds a bidirected edges <-> to the graph. +// */ +// public boolean addBidirectedEdge(Node node1, Node node2) { +// throw new UnsupportedOperationException(); +// } +// +// /** +// * Adds a partially oriented edge o-> to the graph. +// */ +// public boolean addPartiallyOrientedEdge(Node node1, Node node2) { +// throw new UnsupportedOperationException(); +// } +// +// /** +// * @return true iff there is a directed cycle in the graph. +// */ +// public boolean existsDirectedCycle() { +// throw new UnsupportedOperationException(); +// } +// +// public boolean isDirectedFromTo(Node node1, Node node2) { +// throw new UnsupportedOperationException(); +// } +// +// public boolean isUndirectedFromTo(Node node1, Node node2) { +// throw new UnsupportedOperationException(); +// } +// +// public boolean defVisible(Edge edge) { +// throw new UnsupportedOperationException(); +// } +// +// /** +// * @return true iff there is a directed path from node1 to node2 in the +// * graph. +// */ +// public boolean existsDirectedPathFromTo(Node node1, Node node2) { +// throw new UnsupportedOperationException(); +// } +// +// public boolean existsUndirectedPathFromTo(Node node1, Node node2) { +// throw new UnsupportedOperationException(); +// } +// +// public boolean existsSemiDirectedPathFromTo(Node node1, Set nodes2) { +// throw new UnsupportedOperationException(); +// } +// +// public boolean existsSemiDirectedPathFromTo(Node node1, Node node2) { +// throw new UnsupportedOperationException(); +// } +// +// /** +// * @return true iff a trek exists between two nodes in the graph. A trek +// * exists if there is a directed path between the two nodes or else, for +// * some third node in the graph, there is a path to each of the two nodes in +// * question. +// */ +// public boolean existsTrek(Node node1, Node node2) { +// throw new UnsupportedOperationException(); +// } +// +// /** +// * @return the list of ancestors for the given nodes. +// */ +// public List getAncestors(List nodes) { +// throw new UnsupportedOperationException(); +// } +// +// /** +// * @return the Collection of children for a node. +// */ +// public List getChildren(Node node) { +// throw new UnsupportedOperationException(); +// } +// +// public int getConnectivity() { +// throw new UnsupportedOperationException(); +// } +// +// public List getDescendants(List nodes) { +// throw new UnsupportedOperationException(); +// } +// +// /** +// * @return a matrix of endpoints for the nodes in this graph, with nodes in +// * the same order as getNodes(). +// */ +// public Endpoint[][] getEndpointMatrix() { +// throw new UnsupportedOperationException(); +// } +// +// /** +// * @return the list of nodes adjacent to the given node. +// */ +// public List getAdjacentNodes(Node node) { +// throw new UnsupportedOperationException(); +// } +// +// /** +// * @return the number of arrow endpoint adjacent to an edge. +// */ +// public int getIndegree(Node node) { +// throw new UnsupportedOperationException(); +// } +// +// @Override +// public int getDegree(Node node) { +// throw new UnsupportedOperationException(); +// } +// +// /** +// * @return the number of null endpoints adjacent to an edge. +// */ +// public int getOutdegree(Node node) { +// throw new UnsupportedOperationException(); +// } +// /** * @return the list of parents for a node. */ public List getParents(Node node) { - throw new UnsupportedOperationException(); - } - - /** - * Determines whether one node is an ancestor of another. - */ - public boolean isAncestorOf(Node node1, Node node2) { - throw new UnsupportedOperationException(); - } - - public boolean possibleAncestor(Node node1, Node node2) { - throw new UnsupportedOperationException(); - } - - /** - * @return true iff node1 is adjacent to node2 in the graph. - */ - public boolean isAdjacentTo(Node node1, Node node2) { - throw new UnsupportedOperationException(); - } - - /** - * @return true iff node1 is a child of node2 in the graph. - */ - public boolean isChildOf(Node node1, Node node2) { - throw new UnsupportedOperationException(); - } - - /** - * @return true iff node1 is a (non-proper) descendant of node2. - */ - public boolean isDescendentOf(Node node1, Node node2) { - throw new UnsupportedOperationException(); - } - - public boolean defNonDescendent(Node node1, Node node2) { - throw new UnsupportedOperationException(); - } - - public boolean isDefNoncollider(Node node1, Node node2, Node node3) { - throw new UnsupportedOperationException(); - } - - public boolean isDefCollider(Node node1, Node node2, Node node3) { - throw new UnsupportedOperationException(); - } - - /** - * Determines whether one node is d-separated from another. According to - * Spirtes, Richardson & Meek, two nodes are d- connected given some - * conditioning set Z if there is an acyclic undirected path U between them, - * such that every collider on U is an ancestor of some element in Z and - * every non-collider on U is not in Z. Two elements are d-separated just - * in case they are not d-separated. A collider is a node which two edges - * hold in common for which the endpoints leading into the node are both - * arrow endpoints. - */ - public boolean isDConnectedTo(Node node1, Node node2, List z) { - throw new UnsupportedOperationException(); - } - - /** - * Determines whether one node is d-separated from another. According to - * Spirtes, Richardson & Meek, two nodes are d- connected given some - * conditioning set Z if there is an acyclic undirected path U between them, - * such that every collider on U is an ancestor of some element in Z and - * every non-collider on U is not in Z. Two elements are d-separated just - * in case they are not d-separated. A collider is a node which two edges - * hold in common for which the endpoints leading into the node are both - * arrow endpoints. - */ - public boolean isDSeparatedFrom(Node node1, Node node2, List z) { - throw new UnsupportedOperationException(); - } - - public boolean possDConnectedTo(Node node1, Node node2, List z) { - throw new UnsupportedOperationException(); - } - - /** - * @return true iff the given node is exogenous in the graph. - */ - public boolean isExogenous(Node node) { - throw new UnsupportedOperationException(); - } - - /** - * Determines whether one node is a parent of another. - */ - public boolean isParentOf(Node node1, Node node2) { - throw new UnsupportedOperationException(); - } - - /** - * Determines whether one node is a proper ancestor of another. - */ - public boolean isProperAncestorOf(Node node1, Node node2) { - throw new UnsupportedOperationException(); - } - - /** - * Determines whether one node is a proper decendent of another. - */ - public boolean isProperDescendentOf(Node node1, Node node2) { - throw new UnsupportedOperationException(); - } - - /** - * Nodes adjacent to the given node with the given proximal endpoint. - */ - public List getNodesInTo(Node node, Endpoint n) { - throw new UnsupportedOperationException(); - } - - /** - * Nodes adjacent to the given node with the given distal endpoint. - */ - public List getNodesOutTo(Node node, Endpoint n) { - throw new UnsupportedOperationException(); - } - - /** - * Removes all edges from the graph and fully connects it using #-# edges, - * where # is the given endpoint. - */ - public void fullyConnect(Endpoint endpoint) { - throw new UnsupportedOperationException(); - } - - public void reorientAllWith(Endpoint endpoint) { - throw new UnsupportedOperationException(); - } - - public void setHighlighted(Edge edge, boolean highlighted) { - this.highlighted = highlighted; - } - - public boolean isHighlighted(Edge edge) { - return highlighted; - } - - public boolean isParameterizable(Node node) { - return false; - } - - public boolean isTimeLagModel() { - return false; - } - - public TimeLagGraph getTimeLagGraph() { - return null; - } - - @Override - public void removeTriplesNotInGraph() { - throw new UnsupportedOperationException(); - } - - @Override - public List getSepset(Node n1, Node n2) { - throw new UnsupportedOperationException(); - } - - @Override - public void setNodes(List nodes) { - throw new UnsupportedOperationException("Sorry, you cannot replace the variables for a time lag graph."); - } + return new ArrayList(((SessionNode) node).getParents()); + } +// +// /** +// * Determines whether one node is an ancestor of another. +// */ +// public boolean isAncestorOf(Node node1, Node node2) { +// throw new UnsupportedOperationException(); +// } +// +// public boolean possibleAncestor(Node node1, Node node2) { +// throw new UnsupportedOperationException(); +// } +// +// /** +// * @return true iff node1 is adjacent to node2 in the graph. +// */ +// public boolean isAdjacentTo(Node node1, Node node2) { +// throw new UnsupportedOperationException(); +// } +// +// /** +// * @return true iff node1 is a child of node2 in the graph. +// */ +// public boolean isChildOf(Node node1, Node node2) { +// throw new UnsupportedOperationException(); +// } +// +// /** +// * @return true iff node1 is a (non-proper) descendant of node2. +// */ +// public boolean isDescendentOf(Node node1, Node node2) { +// throw new UnsupportedOperationException(); +// } +// +// public boolean defNonDescendent(Node node1, Node node2) { +// throw new UnsupportedOperationException(); +// } +// +// public boolean isDefNoncollider(Node node1, Node node2, Node node3) { +// throw new UnsupportedOperationException(); +// } +// +// public boolean isDefCollider(Node node1, Node node2, Node node3) { +// throw new UnsupportedOperationException(); +// } +// +// /** +// * Determines whether one node is d-separated from another. According to +// * Spirtes, Richardson & Meek, two nodes are d- connected given some +// * conditioning set Z if there is an acyclic undirected path U between them, +// * such that every collider on U is an ancestor of some element in Z and +// * every non-collider on U is not in Z. Two elements are d-separated just +// * in case they are not d-separated. A collider is a node which two edges +// * hold in common for which the endpoints leading into the node are both +// * arrow endpoints. +// */ +// public boolean isDConnectedTo(Node node1, Node node2, List z) { +// throw new UnsupportedOperationException(); +// } +// +// /** +// * Determines whether one node is d-separated from another. According to +// * Spirtes, Richardson & Meek, two nodes are d- connected given some +// * conditioning set Z if there is an acyclic undirected path U between them, +// * such that every collider on U is an ancestor of some element in Z and +// * every non-collider on U is not in Z. Two elements are d-separated just +// * in case they are not d-separated. A collider is a node which two edges +// * hold in common for which the endpoints leading into the node are both +// * arrow endpoints. +// */ +// public boolean isDSeparatedFrom(Node node1, Node node2, List z) { +// throw new UnsupportedOperationException(); +// } +// +// public boolean possDConnectedTo(Node node1, Node node2, List z) { +// throw new UnsupportedOperationException(); +// } +// +// /** +// * @return true iff the given node is exogenous in the graph. +// */ +// public boolean isExogenous(Node node) { +// throw new UnsupportedOperationException(); +// } +// +// /** +// * Determines whether one node is a parent of another. +// */ +// public boolean isParentOf(Node node1, Node node2) { +// throw new UnsupportedOperationException(); +// } +// +// /** +// * Determines whether one node is a proper ancestor of another. +// */ +// public boolean isProperAncestorOf(Node node1, Node node2) { +// throw new UnsupportedOperationException(); +// } +// +// /** +// * Determines whether one node is a proper decendent of another. +// */ +// public boolean isProperDescendentOf(Node node1, Node node2) { +// throw new UnsupportedOperationException(); +// } +// +// /** +// * Nodes adjacent to the given node with the given proximal endpoint. +// */ +// public List getNodesInTo(Node node, Endpoint n) { +// throw new UnsupportedOperationException(); +// } +// +// /** +// * Nodes adjacent to the given node with the given distal endpoint. +// */ +// public List getNodesOutTo(Node node, Endpoint n) { +// throw new UnsupportedOperationException(); +// } +// +// /** +// * Removes all edges from the graph and fully connects it using #-# edges, +// * where # is the given endpoint. +// */ +// public void fullyConnect(Endpoint endpoint) { +// throw new UnsupportedOperationException(); +// } +// +// public void reorientAllWith(Endpoint endpoint) { +// throw new UnsupportedOperationException(); +// } +// +// public void setHighlighted(Edge edge, boolean highlighted) { +// this.highlighted = highlighted; +// } +// +// public boolean isHighlighted(Edge edge) { +// return highlighted; +// } +// +// public boolean isParameterizable(Node node) { +// return false; +// } +// +// public boolean isTimeLagModel() { +// return false; +// } +// +// public TimeLagGraph getTimeLagGraph() { +// return null; +// } +// +// @Override +// public void removeTriplesNotInGraph() { +// throw new UnsupportedOperationException(); +// } +// +// @Override +// public List getSepset(Node n1, Node n2) { +// throw new UnsupportedOperationException(); +// } +// +// @Override +// public void setNodes(List nodes) { +// throw new UnsupportedOperationException("Sorry, you cannot replace the variables for a time lag graph."); +// } public boolean isSessionChanged() { return this.session.isSessionChanged(); diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/Simulation.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/Simulation.java index aa945fea08..f171a14bd7 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/Simulation.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/Simulation.java @@ -27,6 +27,7 @@ import edu.cmu.tetrad.algcomparison.utils.HasKnowledge; import edu.cmu.tetrad.data.*; import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.session.DoNotAddOldModel; import edu.cmu.tetrad.session.SessionModel; import edu.cmu.tetrad.session.SimulationParamsSource; import edu.cmu.tetrad.util.Parameters; diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/TabularComparison.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/TabularComparison.java index 835308a1c9..b8cd07c216 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/TabularComparison.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/TabularComparison.java @@ -68,102 +68,115 @@ public final class TabularComparison implements SessionModel, SimulationParamsSo //=============================CONSTRUCTORS==========================// - public TabularComparison(GeneralAlgorithmRunner model, Parameters params) { - this(model, model.getDataWrapper(), params); - } +// public TabularComparison(GeneralAlgorithmRunner model, Parameters params) { +// this(model, model.getDataWrapper(), params); +// } /** * Compares the results of a PC to a reference workbench by counting errors * of omission and commission. The counts can be retrieved using the methods * countOmissionErrors and countCommissionErrors. */ - public TabularComparison(SessionModel model1, SessionModel model2, + public TabularComparison(MultipleGraphSource model1, MultipleGraphSource model2, Parameters params) { if (params == null) { throw new NullPointerException("Parameters must not be null"); } - // Need to be able to construct this object even if the models are - // null. Otherwise the interface is annoying. - if (model2 == null) { - model2 = new DagWrapper(new Dag()); - } - - if (model1 == null) { - model1 = new DagWrapper(new Dag()); - } - - if (!(model1 instanceof MultipleGraphSource) || !(model2 instanceof MultipleGraphSource)) { - throw new IllegalArgumentException("Must be graph sources."); - } - if (model1 instanceof GeneralAlgorithmRunner && model2 instanceof GeneralAlgorithmRunner) { throw new IllegalArgumentException("Both parents can't be general algorithm runners."); } - if (model1 instanceof GeneralAlgorithmRunner) { + if (model1 instanceof GeneralAlgorithmRunner && model2 instanceof Simulation) { GeneralAlgorithmRunner generalAlgorithmRunner = (GeneralAlgorithmRunner) model1; this.algorithm = generalAlgorithmRunner.getAlgorithm(); - } else if (model2 instanceof GeneralAlgorithmRunner) { + } else if (model2 instanceof GeneralAlgorithmRunner && model1 instanceof Simulation) { GeneralAlgorithmRunner generalAlgorithmRunner = (GeneralAlgorithmRunner) model2; this.algorithm = generalAlgorithmRunner.getAlgorithm(); } String referenceName = params.getString("referenceGraphName", null); - if (referenceName == null) { - throw new IllegalArgumentException("Must specify a reference graph."); - } else { - MultipleGraphSource model11 = (MultipleGraphSource) model1; - Object model21 = model2; + if (referenceName.equals(model1.getName())) { + if (model1 instanceof Simulation && model2 instanceof GeneralAlgorithmRunner) { + this.referenceGraphs = ((GeneralAlgorithmRunner) model2).getCompareGraphs(((Simulation) model1).getGraphs()); + } else if (model1 instanceof MultipleGraphSource) { + this.referenceGraphs = ((MultipleGraphSource) model1).getGraphs(); + } - if (referenceName.equals(model1.getName())) { - if (model11 instanceof MultipleGraphSource) { - this.referenceGraphs = ((MultipleGraphSource) model11).getGraphs(); - } + if (model2 instanceof MultipleGraphSource) { + this.targetGraphs = ((MultipleGraphSource) model2).getGraphs(); + } - if (model21 instanceof MultipleGraphSource) { - this.targetGraphs = ((MultipleGraphSource) model21).getGraphs(); + if (referenceGraphs.size() == 1 && targetGraphs.size() > 1) { + Graph graph = referenceGraphs.get(0); + referenceGraphs = new ArrayList<>(); + for (Graph _graph : targetGraphs) { + referenceGraphs.add(_graph); } + } - if (referenceGraphs == null) { - this.referenceGraphs = Collections.singletonList(((GraphSource) model11).getGraph()); + if (targetGraphs.size() == 1 && referenceGraphs.size() > 1) { + Graph graph = targetGraphs.get(0); + targetGraphs = new ArrayList<>(); + for (Graph _graph : referenceGraphs) { + targetGraphs.add(graph); } + } - if (targetGraphs == null) { - this.targetGraphs = Collections.singletonList(((GraphSource) model21).getGraph()); - } + if (referenceGraphs == null) { + this.referenceGraphs = Collections.singletonList(((GraphSource) model1).getGraph()); + } - this.targetName = ((SessionModel) model21).getName(); - this.referenceName = ((SessionModel) model11).getName(); - } else if (referenceName.equals(model2.getName())) { - if (model21 instanceof MultipleGraphSource) { - this.referenceGraphs = ((MultipleGraphSource) model21).getGraphs(); - } + if (targetGraphs == null) { + this.targetGraphs = Collections.singletonList(((GraphSource) model2).getGraph()); + } + } else if (referenceName.equals(model2.getName())) { + if (model2 instanceof Simulation && model1 instanceof GeneralAlgorithmRunner) { + this.referenceGraphs = ((GeneralAlgorithmRunner) model1).getCompareGraphs(((Simulation) model2).getGraphs()); + } else if (model1 instanceof MultipleGraphSource) { + this.referenceGraphs = ((MultipleGraphSource) model2).getGraphs(); + } - if (model11 instanceof MultipleGraphSource) { - this.targetGraphs = ((MultipleGraphSource) model11).getGraphs(); - } -// - if (referenceGraphs == null) { - this.referenceGraphs = Collections.singletonList(((GraphSource) model21).getGraph()); + if (model1 instanceof MultipleGraphSource) { + this.targetGraphs = ((MultipleGraphSource) model1).getGraphs(); + } + + if (referenceGraphs.size() == 1 && targetGraphs.size() > 1) { + Graph graph = referenceGraphs.get(0); + referenceGraphs = new ArrayList<>(); + for (Graph _graph : targetGraphs) { + referenceGraphs.add(graph); } + } - if (targetGraphs == null) { - this.targetGraphs = Collections.singletonList(((GraphSource) model11).getGraph()); + if (targetGraphs.size() == 1 && referenceGraphs.size() > 1) { + Graph graph = targetGraphs.get(0); + targetGraphs = new ArrayList<>(); + for (Graph _graph : referenceGraphs) { + targetGraphs.add(graph); } + } - this.targetName = ((SessionModel) model11).getName(); - this.referenceName = ((SessionModel) model21).getName(); - } else { - throw new IllegalArgumentException( - "Neither of the supplied session models is named '" + - referenceName + "'."); + if (referenceGraphs == null) { + this.referenceGraphs = Collections.singletonList(((GraphSource) model2).getGraph()); } + + if (targetGraphs == null) { + this.targetGraphs = Collections.singletonList(((GraphSource) model1).getGraph()); + } + } else { + throw new IllegalArgumentException( + "Neither of the supplied session models is named '" + + referenceName + "'."); + } + + for (int i = 0; i < targetGraphs.size(); i++) { + targetGraphs.set(i, GraphUtils.replaceNodes(targetGraphs.get(i), referenceGraphs.get(i).getNodes())); } if (referenceGraphs.size() != targetGraphs.size()) { - throw new IllegalArgumentException("I was expecting the same number of graph in each parent."); + throw new IllegalArgumentException("I was expecting the same number of graphs in each parent."); } if (algorithm != null) { for (int i = 0; i < referenceGraphs.size(); i++) { @@ -190,12 +203,14 @@ private void newExecution() { statistics.add(new AdjacencyRecall()); statistics.add(new ArrowheadPrecision()); statistics.add(new ArrowheadRecall()); + statistics.add(new TwoCyclePrecision()); + statistics.add(new TwoCycleRecall()); // statistics.add(new ElapsedTime()); - statistics.add(new F1Adj()); - statistics.add(new F1Arrow()); - statistics.add(new MathewsCorrAdj()); - statistics.add(new MathewsCorrArrow()); - statistics.add(new SHD()); +// statistics.add(new F1Adj()); +// statistics.add(new F1Arrow()); +// statistics.add(new MathewsCorrAdj()); +// statistics.add(new MathewsCorrArrow()); +// statistics.add(new SHD()); List variables = new ArrayList<>(); @@ -217,38 +232,16 @@ private void addRecord(int i) { } } - public TabularComparison(GraphWrapper referenceGraph, - AbstractAlgorithmRunner algorithmRunner, - Parameters params) { - this(referenceGraph, (SessionModel) algorithmRunner, params); - } - - public TabularComparison(GraphWrapper referenceWrapper, - GraphWrapper targetWrapper, Parameters params) { - this(referenceWrapper, (SessionModel) targetWrapper, params); - } - - public TabularComparison(DagWrapper referenceGraph, - AbstractAlgorithmRunner algorithmRunner, - Parameters params) { - this(referenceGraph, (SessionModel) algorithmRunner, params); - } - - public TabularComparison(DagWrapper referenceWrapper, - GraphWrapper targetWrapper, Parameters params) { - this(referenceWrapper, (SessionModel) targetWrapper, params); - } - /** * Generates a simple exemplar of this class to test serialization. * * @see TetradSerializableUtils */ - public static TabularComparison serializableInstance() { - return new TabularComparison(DagWrapper.serializableInstance(), - DagWrapper.serializableInstance(), - new Parameters()); - } +// public static TabularComparison serializableInstance() { +// return new TabularComparison(DagWrapper.serializableInstance(), +// DagWrapper.serializableInstance(), +// new Parameters()); +// } //==============================PUBLIC METHODS========================// diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/TsFciRunner2.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/TsFciRunner2.java index 9396c012a8..ea414a8a93 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/TsFciRunner2.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/TsFciRunner2.java @@ -124,7 +124,7 @@ public void execute() { fci.setKnowledge(knowledge); fci.setCompleteRuleSetUsed(params.getBoolean("completeRuleSetUsed", false)); fci.setMaxPathLength(params.getInt("maxReachablePathLength", -1)); - fci.setMaxIndegree(params.getInt("maxIndegree")); + fci.setMaxDegree(params.getInt("maxIndegree")); fci.setCompleteRuleSetUsed(false); fci.setFaithfulnessAssumed(params.getBoolean("faithfulnessAssumed", true)); graph = fci.search(); diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/TsFgsRunner.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/TsFgesRunner.java similarity index 81% rename from tetrad-gui/src/main/java/edu/cmu/tetradapp/model/TsFgsRunner.java rename to tetrad-gui/src/main/java/edu/cmu/tetradapp/model/TsFgesRunner.java index a350276f56..4fe08a3559 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/TsFgsRunner.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/TsFgesRunner.java @@ -47,7 +47,7 @@ * @author Daniel Malinsky */ -public class TsFgsRunner extends AbstractAlgorithmRunner implements IFgsRunner, GraphSource, +public class TsFgesRunner extends AbstractAlgorithmRunner implements IFgesRunner, GraphSource, PropertyChangeListener, IGesRunner, Indexable, DoNotAddOldModel, Unmarshallable { static final long serialVersionUID = 23L; private LinkedHashMap allParamSettings; @@ -57,36 +57,36 @@ public enum Type {CONTINUOUS, DISCRETE, MIXED, GRAPH} private transient List listeners; private List topGraphs; private int index; - private transient TsFgs2 fgs; + private transient TsFges2 fges; private transient Graph initialGraph; //============================CONSTRUCTORS============================// - public TsFgsRunner(DataWrapper[] dataWrappers, Parameters params, KnowledgeBoxModel knowledgeBoxModel) { + public TsFgesRunner(DataWrapper[] dataWrappers, Parameters params, KnowledgeBoxModel knowledgeBoxModel) { super(new MergeDatasetsWrapper(dataWrappers, params), params, knowledgeBoxModel); } - public TsFgsRunner(DataWrapper[] dataWrappers, Parameters params) { + public TsFgesRunner(DataWrapper[] dataWrappers, Parameters params) { super(new MergeDatasetsWrapper(dataWrappers, params), params, null); } - public TsFgsRunner(DataWrapper[] dataWrappers, GraphSource graph, Parameters params) { + public TsFgesRunner(DataWrapper[] dataWrappers, GraphSource graph, Parameters params) { super(new MergeDatasetsWrapper(dataWrappers, params), params, null); if (graph == this) throw new IllegalArgumentException(); this.initialGraph = graph.getGraph(); } - public TsFgsRunner(DataWrapper[] dataWrappers, GraphSource graph, Parameters params, KnowledgeBoxModel knowledgeBoxModel) { + public TsFgesRunner(DataWrapper[] dataWrappers, GraphSource graph, Parameters params, KnowledgeBoxModel knowledgeBoxModel) { super(new MergeDatasetsWrapper(dataWrappers, params), params, knowledgeBoxModel); if (graph == this) throw new IllegalArgumentException(); this.initialGraph = graph.getGraph(); } - public TsFgsRunner(GraphWrapper graphWrapper, Parameters params, KnowledgeBoxModel knowledgeBoxModel) { + public TsFgesRunner(GraphWrapper graphWrapper, Parameters params, KnowledgeBoxModel knowledgeBoxModel) { super(graphWrapper.getGraph(), params, knowledgeBoxModel); } - public TsFgsRunner(GraphWrapper graphWrapper, Parameters params) { + public TsFgesRunner(GraphWrapper graphWrapper, Parameters params) { super(graphWrapper.getGraph(), params, null); } @@ -125,9 +125,9 @@ public void execute() { if (model instanceof Graph) { GraphScore gesScore = new GraphScore((Graph) model); - fgs = new TsFgs2(gesScore); - fgs.setKnowledge((IKnowledge) getParams().get("knowledge", new Knowledge2())); - fgs.setVerbose(true); + fges = new TsFges2(gesScore); + fges.setKnowledge((IKnowledge) getParams().get("knowledge", new Knowledge2())); + fges.setVerbose(true); } else { double penaltyDiscount = params.getDouble("penaltyDiscount", 4); @@ -141,24 +141,24 @@ public void execute() { // SvrScore gesScore = new SvrScore((DataSet) model); gesScore.setPenaltyDiscount(penaltyDiscount); System.out.println("Score done"); - fgs = new TsFgs2(gesScore); + fges = new TsFges2(gesScore); } else if (dataSet.isDiscrete()) { double samplePrior = getParams().getDouble("samplePrior", 1); double structurePrior = getParams().getDouble("structurePrior", 1); BDeuScore score = new BDeuScore(dataSet); score.setSamplePrior(samplePrior); score.setStructurePrior(structurePrior); - fgs = new TsFgs2(score); + fges = new TsFges2(score); } else { MixedBicScore gesScore = new MixedBicScore(dataSet); gesScore.setPenaltyDiscount(penaltyDiscount); - fgs = new TsFgs2(gesScore); + fges = new TsFges2(gesScore); } } else if (model instanceof ICovarianceMatrix) { SemBicScore gesScore = new SemBicScore((ICovarianceMatrix) model); gesScore.setPenaltyDiscount(penaltyDiscount); gesScore.setPenaltyDiscount(penaltyDiscount); - fgs = new TsFgs2(gesScore); + fges = new TsFges2(gesScore); } else if (model instanceof DataModelList) { DataModelList list = (DataModelList) model; @@ -171,7 +171,7 @@ else if (model instanceof DataModelList) { } if (list.size() != 1) { - throw new IllegalArgumentException("FGS takes exactly one data set, covariance matrix, or initialGraph " + + throw new IllegalArgumentException("FGES takes exactly one data set, covariance matrix, or initialGraph " + "as input. For multiple data sets as input, use IMaGES."); } @@ -179,26 +179,26 @@ else if (model instanceof DataModelList) { double penalty = getParams().getDouble("penaltyDiscount", 4); if (params.getBoolean("firstNontriangular", false)) { - SemBicScoreImages fgsScore = new SemBicScoreImages(list); - fgsScore.setPenaltyDiscount(penalty); - fgs = new TsFgs2(fgsScore); + SemBicScoreImages fgesScore = new SemBicScoreImages(list); + fgesScore.setPenaltyDiscount(penalty); + fges = new TsFges2(fgesScore); } else { - SemBicScoreImages fgsScore = new SemBicScoreImages(list); - fgsScore.setPenaltyDiscount(penalty); - fgs = new TsFgs2(fgsScore); + SemBicScoreImages fgesScore = new SemBicScoreImages(list); + fgesScore.setPenaltyDiscount(penalty); + fges = new TsFges2(fgesScore); } } else if (allDiscrete(list)) { double structurePrior = getParams().getDouble("structurePrior", 1); double samplePrior = getParams().getDouble("samplePrior", 1); - BdeuScoreImages fgsScore = new BdeuScoreImages(list); - fgsScore.setSamplePrior(samplePrior); - fgsScore.setStructurePrior(structurePrior); + BdeuScoreImages fgesScore = new BdeuScoreImages(list); + fgesScore.setSamplePrior(samplePrior); + fgesScore.setStructurePrior(structurePrior); if (params.getBoolean("firstNontriangular", false)) { - fgs = new TsFgs2(fgsScore); + fges = new TsFges2(fgesScore); } else { - fgs = new TsFgs2(fgsScore); + fges = new TsFges2(fgesScore); } } else { throw new IllegalArgumentException("Data must be either all discrete or all continuous."); @@ -208,13 +208,13 @@ else if (model instanceof DataModelList) { } } - fgs.setInitialGraph(initialGraph); - fgs.setKnowledge((IKnowledge) getParams().get("knowledge", new Knowledge2())); - fgs.setNumPatternsToStore(params.getInt("numPatternsToSave", 1)); - fgs.setVerbose(true); -// fgs.setHeuristicSpeedup(((Parameters) params.getIndTestParams()).isFaithfulnessAssumed()); -// * there is no setHeuristicSpeedup option in Fgs2 and so likewise TsFgs2. * - Graph graph = fgs.search(); + fges.setInitialGraph(initialGraph); + fges.setKnowledge((IKnowledge) getParams().get("knowledge", new Knowledge2())); + fges.setNumPatternsToStore(params.getInt("numPatternsToSave", 1)); + fges.setVerbose(true); +// fges.setHeuristicSpeedup(((Parameters) params.getIndTestParams()).isFaithfulnessAssumed()); +// * there is no setHeuristicSpeedup option in Fges2 and so likewise TsFges2. * + Graph graph = fges.search(); if (getSourceGraph() != null) { GraphUtils.arrangeBySourceGraph(graph, getSourceGraph()); @@ -226,7 +226,7 @@ else if (model instanceof DataModelList) { setResultGraph(graph); - this.topGraphs = new ArrayList<>(fgs.getTopGraphs()); + this.topGraphs = new ArrayList<>(fges.getTopGraphs()); if (topGraphs.isEmpty()) { @@ -240,8 +240,8 @@ else if (model instanceof DataModelList) { * Executes the algorithm, producing (at least) a result workbench. Must be * implemented in the extending class. */ -// public FgsRunner.Type getType() { return FgsRunner.getType(); } - public FgsRunner.Type getType() { +// public FgesRunner.Type getType() { return FgesRunner.getType(); } + public FgesRunner.Type getType() { Object model = getDataModel(); if (model == null && getSourceGraph() != null) { @@ -255,32 +255,32 @@ public FgsRunner.Type getType() { "file when you save the session. It can, however, be recreated from the saved seed."); } - FgsRunner.Type type; + FgesRunner.Type type; if (model instanceof Graph) { - type = FgsRunner.Type.GRAPH; + type = FgesRunner.Type.GRAPH; } else if (model instanceof DataSet) { DataSet dataSet = (DataSet) model; if (dataSet.isContinuous()) { - type = FgsRunner.Type.CONTINUOUS; + type = FgesRunner.Type.CONTINUOUS; } else if (dataSet.isDiscrete()) { - type = FgsRunner.Type.DISCRETE; + type = FgesRunner.Type.DISCRETE; } else { - type = FgsRunner.Type.MIXED; + type = FgesRunner.Type.MIXED; // throw new IllegalStateException("Data set must either be continuous or discrete."); } } else if (model instanceof ICovarianceMatrix) { - type = FgsRunner.Type.CONTINUOUS; + type = FgesRunner.Type.CONTINUOUS; } else if (model instanceof DataModelList) { DataModelList list = (DataModelList) model; if (allContinuous(list)) { - type = FgsRunner.Type.CONTINUOUS; + type = FgesRunner.Type.CONTINUOUS; } else if (allDiscrete(list)) { - type = FgsRunner.Type.DISCRETE; + type = FgesRunner.Type.DISCRETE; } else { - type = FgsRunner.Type.MIXED; + type = FgesRunner.Type.MIXED; // throw new IllegalArgumentException("Data must be either all discrete or all continuous."); } } else { @@ -369,7 +369,7 @@ public Map getParamSettings() { @Override public String getAlgorithmName() { - return "FGS"; + return "FGES"; } public void propertyChange(PropertyChangeEvent evt) { @@ -398,15 +398,15 @@ public List getTopGraphs() { } public String getBayesFactorsReport(Graph dag) { - if (fgs == null) { + if (fges == null) { return "Please re-run IMaGES."; } else { - return fgs.logEdgeBayesFactorsString(dag); + return fges.logEdgeBayesFactorsString(dag); } } public GraphScorer getGraphScorer() { - return fgs; + return fges; } } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/TsGFciRunner.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/TsGFciRunner.java index b5f8e2535d..cd37fea6e9 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/TsGFciRunner.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/TsGFciRunner.java @@ -229,7 +229,7 @@ public void execute() { } if (list.size() != 1) { - throw new IllegalArgumentException("FGS takes exactly one data set, covariance matrix, or initialGraph " + + throw new IllegalArgumentException("FGES takes exactly one data set, covariance matrix, or initialGraph " + "as input. For multiple data sets as input, use IMaGES."); } @@ -239,21 +239,21 @@ public void execute() { if (allContinuous(list)) { double penalty = params.getDouble("penaltyDiscount", 4); - SemBicScoreImages fgsScore = new SemBicScoreImages(list); - fgsScore.setPenaltyDiscount(penalty); + SemBicScoreImages fgesScore = new SemBicScoreImages(list); + fgesScore.setPenaltyDiscount(penalty); IndependenceTest test = new IndTestDSep((Graph) model); - gfci = new TsGFci(test, fgsScore); + gfci = new TsGFci(test, fgesScore); gfci.setKnowledge((IKnowledge) getParams().get("knowledge", new Knowledge2())); } // else if (allDiscrete(list)) { // double structurePrior = ((Parameters) getParameters()).getStructurePrior(); // double samplePrior = ((Parameters) getParameters()).getSamplePrior(); // -// BdeuScoreImages fgsScore = new BdeuScoreImages(list); -// fgsScore.setSamplePrior(samplePrior); -// fgsScore.setStructurePrior(structurePrior); +// BdeuScoreImages fgesScore = new BdeuScoreImages(list); +// fgesScore.setSamplePrior(samplePrior); +// fgesScore.setStructurePrior(structurePrior); // -// gfci = new GFci(fgsScore); +// gfci = new GFci(fgesScore); // } else { throw new IllegalArgumentException("Data must be either all discrete or all continuous."); diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/TsImagesRunner.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/TsImagesRunner.java index 3ae0d758d6..5a10a20624 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/TsImagesRunner.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/TsImagesRunner.java @@ -43,20 +43,20 @@ * @author Daniel Malinsky */ -public class TsImagesRunner extends AbstractAlgorithmRunner implements IFgsRunner, GraphSource, +public class TsImagesRunner extends AbstractAlgorithmRunner implements IFgesRunner, GraphSource, PropertyChangeListener, IGesRunner, Indexable { static final long serialVersionUID = 23L; - public FgsRunner.Type getType() { + public FgesRunner.Type getType() { return type; } private transient List listeners; private List topGraphs = new LinkedList<>(); private int index; - private transient TsGFci fgs; + private transient TsGFci fges; private Graph graph; - private FgsRunner.Type type; + private FgesRunner.Type type; //============================CONSTRUCTORS============================// @@ -121,7 +121,7 @@ public void execute() { if (model instanceof Graph) { GraphScore gesScore = new GraphScore((Graph) model); IndependenceTest test = new IndTestScore(gesScore); - fgs = new TsGFci(test, gesScore); + fges = new TsGFci(test, gesScore); } else if (model instanceof DataSet) { DataSet dataSet = (DataSet) model; @@ -129,7 +129,7 @@ public void execute() { SemBicScore gesScore = new SemBicScore(new CovarianceMatrixOnTheFly((DataSet) model)); gesScore.setPenaltyDiscount(penaltyDiscount); IndependenceTest test = new IndTestScore(gesScore); - fgs = new TsGFci(test, gesScore); + fges = new TsGFci(test, gesScore); } else if (dataSet.isDiscrete()) { double samplePrior = getParams().getDouble("samplePrior", 1); double structurePrior = getParams().getDouble("structurePrior", 1); @@ -137,7 +137,7 @@ public void execute() { score.setSamplePrior(samplePrior); score.setStructurePrior(structurePrior); IndependenceTest test = new IndTestScore(score); - fgs = new TsGFci(test, score); + fges = new TsGFci(test, score); } else { throw new IllegalStateException("Data set must either be continuous or discrete."); } @@ -145,7 +145,7 @@ public void execute() { SemBicScore gesScore = new SemBicScore((ICovarianceMatrix) model); gesScore.setPenaltyDiscount(penaltyDiscount); IndependenceTest test = new IndTestScore(gesScore); - fgs = new TsGFci(test, gesScore); + fges = new TsGFci(test, gesScore); } else if (model instanceof DataModelList) { DataModelList list = (DataModelList) model; @@ -157,7 +157,7 @@ public void execute() { } // if (list.size() != 1) { -// throw new IllegalArgumentException("FGS takes exactly one data set, covariance matrix, or graph " + +// throw new IllegalArgumentException("FGES takes exactly one data set, covariance matrix, or graph " + // "as input. For multiple data sets as input, use IMaGES."); // } @@ -166,20 +166,20 @@ public void execute() { if (allContinuous(list)) { double penalty = penaltyDiscount; - SemBicScoreImages fgsScore = new SemBicScoreImages(list); - fgsScore.setPenaltyDiscount(penalty); - IndependenceTest test = new IndTestScore(fgsScore); - fgs = new TsGFci(test, fgsScore); + SemBicScoreImages fgesScore = new SemBicScoreImages(list); + fgesScore.setPenaltyDiscount(penalty); + IndependenceTest test = new IndTestScore(fgesScore); + fges = new TsGFci(test, fgesScore); } else if (allDiscrete(list)) { double structurePrior = getParams().getDouble("structurePrior", 1); double samplePrior = getParams().getDouble("samplePrior", 1); - BdeuScoreImages fgsScore = new BdeuScoreImages(list); - fgsScore.setSamplePrior(samplePrior); - fgsScore.setStructurePrior(structurePrior); - IndependenceTest test = new IndTestScore(fgsScore); - fgs = new TsGFci(test, fgsScore); + BdeuScoreImages fgesScore = new BdeuScoreImages(list); + fgesScore.setSamplePrior(samplePrior); + fgesScore.setStructurePrior(structurePrior); + IndependenceTest test = new IndTestScore(fgesScore); + fges = new TsGFci(test, fgesScore); } else { throw new IllegalArgumentException("Data must be either all discrete or all continuous."); @@ -188,11 +188,11 @@ public void execute() { System.out.println("No viable input."); } - fgs.setKnowledge((IKnowledge) getParams().get("knowledge", new Knowledge2())); -// fgs.setNumPatternsToStore(params.getNumPatternsToSave()); // removed for TsGFci -// fgs.setHeuristicSpeedup(((Parameters) params.getIndTestParams()).isFaithfulnessAssumed()); // removed for TsGFci - fgs.setVerbose(true); - Graph graph = fgs.search(); + fges.setKnowledge((IKnowledge) getParams().get("knowledge", new Knowledge2())); +// fges.setNumPatternsToStore(params.getNumPatternsToSave()); // removed for TsGFci +// fges.setHeuristicSpeedup(((Parameters) params.getIndTestParams()).isFaithfulnessAssumed()); // removed for TsGFci + fges.setVerbose(true); + Graph graph = fges.search(); if (getSourceGraph() != null) { GraphUtils.arrangeBySourceGraph(graph, getSourceGraph()); @@ -223,7 +223,7 @@ public void execute() { * Executes the algorithm, producing (at least) a result workbench. Must be * implemented in the extending class. */ - private FgsRunner.Type computeType() { + private FgesRunner.Type computeType() { Object model = getDataModel(); if (model == null && getSourceGraph() != null) { @@ -238,26 +238,26 @@ private FgsRunner.Type computeType() { } if (model instanceof Graph) { - type = FgsRunner.Type.GRAPH; + type = FgesRunner.Type.GRAPH; } else if (model instanceof DataSet) { DataSet dataSet = (DataSet) model; if (dataSet.isContinuous()) { - type = FgsRunner.Type.CONTINUOUS; + type = FgesRunner.Type.CONTINUOUS; } else if (dataSet.isDiscrete()) { - type = FgsRunner.Type.DISCRETE; + type = FgesRunner.Type.DISCRETE; } else { throw new IllegalStateException("Data set must either be continuous or discrete."); } } else if (model instanceof ICovarianceMatrix) { - type = FgsRunner.Type.CONTINUOUS; + type = FgesRunner.Type.CONTINUOUS; } else if (model instanceof DataModelList) { DataModelList list = (DataModelList) model; if (allContinuous(list)) { - type = FgsRunner.Type.CONTINUOUS; + type = FgesRunner.Type.CONTINUOUS; } else if (allDiscrete(list)) { - type = FgsRunner.Type.DISCRETE; + type = FgesRunner.Type.DISCRETE; } else { throw new IllegalArgumentException("Data must be either all discrete or all continuous."); } @@ -363,18 +363,18 @@ public List getTopGraphs() { // never gets used, commented out // public String getBayesFactorsReport(Graph dag) { -// if (fgs == null) { +// if (fges == null) { // return "Please re-run IMaGES."; // } else { -// return fgs.logEdgeBayesFactorsString(dag); +// return fges.logEdgeBayesFactorsString(dag); // } // } // public GraphScorer getGraphScorer() { -// return fgs; +// return fges; // } public TsGFci getGraphScorer() { - return fgs; + return fges; } // changed return type for TsGFci public void setGraph(Graph graph) { diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/WFgsRunner.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/WFgesRunner.java similarity index 91% rename from tetrad-gui/src/main/java/edu/cmu/tetradapp/model/WFgsRunner.java rename to tetrad-gui/src/main/java/edu/cmu/tetradapp/model/WFgesRunner.java index 64e4ec009b..6d42c0e7b4 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/WFgsRunner.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/WFgesRunner.java @@ -46,7 +46,7 @@ * @author Ricardo Silva */ -public class WFgsRunner extends AbstractAlgorithmRunner implements IFgsRunner, GraphSource, +public class WFgesRunner extends AbstractAlgorithmRunner implements IFgesRunner, GraphSource, PropertyChangeListener, IGesRunner, Indexable, DoNotAddOldModel, Unmarshallable { static final long serialVersionUID = 23L; private LinkedHashMap allParamSettings; @@ -58,7 +58,7 @@ public class WFgsRunner extends AbstractAlgorithmRunner implements IFgsRunner, G //============================CONSTRUCTORS============================// - public WFgsRunner(DataWrapper[] dataWrappers, Parameters params) { + public WFgesRunner(DataWrapper[] dataWrappers, Parameters params) { super(new MergeDatasetsWrapper(dataWrappers, params), params, null); } @@ -86,9 +86,9 @@ public void execute() { double penaltyDiscount = params.getDouble("penaltyDiscount", 4); - WFgs fgs = new WFgs(dataSet); - fgs.setPenaltyDiscount(penaltyDiscount); - Graph graph = fgs.search(); + WFges fges = new WFges(dataSet); + fges.setPenaltyDiscount(penaltyDiscount); + Graph graph = fges.search(); if (getSourceGraph() != null) { GraphUtils.arrangeBySourceGraph(graph, getSourceGraph()); @@ -109,9 +109,9 @@ public void execute() { * Executes the algorithm, producing (at least) a result workbench. Must be * implemented in the extending class. */ - public FgsRunner.Type getType() { + public FgesRunner.Type getType() { if (true) { - return FgsRunner.Type.CONTINUOUS; + return FgesRunner.Type.CONTINUOUS; } Object model = getDataModel(); @@ -127,32 +127,32 @@ public FgsRunner.Type getType() { "file when you save the session. It can, however, be recreated from the saved seed."); } - FgsRunner.Type type; + FgesRunner.Type type; if (model instanceof Graph) { - type = FgsRunner.Type.GRAPH; + type = FgesRunner.Type.GRAPH; } else if (model instanceof DataSet) { DataSet dataSet = (DataSet) model; if (dataSet.isContinuous()) { - type = FgsRunner.Type.CONTINUOUS; + type = FgesRunner.Type.CONTINUOUS; } else if (dataSet.isDiscrete()) { - type = FgsRunner.Type.DISCRETE; + type = FgesRunner.Type.DISCRETE; } else { - type = FgsRunner.Type.MIXED; + type = FgesRunner.Type.MIXED; // throw new IllegalStateException("Data set must either be continuous or discrete."); } } else if (model instanceof ICovarianceMatrix) { - type = FgsRunner.Type.CONTINUOUS; + type = FgesRunner.Type.CONTINUOUS; } else if (model instanceof DataModelList) { DataModelList list = (DataModelList) model; if (allContinuous(list)) { - type = FgsRunner.Type.CONTINUOUS; + type = FgesRunner.Type.CONTINUOUS; } else if (allDiscrete(list)) { - type = FgsRunner.Type.DISCRETE; + type = FgesRunner.Type.DISCRETE; } else { - type = FgsRunner.Type.MIXED; + type = FgesRunner.Type.MIXED; // throw new IllegalArgumentException("Data must be either all discrete or all continuous."); } } else { @@ -241,7 +241,7 @@ public Map getParamSettings() { @Override public String getAlgorithmName() { - return "FGS"; + return "FGES"; } public void propertyChange(PropertyChangeEvent evt) { diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/datamanip/CovMatrixDifferenceWrapper.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/datamanip/CovMatrixDifferenceWrapper.java index af085e246e..1a8b34cf1a 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/datamanip/CovMatrixDifferenceWrapper.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/datamanip/CovMatrixDifferenceWrapper.java @@ -70,60 +70,60 @@ public CovMatrixDifferenceWrapper(DataWrapper wrapper1, DataWrapper wrapper2, Pa } - public CovMatrixDifferenceWrapper(SemEstimatorWrapper wrapper1, DataWrapper wrapper2, Parameters params) { - if (wrapper1 == null || wrapper2 == null) { - throw new NullPointerException("The data must not be null"); - } - - DataModel model2 = wrapper2.getSelectedDataModel(); - - if (!(model2 instanceof ICovarianceMatrix)) { - throw new IllegalArgumentException("Expecting corrariance matrices."); - } - - TetradMatrix corr1 = wrapper1.getEstimatedSemIm().getImplCovarMeas(); - TetradMatrix corr2 = ((ICovarianceMatrix) model2).getMatrix(); - - TetradMatrix corr3 = calcDifference(corr1, corr2); - - ICovarianceMatrix corrWrapper = new CovarianceMatrix(model2.getVariables(), corr3, - ((ICovarianceMatrix) model2).getSampleSize()); - - setDataModel(corrWrapper); - setSourceGraph(wrapper2.getSourceGraph()); - LogDataUtils.logDataModelList("Difference of matrices.", getDataModelList()); - - } - - public CovMatrixDifferenceWrapper(SemImWrapper wrapper1, DataWrapper wrapper2, Parameters params) { - try { - if (wrapper1 == null || wrapper2 == null) { - throw new NullPointerException("The data must not be null"); - } - - DataModel model2 = wrapper2.getSelectedDataModel(); - - if (!(model2 instanceof ICovarianceMatrix)) { - throw new IllegalArgumentException("Expecting corrariance matrices."); - } - - TetradMatrix corr1 = wrapper1.getSemIm().getImplCovarMeas(); - TetradMatrix corr2 = ((ICovarianceMatrix) model2).getMatrix(); - - TetradMatrix corr3 = calcDifference(corr1, corr2); - - ICovarianceMatrix corrWrapper = new CovarianceMatrix(model2.getVariables(), corr3, - ((ICovarianceMatrix) model2).getSampleSize()); - - setDataModel(corrWrapper); - setSourceGraph(wrapper2.getSourceGraph()); - LogDataUtils.logDataModelList("Difference of matrices.", getDataModelList()); - } catch (Exception e) { - e.printStackTrace(); - throw new RuntimeException(e); - } - - } +// public CovMatrixDifferenceWrapper(SemEstimatorWrapper wrapper1, DataWrapper wrapper2, Parameters params) { +// if (wrapper1 == null || wrapper2 == null) { +// throw new NullPointerException("The data must not be null"); +// } +// +// DataModel model2 = wrapper2.getSelectedDataModel(); +// +// if (!(model2 instanceof ICovarianceMatrix)) { +// throw new IllegalArgumentException("Expecting corrariance matrices."); +// } +// +// TetradMatrix corr1 = wrapper1.getEstimatedSemIm().getImplCovarMeas(); +// TetradMatrix corr2 = ((ICovarianceMatrix) model2).getMatrix(); +// +// TetradMatrix corr3 = calcDifference(corr1, corr2); +// +// ICovarianceMatrix corrWrapper = new CovarianceMatrix(model2.getVariables(), corr3, +// ((ICovarianceMatrix) model2).getSampleSize()); +// +// setDataModel(corrWrapper); +// setSourceGraph(wrapper2.getSourceGraph()); +// LogDataUtils.logDataModelList("Difference of matrices.", getDataModelList()); +// +// } + +// public CovMatrixDifferenceWrapper(SemImWrapper wrapper1, DataWrapper wrapper2, Parameters params) { +// try { +// if (wrapper1 == null || wrapper2 == null) { +// throw new NullPointerException("The data must not be null"); +// } +// +// DataModel model2 = wrapper2.getSelectedDataModel(); +// +// if (!(model2 instanceof ICovarianceMatrix)) { +// throw new IllegalArgumentException("Expecting corrariance matrices."); +// } +// +// TetradMatrix corr1 = wrapper1.getSemIm().getImplCovarMeas(); +// TetradMatrix corr2 = ((ICovarianceMatrix) model2).getMatrix(); +// +// TetradMatrix corr3 = calcDifference(corr1, corr2); +// +// ICovarianceMatrix corrWrapper = new CovarianceMatrix(model2.getVariables(), corr3, +// ((ICovarianceMatrix) model2).getSampleSize()); +// +// setDataModel(corrWrapper); +// setSourceGraph(wrapper2.getSourceGraph()); +// LogDataUtils.logDataModelList("Difference of matrices.", getDataModelList()); +// } catch (Exception e) { +// e.printStackTrace(); +// throw new RuntimeException(e); +// } +// +// } private TetradMatrix calcDifference(TetradMatrix corr1, TetradMatrix corr2) { if (corr1.rows() != corr2.rows()) { diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/datamanip/CovMatrixSumWrapper.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/datamanip/CovMatrixSumWrapper.java index b1372bdaa6..09cacdbdf2 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/datamanip/CovMatrixSumWrapper.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/datamanip/CovMatrixSumWrapper.java @@ -69,60 +69,60 @@ public CovMatrixSumWrapper(DataWrapper wrapper1, DataWrapper wrapper2) { } - public CovMatrixSumWrapper(SemEstimatorWrapper wrapper1, DataWrapper wrapper2) { - if (wrapper1 == null || wrapper2 == null) { - throw new NullPointerException("The data must not be null"); - } - - DataModel model2 = wrapper2.getSelectedDataModel(); - - if (!(model2 instanceof ICovarianceMatrix)) { - throw new IllegalArgumentException("Expecting corrariance matrices."); - } - - TetradMatrix corr1 = wrapper1.getEstimatedSemIm().getImplCovarMeas(); - TetradMatrix corr2 = ((ICovarianceMatrix) model2).getMatrix(); - - TetradMatrix corr3 = calcSum(corr1, corr2); - - ICovarianceMatrix corrWrapper = new CovarianceMatrix(model2.getVariables(), corr3, - ((ICovarianceMatrix) model2).getSampleSize()); - - setDataModel(corrWrapper); - setSourceGraph(wrapper2.getSourceGraph()); - LogDataUtils.logDataModelList("Difference of matrices.", getDataModelList()); - - } - - public CovMatrixSumWrapper(SemImWrapper wrapper1, DataWrapper wrapper2) { - try { - if (wrapper1 == null || wrapper2 == null) { - throw new NullPointerException("The data must not be null"); - } - - DataModel model2 = wrapper2.getSelectedDataModel(); - - if (!(model2 instanceof ICovarianceMatrix)) { - throw new IllegalArgumentException("Expecting corrariance matrices."); - } - - TetradMatrix corr1 = wrapper1.getSemIm().getImplCovarMeas(); - TetradMatrix corr2 = ((ICovarianceMatrix) model2).getMatrix(); - - TetradMatrix corr3 = calcSum(corr1, corr2); - - ICovarianceMatrix corrWrapper = new CovarianceMatrix(model2.getVariables(), corr3, - ((ICovarianceMatrix) model2).getSampleSize()); - - setDataModel(corrWrapper); - setSourceGraph(wrapper2.getSourceGraph()); - LogDataUtils.logDataModelList("Difference of matrices.", getDataModelList()); - } catch (Exception e) { - e.printStackTrace(); - throw new RuntimeException(e); - } - - } +// public CovMatrixSumWrapper(SemEstimatorWrapper wrapper1, DataWrapper wrapper2) { +// if (wrapper1 == null || wrapper2 == null) { +// throw new NullPointerException("The data must not be null"); +// } +// +// DataModel model2 = wrapper2.getSelectedDataModel(); +// +// if (!(model2 instanceof ICovarianceMatrix)) { +// throw new IllegalArgumentException("Expecting corrariance matrices."); +// } +// +// TetradMatrix corr1 = wrapper1.getEstimatedSemIm().getImplCovarMeas(); +// TetradMatrix corr2 = ((ICovarianceMatrix) model2).getMatrix(); +// +// TetradMatrix corr3 = calcSum(corr1, corr2); +// +// ICovarianceMatrix corrWrapper = new CovarianceMatrix(model2.getVariables(), corr3, +// ((ICovarianceMatrix) model2).getSampleSize()); +// +// setDataModel(corrWrapper); +// setSourceGraph(wrapper2.getSourceGraph()); +// LogDataUtils.logDataModelList("Difference of matrices.", getDataModelList()); +// +// } + +// public CovMatrixSumWrapper(SemImWrapper wrapper1, DataWrapper wrapper2) { +// try { +// if (wrapper1 == null || wrapper2 == null) { +// throw new NullPointerException("The data must not be null"); +// } +// +// DataModel model2 = wrapper2.getSelectedDataModel(); +// +// if (!(model2 instanceof ICovarianceMatrix)) { +// throw new IllegalArgumentException("Expecting corrariance matrices."); +// } +// +// TetradMatrix corr1 = wrapper1.getSemIm().getImplCovarMeas(); +// TetradMatrix corr2 = ((ICovarianceMatrix) model2).getMatrix(); +// +// TetradMatrix corr3 = calcSum(corr1, corr2); +// +// ICovarianceMatrix corrWrapper = new CovarianceMatrix(model2.getVariables(), corr3, +// ((ICovarianceMatrix) model2).getSampleSize()); +// +// setDataModel(corrWrapper); +// setSourceGraph(wrapper2.getSourceGraph()); +// LogDataUtils.logDataModelList("Difference of matrices.", getDataModelList()); +// } catch (Exception e) { +// e.printStackTrace(); +// throw new RuntimeException(e); +// } +// +// } private TetradMatrix calcSum(TetradMatrix corr1, TetradMatrix corr2) { if (corr1.rows() != corr2.rows()) { diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/datamanip/CovMatrixWrapper.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/datamanip/CovMatrixWrapper.java index 7751db2e48..d99b94d42c 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/datamanip/CovMatrixWrapper.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/datamanip/CovMatrixWrapper.java @@ -95,22 +95,22 @@ public CovMatrixWrapper(DataWrapper wrapper, Parameters params) { } - public CovMatrixWrapper(SemImWrapper wrapper) { - if (wrapper == null) { - throw new NullPointerException("The Sem IM must not be null."); - } - - SemIm semIm = wrapper.getSemIm(); - - TetradMatrix matrix = semIm.getImplCovar(true); - List variables = semIm.getSemPm().getVariableNodes(); - - ICovarianceMatrix covarianceMatrix = new CovarianceMatrix(variables, matrix, semIm.getSampleSize()); - setDataModel(covarianceMatrix); - setSourceGraph(semIm.getSemPm().getGraph()); - - LogDataUtils.logDataModelList("Conversion of data to covariance matrix form.", getDataModelList()); - } +// public CovMatrixWrapper(SemImWrapper wrapper) { +// if (wrapper == null) { +// throw new NullPointerException("The Sem IM must not be null."); +// } +// +// SemIm semIm = wrapper.getSemIm(); +// +// TetradMatrix matrix = semIm.getImplCovar(true); +// List variables = semIm.getSemPm().getVariableNodes(); +// +// ICovarianceMatrix covarianceMatrix = new CovarianceMatrix(variables, matrix, semIm.getSampleSize()); +// setDataModel(covarianceMatrix); +// setSourceGraph(semIm.getSemPm().getGraph()); +// +// LogDataUtils.logDataModelList("Conversion of data to covariance matrix form.", getDataModelList()); +// } /** diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/datamanip/InverseMatrixWrapper.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/datamanip/InverseMatrixWrapper.java index 031f24c074..86ceb1371f 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/datamanip/InverseMatrixWrapper.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/datamanip/InverseMatrixWrapper.java @@ -75,22 +75,22 @@ public InverseMatrixWrapper(DataWrapper wrapper, Parameters params) { } - public InverseMatrixWrapper(SemImWrapper wrapper, Parameters params) { - if (wrapper == null) { - throw new NullPointerException("The Sem IM must not be null."); - } - - SemIm semIm = wrapper.getSemIm(); - - TetradMatrix matrix = semIm.getImplCovar(true); - List variables = semIm.getSemPm().getVariableNodes(); - - ICovarianceMatrix covarianceMatrix = new CovarianceMatrix(variables, matrix, semIm.getSampleSize()); - setDataModel(covarianceMatrix); - setSourceGraph(semIm.getSemPm().getGraph()); - - LogDataUtils.logDataModelList("Conversion of data to covariance matrix form.", getDataModelList()); - } +// public InverseMatrixWrapper(SemImWrapper wrapper, Parameters params) { +// if (wrapper == null) { +// throw new NullPointerException("The Sem IM must not be null."); +// } +// +// SemIm semIm = wrapper.getSemIm(); +// +// TetradMatrix matrix = semIm.getImplCovar(true); +// List variables = semIm.getSemPm().getVariableNodes(); +// +// ICovarianceMatrix covarianceMatrix = new CovarianceMatrix(variables, matrix, semIm.getSampleSize()); +// setDataModel(covarianceMatrix); +// setSourceGraph(semIm.getSemPm().getGraph()); +// +// LogDataUtils.logDataModelList("Conversion of data to covariance matrix form.", getDataModelList()); +// } /** diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/datamanip/TimeSeriesWrapper.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/datamanip/TimeSeriesWrapper.java index dacf6101b5..106c7ed754 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/datamanip/TimeSeriesWrapper.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/datamanip/TimeSeriesWrapper.java @@ -21,10 +21,8 @@ package edu.cmu.tetradapp.model.datamanip; -import edu.cmu.tetrad.data.DataModel; -import edu.cmu.tetrad.data.DataModelList; -import edu.cmu.tetrad.data.DataSet; -import edu.cmu.tetrad.data.LogDataUtils; +import edu.cmu.tetrad.algcomparison.utils.HasKnowledge; +import edu.cmu.tetrad.data.*; import edu.cmu.tetrad.search.TimeSeriesUtils; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.util.TetradSerializableUtils; @@ -37,6 +35,8 @@ public class TimeSeriesWrapper extends DataWrapper { static final long serialVersionUID = 23L; + private IKnowledge knowledge = new Knowledge2(); + /** * Constructs a new time series dataset. * @@ -57,6 +57,7 @@ public TimeSeriesWrapper(DataWrapper data, Parameters params) { if (dataSet.getName() != null) { timeSeries.setName(dataSet.getName()); } + knowledge = timeSeries.getKnowledge(); timeSeriesDataSets.add(timeSeries); } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/util/GraphUtils.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/util/GraphUtils.java index bc30818a87..e818a1e8b8 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/util/GraphUtils.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/util/GraphUtils.java @@ -1,16 +1,11 @@ package edu.cmu.tetradapp.util; import edu.cmu.tetrad.data.DataGraphUtils; -import edu.cmu.tetrad.graph.EdgeListGraph; -import edu.cmu.tetrad.graph.Graph; -import edu.cmu.tetrad.graph.GraphNode; -import edu.cmu.tetrad.graph.Node; +import edu.cmu.tetrad.graph.*; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.util.PointXy; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; +import java.util.*; /** * Created by jdramsey on 12/8/15. @@ -183,4 +178,67 @@ private static Graph makeRandomScaleFree(int numNodes, int numLatents, double al alpha, beta, deltaIn, deltaOut); return graph; } + + // Returns true if a path consisting of undirected and directed edges toward 'to' exists of + // length at most 'bound' except for an edge from->to itself. Cycle checker in other words. + public static boolean existsSemiDirectedPathExcept(Node from, Node to, int bound, Graph graph) { + Queue Q = new LinkedList<>(); + Set V = new HashSet<>(); + Q.offer(from); + V.add(from); + Node e = null; + int distance = 0; + + while (!Q.isEmpty()) { + Node t = Q.remove(); +// if (t == to) { +// return true; +// } + + if (e == t) { + e = null; + distance++; + if (distance > (bound == -1 ? 1000 : bound)) return false; + } + + for (Node u : graph.getAdjacentNodes(t)) { + Edge edge = graph.getEdge(t, u); + Node c = edu.cmu.tetrad.graph.GraphUtils.traverseSemiDirected(t, edge); + if (c == null) continue; + + if (t == from && c == to) { + continue; + } + + if (c == to) { + return true; + } + + if (!V.contains(c)) { + V.add(c); + Q.offer(c); + + if (e == null) { + e = u; + } + } + } + } + + return false; + } + + // Used to find semidirected paths for cycle checking. + public static Node traverseSemiDirected(Node node, Edge edge) { + if (node == edge.getNode1()) { + if (edge.getEndpoint1() == Endpoint.TAIL || edge.getEndpoint1() == Endpoint.CIRCLE) { + return edge.getNode2(); + } + } else if (node == edge.getNode2()) { + if (edge.getEndpoint2() == Endpoint.TAIL || edge.getEndpoint2() == Endpoint.CIRCLE) { + return edge.getNode1(); + } + } + return null; + } } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/util/SplashScreen.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/util/SplashScreen.java index 84b6e48b9d..cc322329d1 100755 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/util/SplashScreen.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/util/SplashScreen.java @@ -38,12 +38,13 @@ public class SplashScreen { private static int MAX; private static int COUNTER; private static SplashWindow WINDOW; + private boolean skipLatest; - public static void show(Frame parent, String title, int max) { + public static void show(Frame parent, String title, int max, boolean skipLatest) { hide(); SplashScreen.COUNTER = 0; SplashScreen.MAX = max; - WINDOW = new SplashWindow(parent, null, title); + WINDOW = new SplashWindow(parent, null, title, skipLatest); } public static void hide() { @@ -83,7 +84,7 @@ private static class SplashWindow extends Window { final Image splashIm; final JProgressBar bar; - SplashWindow(Frame parent, Image image, String title) { + SplashWindow(Frame parent, Image image, String title, boolean skipLatest) { super(parent); this.splashIm = image; //setSize(200, 100); @@ -108,16 +109,17 @@ private static class SplashWindow extends Window { String text = LicenseUtils.copyright(); - // check if we are running latest version - LatestClient latestClient = LatestClient.getInstance(); + // optionally check if we are running latest version String version = this.getClass().getPackage().getImplementationVersion(); - - // if no version it means we are not running a jar so probably development - if (version == null) version = "DEVELOPMENT"; - latestClient.checkLatest("tetrad", version); - StringBuilder latestResult = new StringBuilder(latestClient.getLatestResult(60)); - text = text + "\n" + latestResult.toString(); - + if (! skipLatest) { + LatestClient latestClient = LatestClient.getInstance(); + + // if no version it means we are not running a jar so probably development + if (version == null) version = "DEVELOPMENT"; + latestClient.checkLatest("tetrad", version); + StringBuilder latestResult = new StringBuilder(latestClient.getLatestResult(60)); + text = text + "\n" + latestResult.toString(); + } JTextArea textArea = new JTextArea(text); textArea.setBorder(new EmptyBorder(5, 5, 5, 5)); diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/workbench/AbstractWorkbench.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/workbench/AbstractWorkbench.java index a6c44515fa..fe6cdb5be8 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/workbench/AbstractWorkbench.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/workbench/AbstractWorkbench.java @@ -23,9 +23,10 @@ import edu.cmu.tetrad.data.IKnowledge; import edu.cmu.tetrad.graph.*; +import edu.cmu.tetrad.graph.GraphUtils; import edu.cmu.tetrad.util.JOptionUtils; -import edu.cmu.tetradapp.util.ImageUtils; -import edu.cmu.tetradapp.util.LayoutEditable; +import edu.cmu.tetradapp.model.SessionWrapper; +import edu.cmu.tetradapp.util.*; import javax.swing.*; import java.awt.*; @@ -402,9 +403,8 @@ public final List getSelectedComponents() { } /** - * @return the model edge for the given display edge. - * * @param displayEdge Ibid. + * @return the model edge for the given display edge. */ public final Edge getModelEdge(IDisplayEdge displayEdge) { return (Edge) getDisplayToModel().get(displayEdge); @@ -977,7 +977,16 @@ private void setGraphWithoutNotify(Graph graph) { throw new IllegalArgumentException("Graph model cannot be null."); } - this.graph = graph; + if (graph instanceof SessionWrapper) { + this.graph = graph; + } else { + this.graph = graph; + + if (graph.isPag()) { + GraphUtils.addPagColoring(new EdgeListGraph(graph)); + } + } + this.modelEdgesToDisplay = new HashMap<>(); this.modelNodesToDisplay = new HashMap<>(); this.displayToModel = new HashMap(); @@ -1200,7 +1209,8 @@ private void adjustForNewModelNodes() { * * @param modelEdge the mode edge. */ - private void addEdge(Edge modelEdge) { + private void + addEdge(Edge modelEdge) { if (modelEdge == null) { return; } @@ -1232,12 +1242,14 @@ private void addEdge(Edge modelEdge) { } IDisplayEdge displayEdge = getNewDisplayEdge(modelEdge); - if (graph.isHighlighted(modelEdge)) displayEdge.setHighlighted(true); - if (displayEdge == null) { return; } + if (graph.isHighlighted(modelEdge)) displayEdge.setHighlighted(true); + displayEdge.setLineColor(modelEdge.getLineColor()); + displayEdge.setDashed(modelEdge.isDashed()); + // Link the display edge to the model edge. getModelEdgesToDisplay().put(modelEdge, displayEdge); diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/workbench/DisplayEdge.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/workbench/DisplayEdge.java index 84f19b1e4a..e52b848ff0 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/workbench/DisplayEdge.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/workbench/DisplayEdge.java @@ -164,7 +164,6 @@ public class DisplayEdge extends JComponent implements IDisplayEdge { // private Color lineColor = Color.black; - // public Color lineColor = new Color(99, 101, 188); // public Color lineColor = new Color(0, 4, 255); // public Color lineColor = new Color(52, 55, 217); @@ -203,6 +202,7 @@ public class DisplayEdge extends JComponent implements IDisplayEdge { */ private final PropertyChangeHandler propertyChangeHandler = new PropertyChangeHandler(); + private boolean dashed = false; //==========================CONSTRUCTORS============================// @@ -215,6 +215,10 @@ public class DisplayEdge extends JComponent implements IDisplayEdge { * @param type the type of the edge, either UNRANDOMIZED or RANDOMIZED. */ protected DisplayEdge(DisplayNode node1, DisplayNode node2, int type) { + this(node1, node2, type, null); + } + + protected DisplayEdge(DisplayNode node1, DisplayNode node2, int type, Color color) { if (node1 == null) { throw new NullPointerException("Node1 must not be null."); } @@ -232,6 +236,11 @@ protected DisplayEdge(DisplayNode node1, DisplayNode node2, int type) { this.node1 = node1; this.node2 = node2; this.type = type; + + if (color != null) { + this.lineColor = color; + } + this.mode = ANCHORED_UNSELECTED; node1.addComponentListener(compHandler); @@ -251,6 +260,10 @@ protected DisplayEdge(DisplayNode node1, DisplayNode node2, int type) { * @param node2 the 'to' component. */ public DisplayEdge(Edge modelEdge, DisplayNode node1, DisplayNode node2) { + this(modelEdge, node1, node2, null); + } + + public DisplayEdge(Edge modelEdge, DisplayNode node1, DisplayNode node2, Color color) { if (modelEdge == null) { throw new NullPointerException("Model edge must not be null."); @@ -267,6 +280,10 @@ public DisplayEdge(Edge modelEdge, DisplayNode node1, DisplayNode node2) { this.modelEdge = modelEdge; this.node1 = node1; this.node2 = node2; + + if (color != null) { + this.lineColor = color; + } this.mode = ANCHORED_UNSELECTED; node1.addComponentListener(compHandler); @@ -292,6 +309,10 @@ public DisplayEdge(Edge modelEdge, DisplayNode node1, DisplayNode node2) { * @see #updateTrackPoint */ public DisplayEdge(DisplayNode node1, Point mouseTrackPoint, int type) { + this(node1, mouseTrackPoint, type, null); + } + + public DisplayEdge(DisplayNode node1, Point mouseTrackPoint, int type, Color color) { if (node1 == null) { throw new NullPointerException("Node1 must not be null."); @@ -311,6 +332,11 @@ public DisplayEdge(DisplayNode node1, Point mouseTrackPoint, int type) { this.node1 = node1; this.mouseTrackPoint = mouseTrackPoint; this.type = type; + + if (color != null) { + this.lineColor = color; + } + this.mode = HALF_ANCHORED; resetBounds(); @@ -407,7 +433,24 @@ private void drawEdge(Graphics g) { // width <= 1.0 seems to cause the problem, so we pick a stroke // width slightly greater than 1.0. jdramsey 4/16/2005 // g2d.setStroke(new BasicStroke(1.000001f)); - g2d.setStroke(new BasicStroke(getStrokeWidth() + 0.000001f)); + BasicStroke s; + + if (dashed) { + float dash1[] = {10.0f}; + s = new BasicStroke(1.0f, + BasicStroke.CAP_BUTT, + BasicStroke.JOIN_MITER, + 10.0f, dash1, 0.0f); + } else { + s = new BasicStroke(getStrokeWidth() + 0.000001f); + } + + g2d.setStroke(s); + + if (!isSelected()) { + g2d.setColor(lineColor); + } + g2d.drawLine(x1, y1, x2, y2); if (!isShowAdjacenciesOnly()) { @@ -626,7 +669,7 @@ protected final void setClickRegion(Polygon clickRegion) { * rectangles but doesn't always...) * * @return a point pair which represents the connecting line segment through - * the center of each rectangle touching the edge of each. + * the center of each rectangle touching the edge of each. */ protected final PointPair calculateEdge(DisplayNode comp1, DisplayNode comp2) { Rectangle r1 = comp1.getBounds(); @@ -842,7 +885,7 @@ private Point getBoundaryIntersection(DisplayNode comp, Point pIn, * * @param pp the point pair representing the line segment of the edge. * @return the Polygon representing the sleeve, or null if no such Polygon - * exists (because, e.g., one of the endpoints is null). + * exists (because, e.g., one of the endpoints is null). */ private Polygon getSleeve(PointPair pp) { if ((pp == null) || (pp.getFrom() == null) || (pp.getTo() == null)) { @@ -990,7 +1033,19 @@ public Color getLineColor() { } public void setLineColor(Color lineColor) { - this.lineColor = lineColor; + if (lineColor != null) { + this.lineColor = lineColor; + } + } + + @Override + public boolean getDashed() { + return dashed; + } + + @Override + public void setDashed(boolean dashed) { + this.dashed = dashed; } public Color getSelectedColor() { diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/workbench/GraphWorkbench.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/workbench/GraphWorkbench.java index 821bac57fa..a76f5e093d 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/workbench/GraphWorkbench.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/workbench/GraphWorkbench.java @@ -218,19 +218,21 @@ public Edge getNewModelEdge(Node node1, Node node2) { * @return the new tracking edge (a display edge). */ public IDisplayEdge getNewTrackingEdge(DisplayNode node, Point mouseLoc) { + Color color = null; + switch (edgeMode) { case DIRECTED_EDGE: - return new DisplayEdge(node, mouseLoc, DisplayEdge.DIRECTED); + return new DisplayEdge(node, mouseLoc, DisplayEdge.DIRECTED, color); case NONDIRECTED_EDGE: - return new DisplayEdge(node, mouseLoc, DisplayEdge.NONDIRECTED); + return new DisplayEdge(node, mouseLoc, DisplayEdge.NONDIRECTED, color); case PARTIALLY_ORIENTED_EDGE: return new DisplayEdge(node, mouseLoc, - DisplayEdge.PARTIALLY_ORIENTED); + DisplayEdge.PARTIALLY_ORIENTED, color); case BIDIRECTED_EDGE: - return new DisplayEdge(node, mouseLoc, DisplayEdge.BIDIRECTED); + return new DisplayEdge(node, mouseLoc, DisplayEdge.BIDIRECTED, color); default : throw new IllegalStateException(); diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/workbench/IDisplayEdge.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/workbench/IDisplayEdge.java index 289f5fdf2d..c9e57c0d43 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/workbench/IDisplayEdge.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/workbench/IDisplayEdge.java @@ -67,6 +67,10 @@ public interface IDisplayEdge { void setLineColor(Color lineColor); + boolean getDashed(); + + void setDashed(boolean dashed); + Color getSelectedColor(); void setSelectedColor(Color selectedColor); diff --git a/tetrad-gui/src/main/resources/resources/configplay.xml b/tetrad-gui/src/main/resources/resources/configplay.xml index f4796fd40a..e0e6f8d2b1 100644 --- a/tetrad-gui/src/main/resources/resources/configplay.xml +++ b/tetrad-gui/src/main/resources/resources/configplay.xml @@ -343,12 +343,8 @@ ]]> - + - - edu.cmu.tetradapp.model.DataWrapper - edu.cmu.tetradapp.editor.DataEditor - @@ -358,6 +354,36 @@ edu.cmu.tetradapp.model.Simulation edu.cmu.tetradapp.editor.SimulationEditor + + + + edu.cmu.tetradapp.app.CategorizingModelChooser + + + Sorry, this box certain inputs; please see the manual. Possible inputs are: +
(1) Data, OR +
(2) Data Manipulation Node +
Note that for the parent boxes, models need to have been created. + + ]]> +
+
+ + + + edu.cmu.tetradapp.model.DataWrapper + edu.cmu.tetradapp.editor.DataEditor + + + + + + + + + + @@ -1254,10 +1280,6 @@ - - - - @@ -1414,7 +1436,7 @@ - + @@ -1426,11 +1448,11 @@ - - + + - + @@ -1442,11 +1464,11 @@ - - + + - + @@ -1457,11 +1479,11 @@ - - + + - + @@ -1474,10 +1496,10 @@ - + - + @@ -1490,7 +1512,7 @@ - + @@ -1610,7 +1632,7 @@ - + @@ -1622,9 +1644,9 @@ - + - + diff --git a/tetrad-gui/src/main/resources/resources/configpost.xml b/tetrad-gui/src/main/resources/resources/configpost.xml index 21b429a599..4d6562bfb4 100644 --- a/tetrad-gui/src/main/resources/resources/configpost.xml +++ b/tetrad-gui/src/main/resources/resources/configpost.xml @@ -199,10 +199,11 @@ - The purpose of this box is to store a graph and allow you possibly to edit it. It can either take -
no parents or else a box with a graph already in it. Or if not that, a box in which variables are -
at least defined, so that a empty graph can be formed over those varialbes. + Sorry, this box certain inputs; please see the manual. Possible inputs are: +
(1) No inputs +
(2) Data (to extract variables) +
(3) Also (to extract a graph) any of these: Search, PM, IM, Estimator, Updator +
Note that for the parent boxes, models need to have been created. ]]>
@@ -246,9 +247,11 @@ - This box requires a graph as a parent; it's going to try to create a parameterized model -
using the graph structure. + Sorry, this box certain inputs; please see the manual. Possible inputs are: +
(1) Graph, OR +
(2) Graph AND Data (to extract discrete categories) +
(3) Also (to extract a PM) any of these: IM, Estimator +
Note that for the parent boxes, models need to have been created. ]]>
@@ -331,19 +334,17 @@ - An IM box requires a PM box as parent. It's going to try to pick values for each of the parameters -
in the PM. + Sorry, this box certain inputs; please see the manual. Possible inputs are: +
(1) PM (Bayes or SEM), OR +
(2) IM (to convert Dirichlet Bayes IM to Bayes IM), OR +
(3) Estimator (to extract IM) +
Note that for the parent boxes, models need to have been created. ]]>
- - edu.cmu.tetradapp.model.DataWrapper - edu.cmu.tetradapp.editor.DataEditor - @@ -353,6 +354,10 @@ edu.cmu.tetradapp.model.Simulation edu.cmu.tetradapp.editor.SimulationEditor + + edu.cmu.tetradapp.model.DataWrapper + edu.cmu.tetradapp.editor.DataEditor + @@ -403,12 +408,6 @@ edu.cmu.tetradapp.editor.CalculatorEditor - - - - - - @@ -852,12 +851,12 @@ - This box stores a dataset, either loaded from file or simulated, else transformed from another -
dataset. As such, it either requires no parents or else another dataset as parent. Sometimes -
it can take multiple datasets as parents. + Sorry, this box certain inputs; please see the manual. Possible inputs are: +
(1) Data, OR +
(2) Data Manipulation Node +
Note that for the parent boxes, models need to have been created. - ]]> + ]]>
@@ -919,22 +918,17 @@ edu.cmu.tetradapp.editor.EmBayesEstimatorEditor edu.cmu.tetradapp.editor.EMBayesEstimatorParamsEditor - - edu.cmu.tetradapp.model.PatternFitModel - - edu.cmu.tetradapp.editor.PatternFitEditor - - The purpose of this box is to estimate a model from data. So it needs a model in which the parameters -
of the model are specified (a PM box) and a dataset (a data box). + Sorry, this box certain inputs; please see the manual. Possible inputs are: +
(1) PM and Data, OR +
(2) IM and Data (shortcut for IM->PM + Data) +
Note that for the parent boxes, models need to have been created. - ]]> + ]]>
@@ -994,10 +988,10 @@ - The purpose of this box is to do updating (learning) from a model in which the parameter values have -
been specified. As such, it needs an instantiated model (IM) as input. - + Sorry, this box certain inputs; please see the manual. Possible inputs are: +
(1) IM (Bayes only) +
Note that for the parent boxes, models need to have been created. + ]]>
@@ -1254,10 +1248,6 @@ - - - - @@ -1414,7 +1404,7 @@ - + @@ -1426,11 +1416,11 @@ - - + + - + @@ -1442,11 +1432,11 @@ - - + + - + @@ -1457,11 +1447,11 @@ - - + + - + @@ -1474,10 +1464,10 @@ - + - + @@ -1490,7 +1480,7 @@ - + @@ -1504,18 +1494,18 @@ - - - - - - - edu.cmu.tetradapp.model.MimBuildRunner - - edu.cmu.tetradapp.editor.MimbuildEditor - - + + + + + + + + + + + + @@ -1579,17 +1569,17 @@ - - - - - - - edu.cmu.tetradapp.model.PurifyRunner - edu.cmu.tetradapp.editor.MimSearchEditor2 - - + + + + + + + + + + + @@ -1610,7 +1600,7 @@ - + @@ -1622,9 +1612,9 @@ - + - + @@ -1685,11 +1675,11 @@ - The purpose of this box is to do a search over data, yielding a graph that estimates the causal -
structure over the variables in the data. As such, it needs to have a dataset as input, or -
at least a graph, for algorithms that can search using d-separation information alone. - Sorry, this box certain inputs; please see the manual. Possible inputs are: +
(1) Data or Simulation, OR +
(2) Graph (searches directly over graph) OR +
Note that for the parent boxes, models need to have been created. + ]]>
@@ -1722,8 +1712,8 @@ - The purpose of this box is to a regression on data. As such, it requires a dataset as input. + Allowable Input Nodes: +
(1) Data (Continuous) ]]>
@@ -1744,14 +1734,13 @@ - The purpose of this box is to perform a classification using a model with parameter values defined -
(an IM) from data. As such, it needs as parents an IM box and a data box. + Allowable Input Combinations: +
(1) IM (Bayes only) + Data (discrete only) ]]>
- + @@ -1827,6 +1816,12 @@ edu.cmu.tetradapp.editor.EdgeWeightComparisonEditor + + edu.cmu.tetradapp.model.PatternFitModel + + edu.cmu.tetradapp.editor.PatternFitEditor + @@ -1834,11 +1829,11 @@ - This box is trying to compare things and so needs things to compare. Only certain combinations of -
things can be compared. You can compare a simulation to a graph output (in certain case). You -
can compare two graphs, or sometime several graphs. You can compare independence tests from -
multiple sources. See the manual for more information. + Sorry, this box certain inputs; please see the manual. Possible inputs are: +
(1) Simulation and Search +
(2) Several boxes containing graphs. +
(3) Several boxes containing independence tests. +
Note that for the parent boxes, models need to have been created. ]]>
@@ -1905,11 +1900,9 @@ - The purpose of this box is to store background knowledge for use in a a search that can use such -
background knowledge, such as PC or FGS, or many other searchs. It needs as input a model that -
contains variable names, because it needs to know what the variable names are over which the -
knowledge is defined. + Sorry, this box certain inputs; please see the manual. Possible inputs are: +
(1) Any box that contains a set of variables. +
Note that for the parent boxes, models need to have been created. ]]>
diff --git a/tetrad-gui/src/main/resources/resources/javahelp/Map.jhm b/tetrad-gui/src/main/resources/resources/javahelp/Map.jhm index c65aaf150a..dd9b542fce 100644 --- a/tetrad-gui/src/main/resources/resources/javahelp/Map.jhm +++ b/tetrad-gui/src/main/resources/resources/javahelp/Map.jhm @@ -89,9 +89,9 @@ - - - + + + @@ -120,7 +120,7 @@ - + diff --git a/tetrad-gui/src/main/resources/resources/javahelp/TetradHelpTOC.xml b/tetrad-gui/src/main/resources/resources/javahelp/TetradHelpTOC.xml index 1535203e3f..c7261e73ea 100644 --- a/tetrad-gui/src/main/resources/resources/javahelp/TetradHelpTOC.xml +++ b/tetrad-gui/src/main/resources/resources/javahelp/TetradHelpTOC.xml @@ -54,9 +54,9 @@ - - - + + + @@ -84,7 +84,7 @@ - + diff --git a/tetrad-gui/src/main/resources/resources/javahelp/manual/boxes/search/bdeu.html b/tetrad-gui/src/main/resources/resources/javahelp/manual/boxes/search/bdeu.html index dec82b3303..cc53f666ae 100644 --- a/tetrad-gui/src/main/resources/resources/javahelp/manual/boxes/search/bdeu.html +++ b/tetrad-gui/src/main/resources/resources/javahelp/manual/boxes/search/bdeu.html @@ -18,7 +18,7 @@

Score: BDeu

Calculates the BDeu score for discrete variables. The formula is given in Chickering (2002), although a different structure prior is used, as given in Ramsey et al. This score is suitable for use in GES-style algorithms, - including FGS and FGS-MB. + including FGES and FGES-MB.

diff --git a/tetrad-gui/src/main/resources/resources/javahelp/manual/boxes/search/fgs.html b/tetrad-gui/src/main/resources/resources/javahelp/manual/boxes/search/fgs.html deleted file mode 100644 index e308b81543..0000000000 --- a/tetrad-gui/src/main/resources/resources/javahelp/manual/boxes/search/fgs.html +++ /dev/null @@ -1,144 +0,0 @@ - - - - Search Algorithms: GES - - - - - - - - - -
-

Search Algorithms: FGS

-
-


- The FGS ("Fast Greedy Search") algorithm is an optimization (in time) of the GES ("Greedy Equivalence Search") - due to Chickering and Meek (give references). This methods of optimization are described in this paper - (give reference), as well as the performance characteristics. We describe the GES algorithm on which it i - based. -

-

GES (Greedy Equivalence Search) is a Bayesian algorithm that searches over Markov equivalence classes, represented - by patterns, for a data set D over a set of variables V.

-

A pattern is an acyclic graph that consists whose edges are either directed (-->) or undirected (---) and - represents an equivalence class of DAGs, as follows: each directed edge in the pattern is so directed in every DAG - in the equivalence class, and for each undirected edge X---Y in the pattern, a DAG exists in the equivalence class - with that edge directed as X<--Y and a DAG exists in the equivalence class with that edge directed as X-->Y. - To put it differently, a pattern represent the set of edges that can be determined by the search, with as many of - these edges oriented as can be, using the available information.

-

It is assumed (as with PC) that the true causal graph is acyclic and the no common hidden causes exist between pairs - of variables in the graph. GES can be run on datasets that are either entirely continuous or entirely discrete (but - not directly on graphs using d-separation). In the continuous case, it is assumed that the direct causal influence - of any variable into any other is linear, with the distribution of each variable being Normal. Under these - assumptions, the algorithm is pointwise consistent.

-

GES searches over patterns by scoring the patterns themselves. There is a forward sweep in the algorithm and - a backward sweep. In the forward sweep, at each step, GES tries to find the edge which, once added, increases the - score the most over not adding any edge at all. (After adding each such edge, the pattern is rebuilt by orienting - any edge as --- that does not participate in a collider and then applying Meek's PC orientation rules to add any - implied orientations.) Once the algorithm gets to the point where there is no edge that can be added that would - increase the score, the backward sweep begins. In the backward sweep, GES tries at each step to find the one edge it - can remove that will increase the score of the resulting the most over the previous pattern. Once it gets to the - point where there is no edge anymore than once removed increases the score, the algorithm stops.

-

There are some differences in assumptions and - expected behavior between this algorithm and the PC algorithm. When, contrary to assumptions, there is actually a - latent common cause of two measured - variables the PC algorithm will sometimes discover that fact; GES will - not.

-

Information about how precisely GES makes decisions about adding or removing edges can be found - in the logs, which can be accessed using the Logging menu.

-

 

-

Entering GES parameters

-

Consider the following example:
-
-

-
-


-
-

-
-

When the PC algorithm is chosen from the Search Object combo box, - the following window appears:
-
-

-
-


-
-

-
-

The parameters that are used by the GES algorithm can be specified - in this window. The parameters are as follows:

-
    -
  • view background knowledge: this button gives - access to a background knowledge editor - that is analogous to the one used in most search algorithms. -
  • -
-

Execute the search.

-


-
Interpreting the output

-

The GES algorithm returns a partially oriented - graph where the nodes represent the variables given as input. In our - example, the outcome should be as follows if the sample is - representative of the population:
-
-

-
-


-
-

-
-

The are basically two types of edges that can appear in GES output:

-
    -
  • a directed edge: -

    -

    In this case, the GES algorithm deduced - that A is a direct cause of B, i.e., the causal effect goes from A to B - and it is not intermediated by any of the other observed variable

    -
  • -
  • a undirected edge: -

    -

    In this case, the GES algorithm cannot - tell if A causes B or if B causes A.

    -
  • -
-

The absence of an edge between any pair of - nodes means they are independent, or that the causal effect of one modelNode - in the other is intermediate by other observed variables. Unlike the PC - algorithm, no accidental double-directed edges can appear. It does not - mean that GES will be immune to the sample variation that caused the - unexpected behavior of the PC search. It is a good idea to run both - searches and compare the result.

-

Finally, a triplet of nodes may assume the following pattern:

-
-

-
-

In other words, in such patterns, A and B are connected by an - undirected edge, A and C are connected by an undirected edge, and B and - C are not connected by an edge. By the PC search assumptions, this - means that B and C cannot both be cause of A. The three possible - scenarios are:

-
    -
  • A is a common cause of B and C
  • -
  • B is a direct cause of A, and A is a direct cause of C
  • -
  • C is a direct cause of A, and A is a direct cause of B
  • -
-

In our example, some edges were compelled to be directed: X2 and X3 - are causes of X4, and X4 is a cause of X5. However, we cannot tell much - about the triplet (X1, X2, X3), but we know that X2 and X3 cannot both - be causes of X1.

-

References:

-

Chickering (2002). Optimal structure identification with greedy search. Journal of Machine Learning - Research.

-

 

- - diff --git a/tetrad-gui/src/main/resources/resources/javahelp/manual/boxes/search/fgsmb.html b/tetrad-gui/src/main/resources/resources/javahelp/manual/boxes/search/fgsmb.html deleted file mode 100644 index ae8019a61b..0000000000 --- a/tetrad-gui/src/main/resources/resources/javahelp/manual/boxes/search/fgsmb.html +++ /dev/null @@ -1,29 +0,0 @@ - - - - Search Algorithms: FGS - - - - - - - - - -
-

Search Algorithms: FGS-MB

-
-


-

The FGS-MB ("Fast Greedy Search, Markov Blanket) algorithm is a restriction of the - FGS algorithm (see) to the problem of calculating the joint set of variables in the - Markov blanket of a list of target variables, as well as the structure over those - variables. It is guaranteed to produce the same result as if the FGS algorithm had - been run and the result then restricted to the union of the Markov blankets of the - target variables. (Give reference.) -

- -

 

- - diff --git a/tetrad-gui/src/main/resources/resources/javahelp/manual/boxes/search/gccd.html b/tetrad-gui/src/main/resources/resources/javahelp/manual/boxes/search/gccd.html index beef699609..ccbf908cce 100644 --- a/tetrad-gui/src/main/resources/resources/javahelp/manual/boxes/search/gccd.html +++ b/tetrad-gui/src/main/resources/resources/javahelp/manual/boxes/search/gccd.html @@ -15,10 +15,9 @@

Search Algorithms: GCCD

-

The GCCD algorithm is a modification of the CCD algorithm (due to - Thomas Richardson, see). It uses FGS for the initial adjacency search, - in the style of GFCI (see). In simulation, it is more accurate at - recovering cyclic structures than CCD, so we include it. +

The CCD-Max algorithm is a modification of the CCD algorithm (due to + Thomas Richardson, see). It uses the PC-Max strategy to modify the CCD + algorithm.

The algorithm is pointwise consistent for linear systems with Normal distributions, no latent variables and no correlated twoCycleErrors. It is diff --git a/tetrad-gui/src/main/resources/resources/javahelp/manual/boxes/search/gfci.html b/tetrad-gui/src/main/resources/resources/javahelp/manual/boxes/search/gfci.html index f8c1510f06..b38d2c7ace 100644 --- a/tetrad-gui/src/main/resources/resources/javahelp/manual/boxes/search/gfci.html +++ b/tetrad-gui/src/main/resources/resources/javahelp/manual/boxes/search/gfci.html @@ -17,7 +17,7 @@

Search Algorithms: GFCI

GFCI ("Greedy Fast Causal Inference", reference) modifies FCI (see) by replacing the - initial adjacency search with FGS followed by edge removal reasoning (using + initial adjacency search with FGES followed by edge removal reasoning (using conditional independence) inside triangles.

diff --git a/tetrad-gui/src/main/resources/resources/javahelp/manual/boxes/search/wfgs.html b/tetrad-gui/src/main/resources/resources/javahelp/manual/boxes/search/wfges.html similarity index 86% rename from tetrad-gui/src/main/resources/resources/javahelp/manual/boxes/search/wfgs.html rename to tetrad-gui/src/main/resources/resources/javahelp/manual/boxes/search/wfges.html index 1b2cb9ff4e..2bd2c865ea 100644 --- a/tetrad-gui/src/main/resources/resources/javahelp/manual/boxes/search/wfgs.html +++ b/tetrad-gui/src/main/resources/resources/javahelp/manual/boxes/search/wfges.html @@ -15,8 +15,8 @@

Search Algorithms: WGFCI

-

WFGS ("Whimsical Fast Greedy Search") is an experimental algorithm for the mixed variable case. - Each discrete variable is rewritten as a set of indicator variables, and then FGS (see) is run +

WFGES ("Whimsical Fast Greedy Equivalence Search") is an experimental algorithm for the mixed variable case. + Each discrete variable is rewritten as a set of indicator variables, and then FGES (see) is run over the indicators and continuous variables. A causal connection is inferred between X and Y if X and Y are connected and continuous, or if there is a connection between one of the indicators of the X or of Y to an indicator of the other, or to a continuous variable, if X or Y is diff --git a/tetrad-gui/src/main/resources/resources/javahelp/manual/images/calvin_hobbes_instructions.gif b/tetrad-gui/src/main/resources/resources/javahelp/manual/images/calvin_hobbes_instructions.gif new file mode 100644 index 0000000000..8020b0e4f7 Binary files /dev/null and b/tetrad-gui/src/main/resources/resources/javahelp/manual/images/calvin_hobbes_instructions.gif differ diff --git a/tetrad-gui/src/main/resources/resources/javahelp/manual/images/choose_data.png b/tetrad-gui/src/main/resources/resources/javahelp/manual/images/choose_data.png new file mode 100644 index 0000000000..74da872df4 Binary files /dev/null and b/tetrad-gui/src/main/resources/resources/javahelp/manual/images/choose_data.png differ diff --git a/tetrad-gui/src/main/resources/resources/javahelp/manual/images/choose_load_data.png b/tetrad-gui/src/main/resources/resources/javahelp/manual/images/choose_load_data.png new file mode 100644 index 0000000000..ece1129e60 Binary files /dev/null and b/tetrad-gui/src/main/resources/resources/javahelp/manual/images/choose_load_data.png differ diff --git a/tetrad-gui/src/main/resources/resources/javahelp/manual/images/click_done_data.png b/tetrad-gui/src/main/resources/resources/javahelp/manual/images/click_done_data.png new file mode 100644 index 0000000000..fe2ce33f4a Binary files /dev/null and b/tetrad-gui/src/main/resources/resources/javahelp/manual/images/click_done_data.png differ diff --git a/tetrad-gui/src/main/resources/resources/javahelp/manual/images/click_load.png b/tetrad-gui/src/main/resources/resources/javahelp/manual/images/click_load.png new file mode 100644 index 0000000000..42b87ab845 Binary files /dev/null and b/tetrad-gui/src/main/resources/resources/javahelp/manual/images/click_load.png differ diff --git a/tetrad-gui/src/main/resources/resources/javahelp/manual/images/place_data_box.png b/tetrad-gui/src/main/resources/resources/javahelp/manual/images/place_data_box.png new file mode 100644 index 0000000000..8c9c97c7b8 Binary files /dev/null and b/tetrad-gui/src/main/resources/resources/javahelp/manual/images/place_data_box.png differ diff --git a/tetrad-gui/src/main/resources/resources/javahelp/manual/images/save_data.png b/tetrad-gui/src/main/resources/resources/javahelp/manual/images/save_data.png new file mode 100644 index 0000000000..1295ddd417 Binary files /dev/null and b/tetrad-gui/src/main/resources/resources/javahelp/manual/images/save_data.png differ diff --git a/tetrad-gui/src/main/resources/resources/javahelp/manual/images/search_box.png b/tetrad-gui/src/main/resources/resources/javahelp/manual/images/search_box.png new file mode 100644 index 0000000000..debd7fe513 Binary files /dev/null and b/tetrad-gui/src/main/resources/resources/javahelp/manual/images/search_box.png differ diff --git a/tetrad-gui/src/main/resources/resources/javahelp/manual/images/search_done.png b/tetrad-gui/src/main/resources/resources/javahelp/manual/images/search_done.png new file mode 100644 index 0000000000..001aeb1caf Binary files /dev/null and b/tetrad-gui/src/main/resources/resources/javahelp/manual/images/search_done.png differ diff --git a/tetrad-gui/src/main/resources/resources/javahelp/manual/images/search_params.png b/tetrad-gui/src/main/resources/resources/javahelp/manual/images/search_params.png new file mode 100644 index 0000000000..08bccc3bf7 Binary files /dev/null and b/tetrad-gui/src/main/resources/resources/javahelp/manual/images/search_params.png differ diff --git a/tetrad-gui/src/main/resources/resources/javahelp/manual/tetrad_tutorial.html b/tetrad-gui/src/main/resources/resources/javahelp/manual/tetrad_tutorial.html index d3f2ad36e8..1a05d3f3fb 100644 --- a/tetrad-gui/src/main/resources/resources/javahelp/manual/tetrad_tutorial.html +++ b/tetrad-gui/src/main/resources/resources/javahelp/manual/tetrad_tutorial.html @@ -3,309 +3,309 @@ - - Tetrad_Tutorial - - - + } +} +@media print { + table, pre { + page-break-inside: avoid; + } + pre { + word-wrap: break-word; + } +} + @@ -314,84 +314,115 @@

Tetrad Tutorial

-

calvin hobbes instructions

+

calvin hobbes instructions

+ +

Calvin and Hobbes, Bill Watterson, April 19, 1988, (source).1

+ +

Table of Contents

Tetrad includes a huge variety of tools for causal inference. It has been under development since the early 90s. The algorithms in Tetrad were designed by many people, but the vast majority of the implementation was done by Joe Ramsey.

+
+

Things you can do with Tetrad

When people say 'causal inference', they mean lots of different things. Here are some things you might want to do with Tetrad:

    -
  • You have a dataset, and: - -
      -
    • You want to learn the causal graph that describes what causes what ("search")
    • -
    • You want to test whether a specific variable causes a target variable, and if so, what is the size of the effect
    • -
    • You want to find the set of variables that affects some target of interest ("feature selection")
    • -
    • You want to predict what will happen if you intervene on some variable
    • -
    • You want to find a set of experiments that are likely to produce large effects (on one or more targets)
    • -
  • -
  • You have a dataset and a known causal graph, and: - -
      -
    • You want to estimate the strength of a particular causal effect, or all of them
    • -
    • You want to evaluate how well the search algorithms recover your graph from the data
    • -
    • You want to evaluate how well your graph fits your data, and maybe find other structures that fit better
    • -
  • -
  • You have a search algorithm, and you want to evaluate how well it recovers causal graphs from synthetic data ("simulation")
  • +
  • You have a dataset, and: + +
      +
    • You want to learn the causal graph that describes what causes what ("search")
    • +
    • You want to test whether a specific variable causes a target variable, and if so, what is the size of the effect
    • +
    • You want to find the set of variables that affects some target of interest ("feature selection")
    • +
    • You want to predict what will happen if you intervene on some variable
    • +
    • You want to find a set of experiments that are likely to produce large effects (on one or more targets)
    • +
  • +
  • You have a dataset and a known causal graph, and: + +
      +
    • You want to estimate the strength of a particular causal effect, or all of them
    • +
    • You want to evaluate how well the search algorithms recover your graph from the data
    • +
    • You want to evaluate how well your graph fits your data, and maybe find other structures that fit better
    • +
  • +
  • You have a search algorithm, and you want to evaluate how well it recovers causal graphs from synthetic data ("simulation")

All of these tasks can be called 'causal inference'.

@@ -400,17 +431,17 @@

Things you can do with Tetrad

To understand what is possible with Tetrad, let's talk about what it contains.

-

What's under the hood

+
-

Comic ontology

+

What's under the hood

-

Tetrad is written in Java, an object-oriented programming language. Tetrad uses the following kinds of objects:1

+

Tetrad is written in Java, an object-oriented programming language. Tetrad uses the following kinds of objects:2

Variables = Nodes = Vertices

variables schema

-

Causal inference is a scientific discovery problem, so random variables are the basic objects. Variables are identified with "nodes" or "vertices" in causal graphs.2

+

Causal inference is a scientific discovery problem, so random variables are the basic objects. Variables are identified with "nodes" or "vertices" in causal graphs.3

In other graph software, you first create a graph, then populate it with nodes; if the graph disappears the nodes do too. By contrast, in Tetrad the nodes are basic objects. You can build multiple graphs over the same set of nodes. This represents the scientific problem: we start out knowing what the variables are, and we learn the causal relationships among them.

@@ -418,9 +449,15 @@

Variables = Nodes = Vertices

How they're made: You create new variables when you load your data into Tetrad, create a random graph, or create a new graph by hand (with no input).

-

Example: our set of variables might be {Sunscreen, Temperature, Ice-cream}.

+

Examples

-

Datasets

+
+

Schematic Example: our set of variables might be {Sunscreen, Temperature, Ice-cream}.

+
+ +
+ +

Datasets

dataset schema

@@ -428,48 +465,51 @@

Datasets

How they're made: You create a dataset when you load your data into Tetrad, or generate data from an instantiated model.

-

Example: our dataset might look something like this table of observations:

+

Examples

+ +
+

Schematic Example: our dataset might look something like this table of observations:

Variables: {Sunscreen, Temperature, Ice-cream}

Data:

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Person/DateSunscreenTemperatureIce-cream
Hemank, June 120ml32°C150g
Mahdi, June 1215ml32°C120g
Benedict, June 1430ml36°C200g
............
Person/DateSunscreenTemperatureIce-cream
Hemank, June 120ml32°C150g
Mahdi, June 1215ml32°C120g
Benedict, June 1430ml36°C200g
............

Or this correlation matrix:

@@ -479,38 +519,96 @@

Datasets

Data:

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
SunscreenTemperatureIce-cream
Sunscreen10.30.12
Temperature0.310.4
Ice-cream0.120.41
SunscreenTemperatureIce-cream
Sunscreen10.30.12
Temperature0.310.4
Ice-cream0.120.41
+
+ +

.

+ +
+

GUI example:

+ +

I'll use the 'Acute Inflammations' (AI) dataset from UC Irvine's repository of datasets for machine learning. I've done four things to make the data cleaner and easier to import into Tetrad:

+ +
    +
  1. The AI dataset uses commas for a decimal point. I replaced them with periods.
  2. +
  3. I removed invisible characters.
  4. +
  5. I put variable names at the top of the file.
  6. +
  7. I dichotomized the 'temperature' variable into a binary indicator I named 'fever', using 39°C as the cutoff (the temperature distribution is bimodal and the low point between the modes is about 39).
  8. +
+ +

You can download my modified file here.

+ +

To create a dataset object in Tetrad, do the following:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
StepScreenshots
1. Place a data box on the workspace; double click to open it.place data
2. Select "data" from the drop-down menu.choose data
3. Select File -> Load. Choose the AI dataset file.choose load data
4. Make sure the loading options are set like so, and click "Load".
5. Click "Save".
6. Click "Done".
+
-

Graphs

+
+ +

Graphs

graph schema

@@ -518,15 +616,21 @@

Graphs

How they're made: There are three ways to create graphs in Tetrad: by hand, using a random graph generator, or using a search algorithm.

-

Example: If our causal graph looks like this: Sunscreen Temperature Ice-cream, it would be represented in Tetrad like so:

+

Example

+ +
+

Schematic Example: If our causal graph looks like this: Sunscreen Temperature Ice-cream, it would be represented in Tetrad like so:

Variables: {Sunscreen, Temperature, Ice-cream}

Edges: - {(Sunscreen, Temperature, >, -), - (Temperature, Ice-cream, -, >)}

+{(Sunscreen, Temperature, >, -), +(Temperature, Ice-cream, -, >)}

+
-

Search algorithms

+
+ +

Search algorithms

search schema

@@ -535,47 +639,47 @@

Search algorithms

How many graphs are we looking through?

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Number of variablesNumber of Directed Acyclic Graphs
11
23
325
4543
529281
63781503
......
20more than the number of atoms in the observable universe
Number of variablesNumber of Directed Acyclic Graphs
11
23
325
4543
529281
63781503
......
20more than the number of atoms in the observable universe

This is why we need an algorithm to search, rather than inspecting all the graphs by hand. Search algorithms use various tricks to find the answer quickly, without inspecting every single graph.

@@ -583,32 +687,70 @@

Search algorithms

How they're made: A search algorithm is a function: it takes input and produces output. The inputs are:

    -
  • A dataset (required)3; note that this includes the variable set
  • -
  • Background knowledge about the causal relationships (optional)
  • -
  • Other settings, like tuning parameters, which depend on the specific algorithm
  • +
  • A dataset (required)4; note that this includes the variable set
  • +
  • Background knowledge about the causal relationships (optional)
  • +
  • Other settings, like tuning parameters, which depend on the specific algorithm

The output is a graph, or a set of graphs that are equally compatible with the data (a.k.a. an "equivalence class" of graphs). The type of graph you get depends on the type of algorithm you use.

-

Knowledge

+

Example

+ +
+

GUI example:

+ + + + + + + + + + + + + + + + + + + + + + + +
StepScreenshots
1. Put a Search box in the workspace, and add an arrow from the Data box to the Search box.
2. (a) Choose output type (here a PAG). (b) Choose an algorithm (here GFCI). (c) Choose parameters (here alpha = 0.05, and "one-edge faithfulness" = "no"). (d) Click "search".
3. Your results will pop up. If you wish, you can drag the variables into a nicer layout. Then click "Done".
+
+ +
+ +

Knowledge

knowledge schema

-

As mentioned in the Search Algorithms section, we can use background knowledge as an input to search. Tetrad represents knowledge as a set of variables, a list of forbidden edges4 and a list of required edges.

+

As mentioned in the Search Algorithms section, we can use background knowledge as an input to search. Tetrad represents knowledge as a set of variables, a list of forbidden edges5 and a list of required edges.

How they're made: You might think of knowledge as being independent of everything else – that's what makes it "background" knowledge! However, Tetrad won't let you create a knowledge object without giving it input: a dataset or search algorithm that tells it the names of your variables. Only then can you list the forbidden and required edges. It is as if Tetrad is asking, "knowledge about what?"

-

Example: Say we know that neither ice-cream nor sunscreen can influence the temperature. We would represent this as a pair of forbidden edges. In Tetrad the knowledge would be represented like so:

+

Example

+ +
+

Schematic Example: Say we know that neither ice-cream nor sunscreen can influence the temperature. We would represent this as a pair of forbidden edges. In Tetrad the knowledge would be represented like so:

Variables: {Sunscreen, Temperature, Ice-cream}

Forbidden Edges: - {(Sunscreen, Temperature, -, >), - (Ice-cream, Temperature, -, >)}

+{(Sunscreen, Temperature, -, >), +(Ice-cream, Temperature, -, >)}

Required Edges: {}

+
-

Parametric & Instantiated models

+
+ +

Parametric & Instantiated models

Causal graphs only give us qualitative information: which variables causally influence which others. But they don't tell us quantitatively how big the causal effects are. They put constraints on the probability distribution over variables in the graph, but they don't fully specify the probability distribution. For that, we need models.

@@ -617,43 +759,43 @@

Parametric & Instantiated models

We need models for several distinct tasks:

    -
  1. Given data and a graph we trust, we fit a model to learn the size of the causal effects.
  2. -
  3. Given data and a graph we wish to evaluate, we fit and then test a model to see how well that graph can describe our data.
  4. -
  5. Given a graph, we specify a model so we can generate synthetic data from that graph, which we can then use to evaluate a search algorithm.
  6. +
  7. Given data and a graph we trust, we fit a model to learn the size of the causal effects.
  8. +
  9. Given data and a graph we wish to evaluate, we fit and then test a model to see how well that graph can describe our data.
  10. +
  11. Given a graph, we specify a model so we can generate synthetic data from that graph, which we can then use to evaluate a search algorithm.

Tetrad has two confusing distinctions between types of model object. Here they are in one table:

- - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + +
Bayes modelStructural Equation Model (SEM)
Parametric ModelGraph (DAG) where the nodes are discrete variables, each with a set of possible valuesGraph (DAG) where the nodes are continuous variables (means and variances initialized but not assigned values), plus a set of linear parameters (coefficients initialized but not assigned values)
Instantiated modelProbabilities assigned to the possible values of each variable, conditional on its parents in the graphValues assigned to all parameters of linear structural equation model (means, variances, and edge coefficients)
Bayes modelStructural Equation Model (SEM)
Parametric ModelGraph (DAG) where the nodes are discrete variables, each with a set of possible valuesGraph (DAG) where the nodes are continuous variables (means and variances initialized but not assigned values), plus a set of linear parameters (coefficients initialized but not assigned values)
Instantiated modelProbabilities assigned to the possible values of each variable, conditional on its parents in the graphValues assigned to all parameters of linear structural equation model (means, variances, and edge coefficients)

Tetrad distinguishes between parametric models and instantiated models. The parametric model just initializes the object: it's where you decide what kind of model you're going to use (Bayes or SEM parameterization). The instantiated model then assigns values to the model parameters.

-

Bayes PMs and IMs

+

Bayes PMs and IMs

bayes pm

-

"Bayes model" just means the model fits discrete-valued data. It has no special relationship to Bayesian inference5. Tetrad uses the term "Bayes model" only because DAGs for discrete data have been called "Bayes nets" (again, not because they have a special relationship to Bayesian inference).

+

"Bayes model" just means the model fits discrete-valued data. It has no special relationship to Bayesian inference6. Tetrad uses the term "Bayes model" only because DAGs for discrete data have been called "Bayes nets" (again, not because they have a special relationship to Bayesian inference).

A Bayes Parametric Model (Bayes PM) object includes a graph, and a set of possible values for every variable in that graph. The graph must be a DAG.

@@ -665,7 +807,7 @@

Bayes PMs and IMs

How Bayes IMs are made: You can start with a Bayes PM and a dataset, in which case Tetrad will estimate the conditional probabilities from your data. If you want to generate synthetic data, you can start with a Bayes PM and specify the conditional probabilities (either by choosing them randomly, or inputting specific values by hand).

-

SEM PMs and IMs

+

SEM PMs and IMs

sem pm

@@ -674,8 +816,8 @@

SEM PMs and IMs

A linear model means the relationships between the variables can be described with linear equations. For example, if we have the graph X Y Z, we could describe this as the standard SEM parametric model:

\(X = \varepsilon_1 \\ - Z = \varepsilon_2 \\ - Y = \alpha X + \beta Z + \varepsilon_3\)

+Z = \varepsilon_2 \\ +Y = \alpha X + \beta Z + \varepsilon_3\)

Where the errors \(\varepsilon_1, \varepsilon_2, \varepsilon_3\) are independent random variables with Gaussian distributions.

@@ -693,35 +835,39 @@

SEM PMs and IMs

How SEM IMs are made: You can start with a SEM PM and a dataset, in which case Tetrad will estimate the model parameters from your data. If you want to generate synthetic data, you can start with a SEM PM and specify the parameter values (either by choosing them randomly, or inputting specific values by hand).

-

Other Objects

+
+ +

Other Objects

There are five other modules that I won't talk about here. See these other sections of the manual for more information:

    -
  • Comparisons between graphs
  • -
  • Updaters
  • -
  • Regression functions
  • -
  • Classifiers
  • -
  • Random graph generators
  • +
  • Comparisons between graphs
  • +
  • Updaters
  • +
  • Regression functions
  • +
  • Classifiers
  • +
  • Random graph generators
-

An example pipeline

+
+ +

An example pipeline

Say you start with data, and you want to learn a causal model and estimate the size of the causal effects. Your workflow or "pipeline" would look like the following schema.

-

But take note: This schema describes what's happening inside the Tetrad library. In the graphical interface, some steps may be combined. For example, in the current6 version of the Tetrad GUI, steps 4, 5 and 6 are grouped into a single box.

+

But take note: This schema describes what's happening inside the Tetrad library. In the graphical interface, some steps may be combined. For example, in the current7 version of the Tetrad GUI, steps 4, 5 and 6 are grouped into a single box.

pipeline

In text form:

    -
  1. Load your data into Tetrad, generating a Dataset object.
  2. -
  3. Feed your data into a Search Algorithm.
  4. -
  5. Choose search settings/assumptions that make sense, given how your data were collected.
  6. -
  7. The output will be an equivalence class of graphs. Choose one plausible DAG from the output equivalence class.
  8. -
  9. Choose a parametric model that makes sense for your data.
  10. -
  11. Use your dataset to learn the parameters of the instantiated model.
  12. +
  13. Load your data into Tetrad, generating a Dataset object.
  14. +
  15. Feed your data into a Search Algorithm.
  16. +
  17. Choose search settings/assumptions that make sense, given how your data were collected.
  18. +
  19. The output will be an equivalence class of graphs. Choose one plausible DAG from the output equivalence class.
  20. +
  21. Choose a parametric model that makes sense for your data.
  22. +
  23. Use your dataset to learn the parameters of the instantiated model.

You should also perform some sanity checks along the way:

@@ -730,41 +876,47 @@

An example pipeline

After estimating the model parameters: do the parameters look plausible? What changes if you choose a different graph from the equivalence class?

-

Takeaway Messages

+
+ +

Takeaway Messages

Tetrad is a modular, object-oriented program for causal inference. "Causal inference" includes a variety of tasks; Tetrad objects can be combined in various ways to accomplish many of those tasks. This tutorial describes some of the most important objects in Tetrad. It is meant to be schematic yet independent of Tetrad's graphical user interface (which may change in the future). I have included an example of one pipeline – one way of combining Tetrad objects to achieve a particular aim – but that is only the beginning of what is possible with Tetrad.

This tutorial is an introduction to the Tetrad software. For an introduction to causal inference in general, and guidance on interpreting your results, see the companion tutorial.

-
-
    +
    +
      + +
    1. +

      This comic is under copyright, held by Universal Uclick. We believe our use of the material is covered under Fair Use for three reasons: (1) The purpose of the use is education, not profit. (2) The portion of the work used is tiny relative to the whole corpus of Calvin and Hobbes comics (one panel of one strip). (3) The use of this panel will have no effect on the market value of Calvin and Hobbes. However, should Universal Uclick disagree with our judgment and ask us to remove the comic from this documentation, we will comply. 

      +
    2. -
    3. -

      For brevity, this is a simplified version of Tetrad's ontology, emphasizing the objects that you see in the GUI, and their dependences. If you want to learn what's really under the hood you can look at the Tetrad library source code in the Git repository. 

      -
    4. +
    5. +

      For brevity, this is a simplified version of Tetrad's ontology, emphasizing the objects that you see in the GUI, and their dependences. If you want to learn what's really under the hood you can look at the Tetrad library source code in the Git repository. 

      +
    6. -
    7. -

      In the guts of Tetrad there are differences between node objects and variables, and what you're using depends on whether you load data first or define a graph and generate data from it. These details should not matter to the user. 

      -
    8. +
    9. +

      In the guts of Tetrad there are differences between node objects and variables, and what you're using depends on whether you load data first or define a graph and generate data from it. These details should not matter to the user. 

      +
    10. -
    11. -

      You may instead use some kind of 'oracle', which gives the algorithm the information that it would normally estimate from the dataset (e.g. conditional independence facts). This is useful if you're trying to figure out how the algorithms perform when given perfect information. 

      -
    12. +
    13. +

      You may instead use some kind of 'oracle', which gives the algorithm the information that it would normally estimate from the dataset (e.g. conditional independence facts). This is useful if you're trying to figure out how the algorithms perform when given perfect information. 

      +
    14. -
    15. -

      We can also use tiers to forbid many edges at once. This is often useful, for example, if you have time-ordered measurements, and you want to prevent any edges going back in time. For more information look at the module on Knowledge. 

      -
    16. +
    17. +

      We can also use tiers to forbid many edges at once. This is often useful, for example, if you have time-ordered measurements, and you want to prevent any edges going back in time. For more information look at the module on Knowledge. 

      +
    18. -
    19. -

      Of course you can learn a Bayes model using Bayesian updating. However, you can also learn a Structural Equation Model using Bayesian updating.  

      -
    20. +
    21. +

      Of course you can learn a Bayes model using Bayesian updating. However, you can also learn a Structural Equation Model using Bayesian updating.  

      +
    22. -
    23. -

      Current as of 10/21/2016. 

      -
    24. +
    25. +

      Current as of 10/21/2016. 

      +
    26. -
    +
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/Comparison.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/Comparison.java index 0b922661c7..0b3e1f9cc4 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/Comparison.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/Comparison.java @@ -29,6 +29,7 @@ import edu.cmu.tetrad.algcomparison.score.BdeuScore; import edu.cmu.tetrad.algcomparison.score.ScoreWrapper; import edu.cmu.tetrad.algcomparison.simulation.LoadContinuousDataAndGraphs; +import edu.cmu.tetrad.algcomparison.simulation.LoadDataAndGraphs; import edu.cmu.tetrad.algcomparison.simulation.Simulation; import edu.cmu.tetrad.algcomparison.simulation.Simulations; import edu.cmu.tetrad.algcomparison.statistic.ElapsedTime; @@ -94,7 +95,7 @@ public void compareFromFiles(String filePath, Algorithms algorithms, } for (File dir : dirs) { - simulations.add(new LoadContinuousDataAndGraphs(dir.getAbsolutePath())); + simulations.add(new LoadDataAndGraphs(dir.getAbsolutePath())); } compareFromSimulations(filePath, simulations, algorithms, statistics, parameters); @@ -115,7 +116,9 @@ public void compareFromSimulations(String filePath, Simulations simulations, Alg // Create output file. try { File dir = new File(filePath); - this.out = new PrintStream(new FileOutputStream(new File(dir, "Comparison.txt"))); + dir.mkdirs(); + File file = new File(dir, "Comparison.txt"); + this.out = new PrintStream(new FileOutputStream(file)); } catch (Exception e) { throw new RuntimeException(e); } @@ -1093,7 +1096,13 @@ private double[][][] calcStatTables(double[][][][] allStats, Mode mode, int numT for (String name : _parameterNames) { if (name.equals(statName)) { - stat = _parameters.getDouble(name); + try { + stat = _parameters.getDouble(name); + } catch (Exception e) { + boolean b = _parameters.getBoolean(name); + stat = b ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY; + } + break; } } @@ -1186,8 +1195,15 @@ private void printStats(double[][][] statTables, Statistics statistics, Mode mod for (int t = 0; t < algorithmSimulationWrappers.size(); t++) { for (int statIndex = 0; statIndex < numStats; statIndex++) { double stat = statTables[u][newOrder[t]][statIndex]; - table.setToken(t + 1, initialColumn + statIndex, - Math.abs(stat) < 0.1 ? smallNf.format(stat) : nf.format(stat)); + + if (stat == Double.POSITIVE_INFINITY) { + table.setToken(t + 1, initialColumn + statIndex, "Yes"); + } else if (stat == Double.NEGATIVE_INFINITY) { + table.setToken(t + 1, initialColumn + statIndex, "No"); + } else { + table.setToken(t + 1, initialColumn + statIndex, + Math.abs(stat) < 0.1 ? smallNf.format(stat) : nf.format(stat)); + } } if (isShowUtilities()) { @@ -1353,7 +1369,7 @@ public List getParameters() { } public void setValue(String name, Object value) { - if (!(value instanceof Number)) { + if (!(value instanceof Number || value instanceof Boolean)) { throw new IllegalArgumentException(); } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/cluster/Bpc.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/cluster/Bpc.java index 3492f9df6c..129b29cbf4 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/cluster/Bpc.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/cluster/Bpc.java @@ -4,6 +4,7 @@ import edu.cmu.tetrad.algcomparison.utils.HasKnowledge; import edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph; import edu.cmu.tetrad.data.*; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.search.BuildPureClusters; import edu.cmu.tetrad.search.SearchGraphUtils; @@ -49,7 +50,7 @@ public Graph search(DataModel dataSet, Parameters parameters) { @Override public Graph getComparisonGraph(Graph graph) { - return SearchGraphUtils.patternForDag(graph); + return SearchGraphUtils.patternForDag(new EdgeListGraph(graph)); } @Override diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/cluster/Fofc.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/cluster/Fofc.java index b3de5d7887..8fda0c2d10 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/cluster/Fofc.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/cluster/Fofc.java @@ -4,6 +4,7 @@ import edu.cmu.tetrad.algcomparison.utils.HasKnowledge; import edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph; import edu.cmu.tetrad.data.*; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.search.FindOneFactorClusters; import edu.cmu.tetrad.search.SearchGraphUtils; @@ -57,7 +58,7 @@ public Graph search(DataModel dataSet, Parameters parameters) { @Override public Graph getComparisonGraph(Graph graph) { - return SearchGraphUtils.patternForDag(graph); + return SearchGraphUtils.patternForDag(new EdgeListGraph(graph)); } @Override diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/cluster/Ftfc.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/cluster/Ftfc.java index 16c73af116..760b9fc868 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/cluster/Ftfc.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/cluster/Ftfc.java @@ -4,6 +4,7 @@ import edu.cmu.tetrad.algcomparison.utils.HasKnowledge; import edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph; import edu.cmu.tetrad.data.*; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.search.FindTwoFactorClusters; import edu.cmu.tetrad.search.SearchGraphUtils; @@ -46,7 +47,7 @@ public Graph search(DataModel dataSet, Parameters parameters) { @Override public Graph getComparisonGraph(Graph graph) { - return SearchGraphUtils.patternForDag(graph); + return SearchGraphUtils.patternForDag(new EdgeListGraph(graph)); } @Override diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/continuous/dag/Lingam.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/continuous/dag/Lingam.java index 5298332d25..ff376dfb8b 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/continuous/dag/Lingam.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/continuous/dag/Lingam.java @@ -3,6 +3,7 @@ import edu.cmu.tetrad.algcomparison.algorithm.Algorithm; import edu.cmu.tetrad.data.DataModel; import edu.cmu.tetrad.data.DataUtils; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.data.DataType; import edu.cmu.tetrad.graph.Graph; @@ -25,7 +26,7 @@ public Graph search(DataModel dataSet, Parameters parameters) { @Override public Graph getComparisonGraph(Graph graph) { - return graph; + return new EdgeListGraph(graph); } public String getDescription() { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/mixed/Mgm.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/mixed/Mgm.java index e0172a55cf..c3dc721217 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/mixed/Mgm.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/mixed/Mgm.java @@ -4,6 +4,7 @@ import edu.cmu.tetrad.data.DataModel; import edu.cmu.tetrad.data.DataSet; import edu.cmu.tetrad.data.DataUtils; +import edu.cmu.tetrad.graph.GraphUtils; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.data.DataType; import edu.cmu.tetrad.graph.Graph; @@ -28,8 +29,9 @@ public Graph search(DataModel ds, Parameters parameters) { return m.search(); } + // Need to marry the parents on this. public Graph getComparisonGraph(Graph graph) { - return SearchGraphUtils.patternForDag(graph); + return GraphUtils.undirectedGraph(graph); } public String getDescription() { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/mixed/pattern/MixedFgsDiscretingContinuousVariables.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/mixed/pattern/MixedFgesDiscretingContinuousVariables.java similarity index 83% rename from tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/mixed/pattern/MixedFgsDiscretingContinuousVariables.java rename to tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/mixed/pattern/MixedFgesDiscretingContinuousVariables.java index 81bff02327..c2e773ae0f 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/mixed/pattern/MixedFgsDiscretingContinuousVariables.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/mixed/pattern/MixedFgesDiscretingContinuousVariables.java @@ -7,7 +7,7 @@ import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.graph.Node; -import edu.cmu.tetrad.search.Fgs; +import edu.cmu.tetrad.search.Fges; import edu.cmu.tetrad.search.SearchGraphUtils; import edu.cmu.tetrad.util.Parameters; @@ -16,11 +16,11 @@ /** * @author jdramsey */ -public class MixedFgsDiscretingContinuousVariables implements Algorithm { +public class MixedFgesDiscretingContinuousVariables implements Algorithm { static final long serialVersionUID = 23L; private ScoreWrapper score; - public MixedFgsDiscretingContinuousVariables(ScoreWrapper score) { + public MixedFgesDiscretingContinuousVariables(ScoreWrapper score) { this.score = score; } @@ -36,20 +36,20 @@ public Graph search(DataModel dataSet, Parameters parameters) { dataSet = discretizer.discretize(); DataSet _dataSet = DataUtils.getDiscreteDataSet(dataSet); - Fgs fgs = new Fgs(score.getScore(_dataSet, parameters)); - Graph p = fgs.search(); + Fges fges = new Fges(score.getScore(_dataSet, parameters)); + Graph p = fges.search(); return convertBack(_dataSet, p); } @Override public Graph getComparisonGraph(Graph graph) { - return SearchGraphUtils.patternForDag(graph); + return SearchGraphUtils.patternForDag(new EdgeListGraph(graph)); } @Override public String getDescription() { - return "FGS after discretizing the continuous variables in the data set using " + score.getDescription(); + return "FGES after discretizing the continuous variables in the data set using " + score.getDescription(); } private Graph convertBack(DataSet Dk, Graph p) { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/mixed/pattern/MixedFgsTreatingDiscreteAsContinuous.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/mixed/pattern/MixedFgesTreatingDiscreteAsContinuous.java similarity index 86% rename from tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/mixed/pattern/MixedFgsTreatingDiscreteAsContinuous.java rename to tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/mixed/pattern/MixedFgesTreatingDiscreteAsContinuous.java index 17dce9b08a..2db63795a3 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/mixed/pattern/MixedFgsTreatingDiscreteAsContinuous.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/mixed/pattern/MixedFgesTreatingDiscreteAsContinuous.java @@ -6,7 +6,7 @@ import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.graph.Node; -import edu.cmu.tetrad.search.Fgs; +import edu.cmu.tetrad.search.Fges; import edu.cmu.tetrad.search.SearchGraphUtils; import edu.cmu.tetrad.search.SemBicScore; import edu.cmu.tetrad.util.Parameters; @@ -17,15 +17,15 @@ /** * @author jdramsey */ -public class MixedFgsTreatingDiscreteAsContinuous implements Algorithm { +public class MixedFgesTreatingDiscreteAsContinuous implements Algorithm { static final long serialVersionUID = 23L; public Graph search(DataModel Dk, Parameters parameters) { DataSet mixedDataSet = DataUtils.getMixedDataSet(Dk); mixedDataSet = DataUtils.convertNumericalDiscreteToContinuous(mixedDataSet); SemBicScore score = new SemBicScore(new CovarianceMatrixOnTheFly(mixedDataSet)); score.setPenaltyDiscount(parameters.getDouble("penaltyDiscount")); - Fgs fgs = new Fgs(score); - Graph p = fgs.search(); + Fges fges = new Fges(score); + Graph p = fges.search(); return convertBack(mixedDataSet, p); } @@ -53,11 +53,11 @@ private Graph convertBack(DataSet Dk, Graph p) { } public Graph getComparisonGraph(Graph graph) { - return SearchGraphUtils.patternForDag(graph); + return SearchGraphUtils.patternForDag(new EdgeListGraph(graph)); } public String getDescription() { - return "FGS2, using the SEM BIC score, treating all discrete variables as " + + return "FGES2, using the SEM BIC score, treating all discrete variables as " + "continuous"; } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/multi/ImagesBDeu.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/multi/ImagesBDeu.java index fef62e3087..be325e3e43 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/multi/ImagesBDeu.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/multi/ImagesBDeu.java @@ -1,10 +1,11 @@ package edu.cmu.tetrad.algcomparison.algorithm.multi; import edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm; -import edu.cmu.tetrad.algcomparison.algorithm.oracle.pattern.Fgs; +import edu.cmu.tetrad.algcomparison.algorithm.oracle.pattern.Fges; import edu.cmu.tetrad.algcomparison.score.BdeuScore; import edu.cmu.tetrad.algcomparison.utils.HasKnowledge; import edu.cmu.tetrad.data.*; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.search.BdeuScoreImages; @@ -37,8 +38,9 @@ public Graph search(List dataSets, Parameters parameters) { dataModels.add(dataSet); } - edu.cmu.tetrad.search.Fgs search = new edu.cmu.tetrad.search.Fgs(new BdeuScoreImages(dataModels)); + edu.cmu.tetrad.search.Fges search = new edu.cmu.tetrad.search.Fges(new BdeuScoreImages(dataModels)); search.setFaithfulnessAssumed(true); + IKnowledge knowledge = dataModels.get(0).getKnowledge(); search.setKnowledge(knowledge); return search.search(); @@ -51,7 +53,7 @@ public Graph search(DataModel dataSet, Parameters parameters) { @Override public Graph getComparisonGraph(Graph graph) { - return SearchGraphUtils.patternForDag(graph); + return SearchGraphUtils.patternForDag(new EdgeListGraph(graph)); } @Override @@ -66,7 +68,7 @@ public DataType getDataType() { @Override public List getParameters() { - List parameters = new Fgs(new BdeuScore()).getParameters(); + List parameters = new Fges(new BdeuScore()).getParameters(); parameters.add("randomSelectionSize"); return parameters; } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/multi/ImagesCcd.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/multi/ImagesCcd.java new file mode 100644 index 0000000000..413a2ab9a7 --- /dev/null +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/multi/ImagesCcd.java @@ -0,0 +1,98 @@ +package edu.cmu.tetrad.algcomparison.algorithm.multi; + +import edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm; +import edu.cmu.tetrad.algcomparison.utils.HasKnowledge; +import edu.cmu.tetrad.data.*; +import edu.cmu.tetrad.graph.EdgeListGraph; +import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.search.*; +import edu.cmu.tetrad.util.Parameters; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** + * Wraps the IMaGES algorithm for continuous variables. + *

+ * Requires that the parameter 'randomSelectionSize' be set to indicate how many + * datasets should be taken at a time (randomly). This cannot given multiple values. + * + * @author jdramsey + */ +public class ImagesCcd implements MultiDataSetAlgorithm, HasKnowledge { + static final long serialVersionUID = 23L; + private IKnowledge knowledge = new Knowledge2(); + + public ImagesCcd() { + } + + @Override + public Graph search(List dataSets, Parameters parameters) { + List dataModels = new ArrayList<>(); + + for (DataSet dataSet : dataSets) { + dataModels.add(dataSet); + } + + SemBicScoreImages2 score = new SemBicScoreImages2(dataModels); + score.setPenaltyDiscount(parameters.getDouble("penaltyDiscount")); + IndependenceTest test = new IndTestScore(score); + edu.cmu.tetrad.search.CcdMax search = new edu.cmu.tetrad.search.CcdMax(test); + search.setUseHeuristic(parameters.getBoolean("useMaxPOrientationHeuristic")); + search.setMaxPathLength(parameters.getInt("maxPOrientationMaxPathLength")); + search.setKnowledge(knowledge); + search.setDepth(parameters.getInt("depth")); + search.setApplyOrientAwayFromCollider(parameters.getBoolean("applyR1")); + search.setUseOrientTowardDConnections(parameters.getBoolean("orientTowardDConnections")); + return search.search(); + } + + @Override + public Graph search(DataModel dataSet, Parameters parameters) { + return search(Collections.singletonList(DataUtils.getContinuousDataSet(dataSet)), parameters); + } + + @Override + public Graph getComparisonGraph(Graph graph) { + return new EdgeListGraph(graph); + } + + @Override + public String getDescription() { + return "CCD-Max using the IMaGEs score for continuous variables"; + } + + @Override + public DataType getDataType() { + return DataType.Continuous; + } + + @Override + public List getParameters() { + List parameters = new ArrayList<>(); + parameters.add("penaltyDiscount"); + + parameters.add("depth"); + parameters.add("orientVisibleFeedbackLoops"); + parameters.add("useMaxPOrientationHeuristic"); + parameters.add("maxPOrientationMaxPathLength"); + parameters.add("applyR1"); + parameters.add("orientTowardDConnections"); + + parameters.add("numRandomSelections"); + parameters.add("randomSelectionSize"); + + return parameters; + } + + @Override + public IKnowledge getKnowledge() { + return knowledge; + } + + @Override + public void setKnowledge(IKnowledge knowledge) { + this.knowledge = knowledge; + } +} diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/multi/ImagesSemBic.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/multi/ImagesSemBic.java index fb02791fb6..8a54aeaf0c 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/multi/ImagesSemBic.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/multi/ImagesSemBic.java @@ -1,10 +1,11 @@ package edu.cmu.tetrad.algcomparison.algorithm.multi; import edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm; -import edu.cmu.tetrad.algcomparison.algorithm.oracle.pattern.Fgs; +import edu.cmu.tetrad.algcomparison.algorithm.oracle.pattern.Fges; import edu.cmu.tetrad.algcomparison.score.SemBicScore; import edu.cmu.tetrad.algcomparison.utils.HasKnowledge; import edu.cmu.tetrad.data.*; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.search.SemBicScoreImages; @@ -37,8 +38,9 @@ public Graph search(List dataSets, Parameters parameters) { dataModels.add(dataSet); } - edu.cmu.tetrad.search.Fgs search = new edu.cmu.tetrad.search.Fgs(new SemBicScoreImages(dataModels)); + edu.cmu.tetrad.search.Fges search = new edu.cmu.tetrad.search.Fges(new SemBicScoreImages(dataModels)); search.setFaithfulnessAssumed(true); + IKnowledge knowledge = dataModels.get(0).getKnowledge(); search.setKnowledge(knowledge); return search.search(); @@ -51,7 +53,7 @@ public Graph search(DataModel dataSet, Parameters parameters) { @Override public Graph getComparisonGraph(Graph graph) { - return new TsDagToPag(graph).convert(); + return new TsDagToPag(new EdgeListGraph(graph)).convert(); } @Override @@ -66,7 +68,7 @@ public DataType getDataType() { @Override public List getParameters() { - List parameters = new Fgs(new SemBicScore()).getParameters(); + List parameters = new Fges(new SemBicScore()).getParameters(); parameters.add("numRandomSelections"); parameters.add("randomSelectionSize"); return parameters; diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/multi/TsImagesSemBic.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/multi/TsImagesSemBic.java index 23b7757a84..62e4cd3082 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/multi/TsImagesSemBic.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/multi/TsImagesSemBic.java @@ -1,10 +1,11 @@ package edu.cmu.tetrad.algcomparison.algorithm.multi; import edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm; -import edu.cmu.tetrad.algcomparison.algorithm.oracle.pattern.Fgs; +import edu.cmu.tetrad.algcomparison.algorithm.oracle.pattern.Fges; import edu.cmu.tetrad.algcomparison.score.SemBicScore; import edu.cmu.tetrad.algcomparison.utils.HasKnowledge; import edu.cmu.tetrad.data.*; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.search.IndTestScore; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.graph.Graph; @@ -54,7 +55,7 @@ public Graph search(DataModel dataSet, Parameters parameters) { @Override public Graph getComparisonGraph(Graph graph) { - return SearchGraphUtils.patternForDag(graph); + return SearchGraphUtils.patternForDag(new EdgeListGraph(graph)); } @Override @@ -69,7 +70,7 @@ public DataType getDataType() { @Override public List getParameters() { - List parameters = new Fgs(new SemBicScore()).getParameters(); + List parameters = new Fges(new SemBicScore()).getParameters(); parameters.add("randomSelectionSize"); return parameters; } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Ccd.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Ccd.java index 752230d66d..36258cbc84 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Ccd.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Ccd.java @@ -4,6 +4,7 @@ import edu.cmu.tetrad.algcomparison.independence.IndependenceWrapper; import edu.cmu.tetrad.data.DataModel; import edu.cmu.tetrad.data.IKnowledge; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.data.DataType; import edu.cmu.tetrad.graph.Graph; @@ -37,7 +38,7 @@ public Graph search(DataModel dataSet, Parameters parameters) { @Override public Graph getComparisonGraph(Graph graph) { - return graph; + return new EdgeListGraph(graph); } @Override diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/CcdMax.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/CcdMax.java new file mode 100644 index 0000000000..acfef38722 --- /dev/null +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/CcdMax.java @@ -0,0 +1,81 @@ +package edu.cmu.tetrad.algcomparison.algorithm.oracle.pag; + +import edu.cmu.tetrad.algcomparison.algorithm.Algorithm; +import edu.cmu.tetrad.algcomparison.independence.IndependenceWrapper; +import edu.cmu.tetrad.algcomparison.utils.HasKnowledge; +import edu.cmu.tetrad.data.*; +import edu.cmu.tetrad.graph.EdgeListGraph; +import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.search.IndependenceTest; +import edu.cmu.tetrad.util.Parameters; + +import java.util.List; + +/** + * FGES (the heuristic version). + * + * @author jdramsey + */ +public class CcdMax implements Algorithm, HasKnowledge { + static final long serialVersionUID = 23L; + private IndependenceWrapper test; + private IKnowledge knowledge = new Knowledge2(); + + public CcdMax(IndependenceWrapper test) { + this.test = test; + } + + @Override + public Graph search(DataModel dataSet, Parameters parameters) { +// DataSet continuousDataSet = DataUtils.getContinuousDataSet(dataSet); + IndependenceTest test = this.test.getTest(dataSet, parameters); + edu.cmu.tetrad.search.CcdMax search = new edu.cmu.tetrad.search.CcdMax(test); + search.setOrientVisibleFeedbackLoops(parameters.getBoolean("orientVisibleFeedbackLoops")); + search.setDoColliderOrientations(parameters.getBoolean("doColliderOrientation")); + search.setUseHeuristic(parameters.getBoolean("useMaxPOrientationHeuristic")); + search.setMaxPathLength(parameters.getInt("maxPOrientationMaxPathLength")); + search.setKnowledge(knowledge); + search.setDepth(parameters.getInt("depth")); + search.setApplyOrientAwayFromCollider(parameters.getBoolean("applyR1")); + search.setUseOrientTowardDConnections(parameters.getBoolean("orientTowardDConnections")); + return search.search(); + } + + @Override + public Graph getComparisonGraph(Graph graph) { + return new EdgeListGraph(graph); + } + + @Override + public String getDescription() { + return "CCD-Max (Cyclic Discovery Search Max) using " + test.getDescription(); + } + + @Override + public DataType getDataType() { + return test.getDataType(); + } + + @Override + public List getParameters() { + List parameters = test.getParameters(); + parameters.add("depth"); + parameters.add("orientVisibleFeedbackLoops"); + parameters.add("doColliderOrientation"); + parameters.add("useMaxPOrientationHeuristic"); + parameters.add("maxPOrientationMaxPathLength"); + parameters.add("applyR1"); + parameters.add("orientTowardDConnections"); + return parameters; + } + + @Override + public IKnowledge getKnowledge() { + return knowledge; + } + + @Override + public void setKnowledge(IKnowledge knowledge) { + this.knowledge = knowledge; + } +} diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Cfci.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Cfci.java index 5a8af008b7..4036bdbc79 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Cfci.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Cfci.java @@ -6,6 +6,7 @@ import edu.cmu.tetrad.data.DataModel; import edu.cmu.tetrad.data.IKnowledge; import edu.cmu.tetrad.data.Knowledge2; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.data.DataType; import edu.cmu.tetrad.graph.Graph; @@ -31,12 +32,13 @@ public Cfci(IndependenceWrapper test) { public Graph search(DataModel dataSet, Parameters parameters) { edu.cmu.tetrad.search.Cfci search = new edu.cmu.tetrad.search.Cfci(test.getTest(dataSet, parameters)); search.setKnowledge(knowledge); + search.setCompleteRuleSetUsed(parameters.getBoolean("completeRuleSetUsed")); return search.search(); } @Override public Graph getComparisonGraph(Graph graph) { - return new DagToPag(graph).convert(); + return new DagToPag(new EdgeListGraph(graph)).convert(); } @Override @@ -53,6 +55,7 @@ public DataType getDataType() { public List getParameters() { List parameters = test.getParameters(); parameters.add("depth"); + parameters.add("completeRuleSetUsed"); return parameters; } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Fci.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Fci.java index 27070b1085..40dd425d5c 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Fci.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Fci.java @@ -6,6 +6,7 @@ import edu.cmu.tetrad.data.DataModel; import edu.cmu.tetrad.data.IKnowledge; import edu.cmu.tetrad.data.Knowledge2; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph; import edu.cmu.tetrad.data.DataType; @@ -44,6 +45,8 @@ public Graph search(DataModel dataSet, Parameters parameters) { edu.cmu.tetrad.search.Fci search = new edu.cmu.tetrad.search.Fci(test.getTest(dataSet, parameters)); search.setKnowledge(knowledge); + search.setMaxPathLength(parameters.getInt("maxPathLength")); + search.setCompleteRuleSetUsed(parameters.getBoolean("completeRuleSetUsed")); // if (initial != null) { // search.setInitialGraph(initial); @@ -54,7 +57,7 @@ public Graph search(DataModel dataSet, Parameters parameters) { @Override public Graph getComparisonGraph(Graph graph) { - return new DagToPag(graph).convert(); + return new DagToPag(new EdgeListGraph(graph)).convert(); } public String getDescription() { @@ -72,6 +75,8 @@ public DataType getDataType() { public List getParameters() { List parameters = test.getParameters(); parameters.add("depth"); + parameters.add("maxPathLength"); + parameters.add("completeRuleSetUsed"); return parameters; } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/GCcd.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/GCcd.java deleted file mode 100644 index 8616426d1e..0000000000 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/GCcd.java +++ /dev/null @@ -1,64 +0,0 @@ -package edu.cmu.tetrad.algcomparison.algorithm.oracle.pag; - -import edu.cmu.tetrad.algcomparison.algorithm.Algorithm; -import edu.cmu.tetrad.algcomparison.independence.IndependenceWrapper; -import edu.cmu.tetrad.algcomparison.score.ScoreWrapper; -import edu.cmu.tetrad.data.*; -import edu.cmu.tetrad.graph.Graph; -import edu.cmu.tetrad.search.IndependenceTest; -import edu.cmu.tetrad.search.Score; -import edu.cmu.tetrad.search.SearchGraphUtils; -import edu.cmu.tetrad.util.Parameters; - -import java.util.List; - -/** - * FGS (the heuristic version). - * - * @author jdramsey - */ -public class GCcd implements Algorithm { - static final long serialVersionUID = 23L; - private IndependenceWrapper test; - private ScoreWrapper score; - private IKnowledge knowledge = new Knowledge2(); - - public GCcd(IndependenceWrapper test, ScoreWrapper score) { - this.test = test; - this.score = score; - } - - @Override - public Graph search(DataModel dataSet, Parameters parameters) { - DataSet continuousDataSet = DataUtils.getContinuousDataSet(dataSet); - IndependenceTest test = this.test.getTest(continuousDataSet, parameters); - Score score = this.score.getScore(continuousDataSet, parameters); - edu.cmu.tetrad.search.GCcd search = new edu.cmu.tetrad.search.GCcd(test, score); - search.setApplyR1(parameters.getBoolean("applyR1")); - search.setKnowledge(knowledge); - return search.search(); - } - - @Override - public Graph getComparisonGraph(Graph graph) { - return graph; - } - - @Override - public String getDescription() { - return "GCCD (Greedy Cyclic Discovery Search) using " + test.getDescription(); - } - - @Override - public DataType getDataType() { - return test.getDataType(); - } - - @Override - public List getParameters() { - List parameters = test.getParameters(); - parameters.add("depth"); - parameters.add("applyR1"); - return parameters; - } -} diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Gfci.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Gfci.java index 4d037bb586..4d6286bab6 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Gfci.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Gfci.java @@ -7,6 +7,7 @@ import edu.cmu.tetrad.data.*; import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.search.DagToPag; +import edu.cmu.tetrad.search.GFci; import edu.cmu.tetrad.search.GFciMax; import edu.cmu.tetrad.util.Parameters; import java.io.PrintStream; @@ -31,11 +32,13 @@ public Gfci(IndependenceWrapper test, ScoreWrapper score) { @Override public Graph search(DataModel dataSet, Parameters parameters) { - GFciMax search = new GFciMax(test.getTest(dataSet, parameters), score.getScore(dataSet, parameters)); + GFci search = new GFci(test.getTest(dataSet, parameters), score.getScore(dataSet, parameters)); search.setMaxDegree(parameters.getInt("maxDegree")); search.setKnowledge(knowledge); search.setVerbose(parameters.getBoolean("verbose")); search.setFaithfulnessAssumed(parameters.getBoolean("faithfulnessAssumed")); + search.setMaxPathLength(parameters.getInt("maxPathLength")); + search.setCompleteRuleSetUsed(parameters.getBoolean("completeRuleSetUsed")); Object obj = parameters.get("printStream"); @@ -68,6 +71,8 @@ public List getParameters() { parameters.add("faithfulnessAssumed"); parameters.add("maxDegree"); parameters.add("printStream"); + parameters.add("maxPathLength"); + parameters.add("completeRuleSetUsed"); return parameters; } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Rfci.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Rfci.java index 5f15dc478c..5f17f4efa9 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Rfci.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Rfci.java @@ -6,6 +6,7 @@ import edu.cmu.tetrad.data.DataModel; import edu.cmu.tetrad.data.IKnowledge; import edu.cmu.tetrad.data.Knowledge2; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.data.DataType; import edu.cmu.tetrad.graph.Graph; @@ -31,12 +32,14 @@ public Rfci(IndependenceWrapper test) { public Graph search(DataModel dataSet, Parameters parameters) { edu.cmu.tetrad.search.Rfci search = new edu.cmu.tetrad.search.Rfci(test.getTest(dataSet, parameters)); search.setKnowledge(knowledge); + search.setMaxPathLength(parameters.getInt("maxPathLength")); + search.setCompleteRuleSetUsed(parameters.getBoolean("completeRuleSetUsed")); return search.search(); } @Override public Graph getComparisonGraph(Graph graph) { - return new DagToPag(graph).convert(); + return new DagToPag(new EdgeListGraph(graph)).convert(); } public String getDescription() { @@ -52,6 +55,8 @@ public DataType getDataType() { public List getParameters() { List parameters = test.getParameters(); parameters.add("depth"); + parameters.add("maxPathLength"); + parameters.add("completeRuleSetUsed"); return parameters; } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/TsFci.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/TsFci.java index c112fec665..59cdbfb210 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/TsFci.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/TsFci.java @@ -4,6 +4,7 @@ import edu.cmu.tetrad.algcomparison.independence.IndependenceWrapper; import edu.cmu.tetrad.algcomparison.utils.HasKnowledge; import edu.cmu.tetrad.data.DataModel; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph; import edu.cmu.tetrad.data.DataType; @@ -36,25 +37,13 @@ public TsFci(IndependenceWrapper type, Algorithm initialGraph) { @Override public Graph search(DataModel dataSet, Parameters parameters) { - Graph initial = null; - - if (initialGraph != null) { - initial = initialGraph.search(dataSet, parameters); - } - edu.cmu.tetrad.search.TsFci search = new edu.cmu.tetrad.search.TsFci(test.getTest(dataSet, parameters)); - -// if (initial != null) { -// search.setInitialGraph(initial); -// } - - search.setKnowledge(knowledge); - + search.setKnowledge(dataSet.getKnowledge()); return search.search(); } @Override - public Graph getComparisonGraph(Graph graph) { return new TsDagToPag(graph).convert(); } + public Graph getComparisonGraph(Graph graph) { return new TsDagToPag(new EdgeListGraph(graph)).convert(); } public String getDescription() { return "tsFCI (Time Series Fast Causal Inference) using " + test.getDescription() + diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/TsGfci.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/TsGfci.java index 6d892a8a97..129ff147ec 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/TsGfci.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/TsGfci.java @@ -8,6 +8,7 @@ import edu.cmu.tetrad.data.DataModel; import edu.cmu.tetrad.data.DataType; import edu.cmu.tetrad.data.IKnowledge; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.search.TsDagToPag; import edu.cmu.tetrad.util.Parameters; @@ -36,12 +37,12 @@ public TsGfci(IndependenceWrapper type, ScoreWrapper score) { public Graph search(DataModel dataSet, Parameters parameters) { edu.cmu.tetrad.search.TsGFci search = new edu.cmu.tetrad.search.TsGFci(test.getTest(dataSet, parameters), score.getScore(dataSet, parameters)); - search.setKnowledge(knowledge); + search.setKnowledge(dataSet.getKnowledge()); return search.search(); } @Override - public Graph getComparisonGraph(Graph graph) { return new TsDagToPag(graph).convert(); } + public Graph getComparisonGraph(Graph graph) { return new TsDagToPag(new EdgeListGraph(graph)).convert(); } public String getDescription() { return "tsGFCI (Time Series GFCI) using " + test.getDescription() + " and " + score.getDescription() + diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/TsImages.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/TsImages.java index 4fb83aed34..8611b45d5e 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/TsImages.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/TsImages.java @@ -1,16 +1,19 @@ package edu.cmu.tetrad.algcomparison.algorithm.oracle.pag; import edu.cmu.tetrad.algcomparison.algorithm.Algorithm; +import edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm; import edu.cmu.tetrad.algcomparison.independence.IndependenceWrapper; +import edu.cmu.tetrad.algcomparison.score.ScoreWrapper; import edu.cmu.tetrad.algcomparison.utils.HasKnowledge; import edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph; -import edu.cmu.tetrad.data.DataModel; -import edu.cmu.tetrad.data.DataType; -import edu.cmu.tetrad.data.IKnowledge; +import edu.cmu.tetrad.data.*; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.graph.Graph; -import edu.cmu.tetrad.search.TsDagToPag; +import edu.cmu.tetrad.graph.Node; +import edu.cmu.tetrad.search.*; import edu.cmu.tetrad.util.Parameters; +import java.util.ArrayList; import java.util.List; /** @@ -19,57 +22,50 @@ * @author jdramsey * @author Daniel Malinsky */ -public class TsImages implements Algorithm, TakesInitialGraph, HasKnowledge { +public class TsImages implements Algorithm, HasKnowledge, MultiDataSetAlgorithm { static final long serialVersionUID = 23L; - private IndependenceWrapper test; + private ScoreWrapper score; private Algorithm initialGraph = null; private IKnowledge knowledge = null; - public TsImages(IndependenceWrapper type) { - this.test = type; - } + public TsImages(ScoreWrapper score) { + if (!(score instanceof SemBicScore || score instanceof BDeuScore)) { + throw new IllegalArgumentException("Only SEM BIC score or BDeu score can be used with this, sorry."); + } - public TsImages(IndependenceWrapper type, Algorithm initialGraph) { - this.test = type; - this.initialGraph = initialGraph; + this.score = score; } @Override - public Graph search(DataModel dataSet, Parameters parameters) { - Graph initial = null; - - if (initialGraph != null) { - initial = initialGraph.search(dataSet, parameters); - } - - edu.cmu.tetrad.search.TsFci search = new edu.cmu.tetrad.search.TsFci(test.getTest(dataSet, parameters)); - -// if (initial != null) { -// search.setInitialGraph(initial); -// } - - search.setKnowledge(knowledge); - + public Graph search(DataModel dataModel, Parameters parameters) { + DataSet dataSet = (DataSet) dataModel; + TsGFci search; + Score score1 = score.getScore(dataSet, parameters); + IndependenceTest test = new IndTestScore(score1); + search = new TsGFci(test, score1); + search.setKnowledge(dataSet.getKnowledge()); return search.search(); } @Override - public Graph getComparisonGraph(Graph graph) { return new TsDagToPag(graph).convert(); } + public Graph getComparisonGraph(Graph graph) { + return new TsDagToPag(new EdgeListGraph(graph)).convert(); + } public String getDescription() { - return "tsFCI (Time Series Fast Causal Inference) using " + test.getDescription() + + return "tsFCI (Time Series Fast Causal Inference) using " + score.getDescription() + (initialGraph != null ? " with initial graph from " + initialGraph.getDescription() : ""); } @Override public DataType getDataType() { - return test.getDataType(); + return score.getDataType(); } @Override public List getParameters() { - return test.getParameters(); + return score.getParameters(); } @Override @@ -81,4 +77,36 @@ public IKnowledge getKnowledge() { public void setKnowledge(IKnowledge knowledge) { this.knowledge = knowledge; } + + @Override + public Graph search(List dataSets, Parameters parameters) { + List dataModels = new ArrayList<>(); + + for (DataSet dataSet : dataSets) { + dataModels.add(dataSet); + } + + TsGFci search; + + if (score instanceof SemBicScore) { + SemBicScoreImages gesScore = new SemBicScoreImages(dataModels); + gesScore.setPenaltyDiscount(parameters.getDouble("penaltyDiscount")); + IndependenceTest test = new IndTestScore(gesScore); + search = new TsGFci(test, gesScore); + } else if (score instanceof BDeScore) { + double samplePrior = parameters.getDouble("samplePrior", 1); + double structurePrior = parameters.getDouble("structurePrior", 1); + BdeuScoreImages score = new BdeuScoreImages(dataModels); + score.setSamplePrior(samplePrior); + score.setStructurePrior(structurePrior); + IndependenceTest test = new IndTestScore(score); + search = new TsGFci(test, score); + } else { + throw new IllegalStateException("Sorry, data must either be all continuous or all discrete."); + } + + IKnowledge knowledge = dataModels.get(0).getKnowledge(); + search.setKnowledge(knowledge); + return search.search(); + } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/Cpc.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/Cpc.java index 12c42d1cc4..ec04f7778c 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/Cpc.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/Cpc.java @@ -6,6 +6,7 @@ import edu.cmu.tetrad.data.DataModel; import edu.cmu.tetrad.data.IKnowledge; import edu.cmu.tetrad.data.Knowledge2; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph; import edu.cmu.tetrad.data.DataType; @@ -54,7 +55,7 @@ public Graph search(DataModel dataSet, Parameters parameters) { @Override public Graph getComparisonGraph(Graph graph) { - return SearchGraphUtils.patternForDag(graph); + return SearchGraphUtils.patternForDag(new EdgeListGraph(graph)); } @Override diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/CpcStable.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/CpcStable.java index a6b13f82a5..45af6cf241 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/CpcStable.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/CpcStable.java @@ -6,6 +6,7 @@ import edu.cmu.tetrad.data.DataModel; import edu.cmu.tetrad.data.IKnowledge; import edu.cmu.tetrad.data.Knowledge2; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.data.DataType; import edu.cmu.tetrad.graph.Graph; @@ -36,7 +37,7 @@ public Graph search(DataModel dataSet, Parameters parameters) { @Override public Graph getComparisonGraph(Graph graph) { - return SearchGraphUtils.patternForDag(graph); + return SearchGraphUtils.patternForDag(new EdgeListGraph(graph)); } @Override diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/FAS.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/FAS.java index ac562ef862..b2b5ee8b3e 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/FAS.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/FAS.java @@ -7,6 +7,7 @@ import edu.cmu.tetrad.data.DataType; import edu.cmu.tetrad.data.IKnowledge; import edu.cmu.tetrad.data.Knowledge2; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.search.SearchGraphUtils; import edu.cmu.tetrad.util.Parameters; @@ -37,7 +38,7 @@ public Graph search(DataModel dataSet, Parameters parameters) { @Override public Graph getComparisonGraph(Graph graph) { - return SearchGraphUtils.patternForDag(graph); + return SearchGraphUtils.patternForDag(new EdgeListGraph(graph)); } @Override diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/Fgs.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/Fges.java similarity index 73% rename from tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/Fgs.java rename to tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/Fges.java index edb8b05f79..d2b45edf21 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/Fgs.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/Fges.java @@ -4,30 +4,35 @@ import edu.cmu.tetrad.algcomparison.score.ScoreWrapper; import edu.cmu.tetrad.algcomparison.utils.HasKnowledge; import edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph; -import edu.cmu.tetrad.data.*; +import edu.cmu.tetrad.data.DataModel; +import edu.cmu.tetrad.data.DataType; +import edu.cmu.tetrad.data.IKnowledge; +import edu.cmu.tetrad.data.Knowledge2; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.search.SearchGraphUtils; import edu.cmu.tetrad.util.Parameters; + import java.io.PrintStream; import java.util.List; /** - * FGS (the heuristic version). + * FGES (the heuristic version). * * @author jdramsey */ -public class Fgs implements Algorithm, TakesInitialGraph, HasKnowledge { +public class Fges implements Algorithm, TakesInitialGraph, HasKnowledge { static final long serialVersionUID = 23L; private ScoreWrapper score; private Algorithm initialGraph = null; private IKnowledge knowledge = new Knowledge2(); - public Fgs(ScoreWrapper score) { + public Fges(ScoreWrapper score) { this.score = score; } - public Fgs(ScoreWrapper score, Algorithm initialGraph) { + public Fges(ScoreWrapper score, Algorithm initialGraph) { this.score = score; this.initialGraph = initialGraph; } @@ -40,14 +45,14 @@ public Graph search(DataModel dataSet, Parameters parameters) { initial = initialGraph.search(dataSet, parameters); } - edu.cmu.tetrad.search.Fgs search = - new edu.cmu.tetrad.search.Fgs(score.getScore(dataSet, parameters)); + edu.cmu.tetrad.search.Fges search + = new edu.cmu.tetrad.search.Fges(score.getScore(dataSet, parameters)); search.setFaithfulnessAssumed(parameters.getBoolean("faithfulnessAssumed")); search.setKnowledge(knowledge); search.setVerbose(parameters.getBoolean("verbose")); search.setMaxDegree(parameters.getInt("maxDegree")); - Object obj = parameters.get("printStream"); + Object obj = parameters.get("printStedu.cmream"); if (obj instanceof PrintStream) { search.setOut((PrintStream) obj); } @@ -61,12 +66,13 @@ public Graph search(DataModel dataSet, Parameters parameters) { @Override public Graph getComparisonGraph(Graph graph) { - return SearchGraphUtils.patternForDag(graph); +// return new EdgeListGraph(graph); + return SearchGraphUtils.patternForDag(new EdgeListGraph(graph)); } @Override public String getDescription() { - return "FGS (Fast Greedy Search) using " + score.getDescription(); + return "FGES (Fast Greedy Equivalence Search) using " + score.getDescription(); } @Override @@ -80,7 +86,6 @@ public List getParameters() { parameters.add("faithfulnessAssumed"); parameters.add("maxDegree"); parameters.add("verbose"); - parameters.add("printStream"); return parameters; } @@ -93,4 +98,5 @@ public IKnowledge getKnowledge() { public void setKnowledge(IKnowledge knowledge) { this.knowledge = knowledge; } + } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/FgsMb.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/FgesMb.java similarity index 82% rename from tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/FgsMb.java rename to tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/FgesMb.java index 1cfe1b5cfe..b4d27e7f4d 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/FgsMb.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/FgesMb.java @@ -5,9 +5,11 @@ import edu.cmu.tetrad.algcomparison.utils.HasKnowledge; import edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph; import edu.cmu.tetrad.data.*; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.graph.GraphUtils; import edu.cmu.tetrad.graph.Node; +import edu.cmu.tetrad.search.FgesMb2; import edu.cmu.tetrad.search.Score; import edu.cmu.tetrad.util.Parameters; @@ -15,22 +17,22 @@ import java.util.List; /** - * FGS (the heuristic version). + * FGES (the heuristic version). * * @author jdramsey */ -public class FgsMb implements Algorithm, TakesInitialGraph, HasKnowledge { +public class FgesMb implements Algorithm, TakesInitialGraph, HasKnowledge { static final long serialVersionUID = 23L; private ScoreWrapper score; private Algorithm initialGraph = null; private IKnowledge knowledge = new Knowledge2(); private String targetName; - public FgsMb(ScoreWrapper score) { + public FgesMb(ScoreWrapper score) { this.score = score; } - public FgsMb(ScoreWrapper score, Algorithm initialGraph) { + public FgesMb(ScoreWrapper score, Algorithm initialGraph) { this.score = score; this.initialGraph = initialGraph; } @@ -44,8 +46,8 @@ public Graph search(DataModel dataSet, Parameters parameters) { } Score score = this.score.getScore(DataUtils.getContinuousDataSet(dataSet), parameters); - edu.cmu.tetrad.search.FgsMb2 search - = new edu.cmu.tetrad.search.FgsMb2(score); + FgesMb2 search + = new FgesMb2(score); search.setFaithfulnessAssumed(parameters.getBoolean("faithfulnessAssumed")); search.setKnowledge(knowledge); @@ -62,12 +64,12 @@ public Graph search(DataModel dataSet, Parameters parameters) { @Override public Graph getComparisonGraph(Graph graph) { Node target = graph.getNode(targetName); - return GraphUtils.markovBlanketDag(target, graph); + return GraphUtils.markovBlanketDag(target, new EdgeListGraph(graph)); } @Override public String getDescription() { - return "FGS (Fast Greedy Search) using " + score.getDescription(); + return "FGES (Fast Greedy Search) using " + score.getDescription(); } @Override diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/FgsMeasurement.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/FgesMeasurement.java similarity index 82% rename from tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/FgsMeasurement.java rename to tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/FgesMeasurement.java index 40f50ee502..3caff409c4 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/FgsMeasurement.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/FgesMeasurement.java @@ -5,7 +5,9 @@ import edu.cmu.tetrad.algcomparison.utils.HasKnowledge; import edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph; import edu.cmu.tetrad.data.*; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.search.Fges; import edu.cmu.tetrad.search.SearchGraphUtils; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.util.RandomUtil; @@ -13,21 +15,21 @@ import java.util.List; /** - * FGS (the heuristic version). + * FGES (the heuristic version). * * @author jdramsey */ -public class FgsMeasurement implements Algorithm, TakesInitialGraph, HasKnowledge { +public class FgesMeasurement implements Algorithm, TakesInitialGraph, HasKnowledge { static final long serialVersionUID = 23L; private ScoreWrapper score; private Algorithm initialGraph = null; private IKnowledge knowledge = new Knowledge2(); - public FgsMeasurement(ScoreWrapper score) { + public FgesMeasurement(ScoreWrapper score) { this.score = score; } - public FgsMeasurement(ScoreWrapper score, Algorithm initialGraph) { + public FgesMeasurement(ScoreWrapper score, Algorithm initialGraph) { this.score = score; this.initialGraph = initialGraph; } @@ -50,7 +52,7 @@ public Graph search(DataModel dataModel, Parameters parameters) { } } - edu.cmu.tetrad.search.Fgs search = new edu.cmu.tetrad.search.Fgs(score.getScore(dataSet, parameters)); + Fges search = new Fges(score.getScore(dataSet, parameters)); search.setFaithfulnessAssumed(parameters.getBoolean("faithfulnessAssumed")); search.setKnowledge(knowledge); search.setVerbose(parameters.getBoolean("verbose")); @@ -60,12 +62,12 @@ public Graph search(DataModel dataModel, Parameters parameters) { @Override public Graph getComparisonGraph(Graph graph) { - return SearchGraphUtils.patternForDag(graph); + return SearchGraphUtils.patternForDag(new EdgeListGraph(graph)); } @Override public String getDescription() { - return "FGS adding measuremnt noise using " + score.getDescription(); + return "FGES adding measuremnt noise using " + score.getDescription(); } @Override diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/Jcpc.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/Jcpc.java index 770e24a800..1644a7cf52 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/Jcpc.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/Jcpc.java @@ -6,6 +6,7 @@ import edu.cmu.tetrad.algcomparison.utils.HasKnowledge; import edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph; import edu.cmu.tetrad.data.*; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.search.SearchGraphUtils; import edu.cmu.tetrad.util.Parameters; @@ -48,7 +49,7 @@ public Graph search(DataModel dataSet, Parameters parameters) { @Override public Graph getComparisonGraph(Graph graph) { - return SearchGraphUtils.patternForDag(graph); + return SearchGraphUtils.patternForDag(new EdgeListGraph(graph)); } @Override diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/MBFS.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/MBFS.java index 7f2e40e5d1..e19e15ee7c 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/MBFS.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/MBFS.java @@ -6,6 +6,7 @@ import edu.cmu.tetrad.data.DataModel; import edu.cmu.tetrad.data.IKnowledge; import edu.cmu.tetrad.data.Knowledge2; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.graph.GraphUtils; import edu.cmu.tetrad.graph.Node; import edu.cmu.tetrad.search.IndependenceTest; @@ -48,7 +49,7 @@ public Graph search(DataModel dataSet, Parameters parameters) { @Override public Graph getComparisonGraph(Graph graph) { Node target = graph.getNode(targetName); - return GraphUtils.markovBlanketDag(target, graph); + return GraphUtils.markovBlanketDag(target, new EdgeListGraph(graph)); } @Override diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/Pc.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/Pc.java index d5001aa40f..bc1a381ef7 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/Pc.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/Pc.java @@ -6,6 +6,7 @@ import edu.cmu.tetrad.data.DataModel; import edu.cmu.tetrad.data.IKnowledge; import edu.cmu.tetrad.data.Knowledge2; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph; import edu.cmu.tetrad.data.DataType; @@ -44,7 +45,7 @@ public Graph search(DataModel dataSet, Parameters parameters) { @Override public Graph getComparisonGraph(Graph graph) { - return SearchGraphUtils.patternForDag(graph); + return SearchGraphUtils.patternForDag(new EdgeListGraph(graph)); } @Override diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/PcMax.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/PcMax.java index d5b8f0105f..fe480104f2 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/PcMax.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/PcMax.java @@ -8,6 +8,7 @@ import edu.cmu.tetrad.data.DataType; import edu.cmu.tetrad.data.IKnowledge; import edu.cmu.tetrad.data.Knowledge2; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.search.SearchGraphUtils; import edu.cmu.tetrad.util.Parameters; @@ -15,7 +16,7 @@ import java.util.List; /** - * PC. + * PC-Max * * @author jdramsey */ @@ -33,13 +34,16 @@ public PcMax(IndependenceWrapper test) { public Graph search(DataModel dataSet, Parameters parameters) { edu.cmu.tetrad.search.PcMax search = new edu.cmu.tetrad.search.PcMax( test.getTest(dataSet, parameters)); + search.setUseHeuristic(parameters.getBoolean("useMaxPOrientationHeuristic")); + search.setMaxPathLength(parameters.getInt("maxPOrientationMaxPathLength")); search.setKnowledge(knowledge); return search.search(); } @Override public Graph getComparisonGraph(Graph graph) { - return SearchGraphUtils.patternForDag(graph); +// return new EdgeListGraph(graph); + return SearchGraphUtils.patternForDag(new EdgeListGraph(graph)); } @Override @@ -58,6 +62,8 @@ public DataType getDataType() { public List getParameters() { List parameters = test.getParameters(); parameters.add("depth"); + parameters.add("useMaxPOrientationHeuristic"); + parameters.add("maxPOrientationMaxPathLength"); return parameters; } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/PcStable.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/PcStable.java index 736356acad..5d712d121a 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/PcStable.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/PcStable.java @@ -6,6 +6,7 @@ import edu.cmu.tetrad.data.DataModel; import edu.cmu.tetrad.data.IKnowledge; import edu.cmu.tetrad.data.Knowledge2; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph; import edu.cmu.tetrad.data.DataType; @@ -55,7 +56,7 @@ public Graph search(DataModel dataSet, Parameters parameters) { @Override public Graph getComparisonGraph(Graph graph) { - return SearchGraphUtils.patternForDag(graph); + return SearchGraphUtils.patternForDag(new EdgeListGraph(graph)); } @Override diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/SingleGraphAlg.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/SingleGraphAlg.java index 63c0842742..4f158d7009 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/SingleGraphAlg.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/SingleGraphAlg.java @@ -7,6 +7,7 @@ import edu.cmu.tetrad.data.DataType; import edu.cmu.tetrad.data.IKnowledge; import edu.cmu.tetrad.data.Knowledge2; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.util.Parameters; @@ -33,7 +34,7 @@ public Graph search(DataModel dataSet, Parameters parameters) { @Override public Graph getComparisonGraph(Graph graph) { - return graph; + return new EdgeListGraph(graph); } @Override diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/other/Glasso.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/other/Glasso.java index 286cb4595a..beb03a9549 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/other/Glasso.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/other/Glasso.java @@ -6,6 +6,7 @@ import edu.cmu.tetrad.data.*; import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.graph.GraphUtils; import edu.cmu.tetrad.graph.Node; import edu.cmu.tetrad.search.SearchGraphUtils; import edu.cmu.tetrad.util.Parameters; @@ -52,7 +53,7 @@ public Graph search(DataModel ds, Parameters parameters) { } public Graph getComparisonGraph(Graph graph) { - return SearchGraphUtils.patternForDag(graph); + return GraphUtils.undirectedGraph(graph); } public String getDescription() { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/EB.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/EB.java index 9d0c22f9c4..054015452b 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/EB.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/EB.java @@ -3,6 +3,7 @@ import edu.cmu.tetrad.algcomparison.algorithm.Algorithm; import edu.cmu.tetrad.data.DataModel; import edu.cmu.tetrad.data.DataUtils; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph; import edu.cmu.tetrad.data.DataSet; @@ -48,7 +49,7 @@ public Graph search(DataModel dataSet, Parameters parameters) { @Override public Graph getComparisonGraph(Graph graph) { - return graph; + return new EdgeListGraph(graph); } @Override diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/R1.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/R1.java index 30542d1e9f..42d2e5e612 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/R1.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/R1.java @@ -3,6 +3,7 @@ import edu.cmu.tetrad.algcomparison.algorithm.Algorithm; import edu.cmu.tetrad.data.DataModel; import edu.cmu.tetrad.data.DataUtils; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph; import edu.cmu.tetrad.data.DataSet; @@ -48,7 +49,7 @@ public Graph search(DataModel dataSet, Parameters parameters) { @Override public Graph getComparisonGraph(Graph graph) { - return graph; + return new EdgeListGraph(graph); } @Override diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/R2.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/R2.java index 9c7b8e2785..c745c6a044 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/R2.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/R2.java @@ -3,6 +3,7 @@ import edu.cmu.tetrad.algcomparison.algorithm.Algorithm; import edu.cmu.tetrad.data.DataModel; import edu.cmu.tetrad.data.DataUtils; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph; import edu.cmu.tetrad.data.DataSet; @@ -48,7 +49,7 @@ public Graph search(DataModel dataSet, Parameters parameters) { @Override public Graph getComparisonGraph(Graph graph) { - return graph; + return new EdgeListGraph(graph); } @Override diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/R3.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/R3.java index 82a20fa0f1..89cbd46e97 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/R3.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/R3.java @@ -3,6 +3,7 @@ import edu.cmu.tetrad.algcomparison.algorithm.Algorithm; import edu.cmu.tetrad.data.DataModel; import edu.cmu.tetrad.data.DataUtils; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph; import edu.cmu.tetrad.data.DataSet; @@ -48,7 +49,7 @@ public Graph search(DataModel dataSet, Parameters parameters) { @Override public Graph getComparisonGraph(Graph graph) { - return graph; + return new EdgeListGraph(graph); } @Override diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/R4.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/R4.java index 45aa4e703d..2fd0cf74d1 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/R4.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/R4.java @@ -3,6 +3,7 @@ import edu.cmu.tetrad.algcomparison.algorithm.Algorithm; import edu.cmu.tetrad.data.DataModel; import edu.cmu.tetrad.data.DataUtils; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph; import edu.cmu.tetrad.data.DataSet; @@ -48,7 +49,7 @@ public Graph search(DataModel dataSet, Parameters parameters) { @Override public Graph getComparisonGraph(Graph graph) { - return graph; + return new EdgeListGraph(graph); } @Override diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/RSkew.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/RSkew.java index 76ec9f7ee9..5a0471ed1d 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/RSkew.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/RSkew.java @@ -3,6 +3,7 @@ import edu.cmu.tetrad.algcomparison.algorithm.Algorithm; import edu.cmu.tetrad.data.DataModel; import edu.cmu.tetrad.data.DataUtils; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph; import edu.cmu.tetrad.data.DataSet; @@ -48,7 +49,7 @@ public Graph search(DataModel dataSet, Parameters parameters) { @Override public Graph getComparisonGraph(Graph graph) { - return graph; + return new EdgeListGraph(graph); } @Override diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/RSkewE.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/RSkewE.java index ca651ba8a4..afc18a2ae1 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/RSkewE.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/RSkewE.java @@ -3,6 +3,7 @@ import edu.cmu.tetrad.algcomparison.algorithm.Algorithm; import edu.cmu.tetrad.data.DataModel; import edu.cmu.tetrad.data.DataUtils; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph; import edu.cmu.tetrad.data.DataSet; @@ -48,7 +49,7 @@ public Graph search(DataModel dataSet, Parameters parameters) { @Override public Graph getComparisonGraph(Graph graph) { - return graph; + return new EdgeListGraph(graph); } @Override diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/Skew.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/Skew.java index 3964ac7ce8..af3f3adb45 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/Skew.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/Skew.java @@ -3,6 +3,7 @@ import edu.cmu.tetrad.algcomparison.algorithm.Algorithm; import edu.cmu.tetrad.data.DataModel; import edu.cmu.tetrad.data.DataUtils; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph; import edu.cmu.tetrad.data.DataSet; @@ -48,7 +49,7 @@ public Graph search(DataModel dataSet, Parameters parameters) { @Override public Graph getComparisonGraph(Graph graph) { - return graph; + return new EdgeListGraph(graph); } @Override diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/SkewE.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/SkewE.java index 231517def7..758e0be98a 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/SkewE.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/SkewE.java @@ -3,6 +3,7 @@ import edu.cmu.tetrad.algcomparison.algorithm.Algorithm; import edu.cmu.tetrad.data.DataModel; import edu.cmu.tetrad.data.DataUtils; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph; import edu.cmu.tetrad.data.DataSet; @@ -48,7 +49,7 @@ public Graph search(DataModel dataSet, Parameters parameters) { @Override public Graph getComparisonGraph(Graph graph) { - return graph; + return new EdgeListGraph(graph); } @Override diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/Tanh.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/Tanh.java index 9045e2668a..54f021775b 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/Tanh.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/pairwise/Tanh.java @@ -3,6 +3,7 @@ import edu.cmu.tetrad.algcomparison.algorithm.Algorithm; import edu.cmu.tetrad.data.DataModel; import edu.cmu.tetrad.data.DataUtils; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph; import edu.cmu.tetrad.data.DataSet; @@ -48,7 +49,7 @@ public Graph search(DataModel dataSet, Parameters parameters) { @Override public Graph getComparisonGraph(Graph graph) { - return graph; + return new EdgeListGraph(graph); } @Override diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/examples/ExampleCompareSimulation.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/examples/ExampleCompareSimulation.java index 5cd0d3e328..201ace74f6 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/examples/ExampleCompareSimulation.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/examples/ExampleCompareSimulation.java @@ -66,7 +66,7 @@ public static void main(String... args) { Algorithms algorithms = new Algorithms(); algorithms.add(new Pc(new FisherZ())); - algorithms.add(new Cpc(new FisherZ(), new Fgs(new SemBicScore()))); + algorithms.add(new Cpc(new FisherZ(), new Fges(new SemBicScore()))); algorithms.add(new PcStable(new FisherZ())); algorithms.add(new CpcStable(new FisherZ())); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/examples/ExampleCompareSimulationTimeSeries.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/examples/ExampleCompareSimulationTimeSeries.java index f0bedf3e76..79d9d118d7 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/examples/ExampleCompareSimulationTimeSeries.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/examples/ExampleCompareSimulationTimeSeries.java @@ -73,7 +73,7 @@ public static void main(String... args) { algorithms.add(new TsFci(new FisherZ())); algorithms.add(new TsGfci(new FisherZ(), new SemBicScore())); - algorithms.add(new TsImages(new FisherZ())); + algorithms.add(new TsImages(new SemBicScore())); Simulations simulations = new Simulations(); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/independence/ConditionalGaussianLRT.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/independence/ConditionalGaussianLRT.java index 1e121ef7c1..5de3dde18b 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/independence/ConditionalGaussianLRT.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/independence/ConditionalGaussianLRT.java @@ -21,8 +21,12 @@ public class ConditionalGaussianLRT implements IndependenceWrapper, Experimental @Override public IndependenceTest getTest(DataModel dataSet, Parameters parameters) { - return new IndTestConditionalGaussianLRT(DataUtils.getMixedDataSet(dataSet), + final IndTestConditionalGaussianLRT test + = new IndTestConditionalGaussianLRT(DataUtils.getMixedDataSet(dataSet), parameters.getDouble("alpha")); + test.setPenaltyDiscount(parameters.getDouble("penaltyDiscount")); + test.setDenominatorMixed(parameters.getBoolean("assumeMixed")); + return test; } @Override @@ -37,9 +41,11 @@ public DataType getDataType() { @Override public List getParameters() { - List params = new ArrayList<>(); - params.add("alpha"); - return params; + List parameters = new ArrayList<>(); + parameters.add("alpha"); + parameters.add("assumeMixed"); + parameters.add("penaltyDiscount"); + return parameters; } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/score/ConditionalGaussianBicScore.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/score/ConditionalGaussianBicScore.java index f6b7eb1c0c..fb595895e0 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/score/ConditionalGaussianBicScore.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/score/ConditionalGaussianBicScore.java @@ -21,7 +21,11 @@ public class ConditionalGaussianBicScore implements ScoreWrapper, Experimental @Override public Score getScore(DataModel dataSet, Parameters parameters) { - return new ConditionalGaussianScore(DataUtils.getMixedDataSet(dataSet)); + final ConditionalGaussianScore conditionalGaussianScore + = new ConditionalGaussianScore(DataUtils.getMixedDataSet(dataSet)); + conditionalGaussianScore.setPenaltyDiscount(parameters.getDouble("penaltyDiscount")); + conditionalGaussianScore.setDenominatorMixed(parameters.getBoolean("assumeMixed")); + return conditionalGaussianScore; } @Override @@ -38,6 +42,8 @@ public DataType getDataType() { public List getParameters() { List parameters = new ArrayList<>(); parameters.add("penaltyDiscount"); + parameters.add("cgExact"); + parameters.add("assumeMixed"); return parameters; } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/ConditionalGaussianSimulation.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/ConditionalGaussianSimulation.java new file mode 100644 index 0000000000..99bc9ac403 --- /dev/null +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/ConditionalGaussianSimulation.java @@ -0,0 +1,448 @@ +package edu.cmu.tetrad.algcomparison.simulation; + +import edu.cmu.tetrad.algcomparison.graph.RandomGraph; +import edu.cmu.tetrad.bayes.BayesIm; +import edu.cmu.tetrad.bayes.BayesPm; +import edu.cmu.tetrad.bayes.MlBayesIm; +import edu.cmu.tetrad.data.*; +import edu.cmu.tetrad.graph.*; +import edu.cmu.tetrad.sem.*; +import edu.cmu.tetrad.util.*; + +import java.util.*; + +/** + * A simulation method based on the conditional Gaussian assumption. + * + * @author jdramsey + */ +public class ConditionalGaussianSimulation implements Simulation { + static final long serialVersionUID = 23L; + private RandomGraph randomGraph; + private List dataSets = new ArrayList<>(); + private List graphs = new ArrayList<>(); + private DataType dataType; + private List shuffledOrder; + private double varLow = 1; + private double varHigh = 3; + private double coefLow = 0.05; + private double coefHigh = 1.5; + private boolean coefSymmetric = true; + private double meanLow = -1; + private double meanHigh = 1; + + public ConditionalGaussianSimulation(RandomGraph graph) { + this.randomGraph = graph; + } + + @Override + public void createData(Parameters parameters) { + setVarLow(parameters.getDouble("varLow")); + setVarHigh(parameters.getDouble("varHigh")); + setCoefLow(parameters.getDouble("coefLow")); + setCoefHigh(parameters.getDouble("coefHigh")); + setCoefSymmetric(parameters.getBoolean("coefSymmetric")); + setMeanLow(parameters.getDouble("meanLow")); + setMeanHigh(parameters.getDouble("meanHigh")); + + double percentDiscrete = parameters.getDouble("percentDiscrete"); + + boolean discrete = parameters.getString("dataType").equals("discrete"); + boolean continuous = parameters.getString("dataType").equals("continuous"); + + if (discrete && percentDiscrete != 100.0) { + throw new IllegalArgumentException("To simulate discrete data, 'percentDiscrete' must be set to 0.0."); + } else if (continuous && percentDiscrete != 0.0) { + throw new IllegalArgumentException("To simulate continuoue data, 'percentDiscrete' must be set to 100.0."); + } + + if (discrete) this.dataType = DataType.Discrete; + if (continuous) this.dataType = DataType.Continuous; + + this.shuffledOrder = null; + + Graph graph = randomGraph.createGraph(parameters); + + dataSets = new ArrayList<>(); + graphs = new ArrayList<>(); + + for (int i = 0; i < parameters.getInt("numRuns"); i++) { + System.out.println("Simulating dataset #" + (i + 1)); + + if (parameters.getBoolean("differentGraphs") && i > 0) { + graph = randomGraph.createGraph(parameters); + } + + graphs.add(graph); + + DataSet dataSet = simulate(graph, parameters); + dataSet.setName("" + (i + 1)); + dataSets.add(dataSet); + } + } + + @Override + public Graph getTrueGraph(int index) { + return graphs.get(index); + } + + @Override + public DataSet getDataSet(int index) { + return dataSets.get(index); + } + + @Override + public String getDescription() { + return "Conditional Gaussian simulation using " + randomGraph.getDescription(); + } + + @Override + public List getParameters() { + List parameters = randomGraph.getParameters(); + parameters.add("numCategories"); + parameters.add("percentDiscrete"); + parameters.add("numRuns"); + parameters.add("differentGraphs"); + parameters.add("sampleSize"); + parameters.add("varLow"); + parameters.add("varHigh"); + parameters.add("coefLow"); + parameters.add("coefHigh"); + parameters.add("coefSymmetric"); + parameters.add("meanLow"); + parameters.add("meanHigh"); + return parameters; + } + + @Override + public int getNumDataSets() { + return dataSets.size(); + } + + @Override + public DataType getDataType() { + return dataType; + } + + private DataSet simulate(Graph G, Parameters parameters) { + HashMap nd = new HashMap<>(); + + List nodes = G.getNodes(); + + Collections.shuffle(nodes); + + if (this.shuffledOrder == null) { + List shuffledNodes = new ArrayList<>(nodes); + Collections.shuffle(shuffledNodes); + this.shuffledOrder = shuffledNodes; + } + + for (int i = 0; i < nodes.size(); i++) { + if (i < nodes.size() * parameters.getDouble("percentDiscrete") * 0.01) { + nd.put(shuffledOrder.get(i).getName(), parameters.getInt("numCategories")); + } else { + nd.put(shuffledOrder.get(i).getName(), 0); + } + } + + G = makeMixedGraph(G, nd); + nodes = G.getNodes(); + + DataSet mixedData = new BoxDataSet(new MixedDataBox(nodes, parameters.getInt("sampleSize")), nodes); + + List X = new ArrayList<>(); + List A = new ArrayList<>(); + + for (Node node : G.getNodes()) { + if (node instanceof ContinuousVariable) { + X.add(node); + } else { + A.add(node); + } + } + + Graph AG = G.subgraph(A); + Graph XG = G.subgraph(X); + + Map erstatzNodes = new HashMap<>(); + Map erstatzNodesReverse = new HashMap<>(); + + for (Node y : A) { + for (Node x : G.getParents(y)) { + if (x instanceof ContinuousVariable) { + DiscreteVariable ersatz = erstatzNodes.get(x); + + if (ersatz == null) { + ersatz = new DiscreteVariable("Ersatz_" + x.getName(), 3); + erstatzNodes.put((ContinuousVariable) x, ersatz); + erstatzNodesReverse.put(ersatz.getName(), (ContinuousVariable) x); + AG.addNode(ersatz); + } + + AG.addDirectedEdge(ersatz, y); + } + } + } + + BayesPm bayesPm = new BayesPm(AG); + BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + + SemPm semPm = new SemPm(XG); + + Map paramValues = new HashMap<>(); + + List tierOrdering = G.getCausalOrdering(); + + int[] tiers = new int[tierOrdering.size()]; + + for (int t = 0; t < tierOrdering.size(); t++) { + tiers[t] = nodes.indexOf(tierOrdering.get(t)); + } + + for (int mixedIndex : tiers) { + for (int i = 0; i < parameters.getInt("sampleSize"); i++) { + if (nodes.get(mixedIndex) instanceof DiscreteVariable) { + int bayesIndex = bayesIm.getNodeIndex(nodes.get(mixedIndex)); + + int[] bayesParents = bayesIm.getParents(bayesIndex); + int[] parentValues = new int[bayesParents.length]; + + for (int k = 0; k < parentValues.length; k++) { + int bayesParentColumn = bayesParents[k]; + + Node bayesParent = bayesIm.getVariables().get(bayesParentColumn); + DiscreteVariable _parent = (DiscreteVariable) bayesParent; + int value; + + ContinuousVariable orig = erstatzNodesReverse.get(_parent.getName()); + + if (orig != null) { + int mixedParentColumn = mixedData.getColumn(orig); + double d = mixedData.getDouble(i, mixedParentColumn); + double[] breakpoints = getBreakpoints(mixedData, _parent, mixedParentColumn); + value = breakpoints.length; + + for (int j = 0; j < breakpoints.length; j++) { + if (d < breakpoints[j]) { + value = j; + break; + } + } + } else { + int mixedColumn = mixedData.getColumn(bayesParent); + value = mixedData.getInt(i, mixedColumn); + } + + parentValues[k] = value; + } + + int rowIndex = bayesIm.getRowIndex(bayesIndex, parentValues); + double sum = 0.0; + + double r = RandomUtil.getInstance().nextDouble(); + + for (int k = 0; k < bayesIm.getNumColumns(bayesIndex); k++) { + double probability = bayesIm.getProbability(bayesIndex, rowIndex, k); + sum += probability; + + if (sum >= r) { + mixedData.setInt(i, mixedIndex, k); + break; + } + } + } else { + Node y = nodes.get(mixedIndex); + + Set discreteParents = new HashSet<>(); + Set continuousParents = new HashSet<>(); + + for (Node node : G.getParents(y)) { + if (node instanceof DiscreteVariable) { + discreteParents.add((DiscreteVariable) node); + } else { + continuousParents.add((ContinuousVariable) node); + } + } + + Parameter varParam = semPm.getParameter(y, y); + Parameter muParam = semPm.getMeanParameter(y); + + Combination varComb = new Combination(varParam); + Combination muComb = new Combination(muParam); + + for (DiscreteVariable v : discreteParents) { + varComb.addParamValue(v, mixedData.getInt(i, mixedData.getColumn(v))); + muComb.addParamValue(v, mixedData.getInt(i, mixedData.getColumn(v))); + } + + double value = RandomUtil.getInstance().nextNormal(0, getParamValue(varComb, paramValues)); + + for (Node x : continuousParents) { + Parameter coefParam = semPm.getParameter(x, y); + Combination coefComb = new Combination(coefParam); + + for (DiscreteVariable v : discreteParents) { + coefComb.addParamValue(v, mixedData.getInt(i, mixedData.getColumn(v))); + } + + int parent = nodes.indexOf(x); + double parentValue = mixedData.getDouble(i, parent); + double parentCoef = getParamValue(coefComb, paramValues); + value += parentValue * parentCoef; + } + + value += getParamValue(muComb, paramValues); + mixedData.setDouble(i, mixedIndex, value); + } + } + } + + return mixedData; + } + + private double[] getBreakpoints(DataSet mixedData, DiscreteVariable _parent, int mixedParentColumn) { + double[] data = new double[mixedData.getNumRows()]; + + for (int r = 0; r < mixedData.getNumRows(); r++) { + data[r] = mixedData.getDouble(r, mixedParentColumn); + } + + return Discretizer.getEqualFrequencyBreakPoints(data, _parent.getNumCategories()); + } + + private Double getParamValue(Combination values, Map map) { + Double d = map.get(values); + + if (d == null) { + Parameter parameter = values.getParameter(); + + if (parameter.getType() == ParamType.VAR) { + d = RandomUtil.getInstance().nextUniform(varLow, varHigh); + map.put(values, d); + } else if (parameter.getType() == ParamType.COEF) { + double min = coefLow; + double max = coefHigh; + double value = RandomUtil.getInstance().nextUniform(min, max); + d = RandomUtil.getInstance().nextUniform(0, 1) < 0.5 && coefSymmetric ? -value : value; + map.put(values, d); + } else if (parameter.getType() == ParamType.MEAN) { + d = RandomUtil.getInstance().nextUniform(meanLow, meanHigh); + map.put(values, d); + } + } + + return d; + } + + public void setVarLow(double varLow) { + this.varLow = varLow; + } + + public void setVarHigh(double varHigh) { + this.varHigh = varHigh; + } + + public void setCoefLow(double coefLow) { + this.coefLow = coefLow; + } + + public void setCoefHigh(double coefHigh) { + this.coefHigh = coefHigh; + } + + public void setCoefSymmetric(boolean coefSymmetric) { + this.coefSymmetric = coefSymmetric; + } + + public void setMeanLow(double meanLow) { + this.meanLow = meanLow; + } + + public void setMeanHigh(double meanHigh) { + this.meanHigh = meanHigh; + } + + private class Combination { + private Parameter parameter; + private Set paramValues; + + public Combination(Parameter parameter) { + this.parameter = parameter; + this.paramValues = new HashSet<>(); + } + + public void addParamValue(DiscreteVariable variable, int value) { + this.paramValues.add(new VariableValues(variable, value)); + } + + public int hashCode() { + return parameter.hashCode() + paramValues.hashCode(); + } + + public boolean equals(Object o) { + if (o == this) return true; + if (!(o instanceof Combination)) return false; + Combination v = (Combination) o; + return v.parameter == this.parameter && v.paramValues.equals(this.paramValues); + } + + public Parameter getParameter() { + return parameter; + } + } + + private class VariableValues { + private DiscreteVariable variable; + private int value; + + public VariableValues(DiscreteVariable variable, int value) { + this.variable = variable; + this.value = value; + } + + public DiscreteVariable getVariable() { + return variable; + } + + public int getValue() { + return value; + } + + public int hashCode() { + return variable.hashCode() + value; + } + + public boolean equals(Object o) { + if (o == this) return true; + if (!(o instanceof VariableValues)) return false; + VariableValues v = (VariableValues) o; + return v.variable.equals(this.variable) && v.value == this.value; + } + } + + private static Graph makeMixedGraph(Graph g, Map m) { + List nodes = g.getNodes(); + for (int i = 0; i < nodes.size(); i++) { + Node n = nodes.get(i); + int nL = m.get(n.getName()); + if (nL > 0) { + Node nNew = new DiscreteVariable(n.getName(), nL); + nodes.set(i, nNew); + } else { + Node nNew = new ContinuousVariable(n.getName()); + nodes.set(i, nNew); + } + } + + Graph outG = new EdgeListGraph(nodes); + + for (Edge e : g.getEdges()) { + Node n1 = e.getNode1(); + Node n2 = e.getNode2(); + Edge eNew = new Edge(outG.getNode(n1.getName()), outG.getNode(n2.getName()), e.getEndpoint1(), e.getEndpoint2()); + outG.addEdge(eNew); + } + + return outG; + } +} diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/LinearFisherModel.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/LinearFisherModel.java index 4b6e1c61ce..3ddf47878f 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/LinearFisherModel.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/LinearFisherModel.java @@ -2,15 +2,13 @@ import edu.cmu.tetrad.algcomparison.graph.RandomGraph; import edu.cmu.tetrad.algcomparison.utils.TakesData; -import edu.cmu.tetrad.data.DataModel; -import edu.cmu.tetrad.data.DataSet; -import edu.cmu.tetrad.data.DataType; -import edu.cmu.tetrad.data.Discretizer; +import edu.cmu.tetrad.data.*; import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.graph.Node; import edu.cmu.tetrad.sem.LargeScaleSimulation; import edu.cmu.tetrad.util.JOptionUtils; import edu.cmu.tetrad.util.Parameters; +import edu.cmu.tetrad.util.dist.Split; import javax.swing.*; import java.util.ArrayList; @@ -37,11 +35,11 @@ public LinearFisherModel(RandomGraph graph, List shocks) { this.randomGraph = graph; this.shocks = shocks; - JOptionPane.showMessageDialog(JOptionUtils.centeringComp(), - "The initial dataset you've provided will be used as initial shocks" + - "\nfor a Fisher model."); - if (shocks != null) { + JOptionPane.showMessageDialog(JOptionUtils.centeringComp(), + "The initial dataset you've provided will be used as initial shocks" + + "\nfor a Fisher model."); + for (DataModel _shocks : shocks) { if (_shocks == null) throw new NullPointerException("Dataset containing shocks must not be null."); DataSet dataSet = (DataSet) _shocks; @@ -86,14 +84,35 @@ public void createData(Parameters parameters) { simulator.setVarRange( parameters.getDouble("varLow"), parameters.getDouble("varHigh")); + simulator.setCoefSymmetric(parameters.getBoolean("coefSymmetric")); + simulator.setVerbose(parameters.getBoolean("verbose")); DataSet dataSet; +// if (shocks == null) { +// dataSet = simulator.simulateDataFisher( +// simulator.getUncorrelatedGaussianShocks(parameters.getInt("sampleSize")), +// parameters.getInt("intervalBetweenShocks"), +// parameters.getInt("intervalBetweenRecordings"), +// parameters.getDouble("fisherEpsilon") +// ); +// } else { +// DataSet _shocks = (DataSet) shocks.get(i); +// +// dataSet = simulator.simulateDataFisher( +// _shocks.getDoubleData().toArray(), +// parameters.getInt("intervalBetweenShocks"), +// parameters.getInt("intervalBetweenRecordings"), +// parameters.getDouble("fisherEpsilon") +// ); +// } + if (shocks == null) { dataSet = simulator.simulateDataFisher( - simulator.getSoCalledPoissonShocks(parameters.getInt("sampleSize")), parameters.getInt("intervalBetweenShocks"), + parameters.getInt("intervalBetweenRecordings"), + parameters.getInt("sampleSize"), parameters.getDouble("fisherEpsilon") ); } else { @@ -127,7 +146,7 @@ public void createData(Parameters parameters) { dataSet.setName(name); } - dataSets.add(dataSet); + dataSets.add(DataUtils.restrictToMeasured(dataSet)); } } @@ -161,12 +180,14 @@ public List getParameters() { parameters.add("varLow"); parameters.add("varHigh"); parameters.add("verbose"); + parameters.add("coefSymmetric"); parameters.add("numRuns"); parameters.add("percentDiscrete"); parameters.add("numCategories"); parameters.add("differentGraphs"); parameters.add("sampleSize"); parameters.add("intervalBetweenShocks"); + parameters.add("intervalBetweenRecordings"); parameters.add("fisherEpsilon"); return parameters; } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/LoadDataAndGraphs.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/LoadDataAndGraphs.java new file mode 100644 index 0000000000..fe6e5020e5 --- /dev/null +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/LoadDataAndGraphs.java @@ -0,0 +1,144 @@ +package edu.cmu.tetrad.algcomparison.simulation; + +import edu.cmu.tetrad.data.DataReader; +import edu.cmu.tetrad.data.DataSet; +import edu.cmu.tetrad.data.DataType; +import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.graph.GraphUtils; +import edu.cmu.tetrad.util.Parameters; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileReader; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +/** + * @author jdramsey + */ +public class LoadDataAndGraphs implements Simulation { + static final long serialVersionUID = 23L; + private String path; + private List graphs = new ArrayList<>(); + private List dataSets = new ArrayList<>(); + private List usedParameters = new ArrayList<>(); + + public LoadDataAndGraphs(String path) { + this.path = path; + } + + @Override + public void createData(Parameters parameters) { + this.dataSets = new ArrayList<>(); + + if (new File(path + "/data").exists()) { + int numDataSets = new File(path + "/data").listFiles().length; + + try { + for (int i = 0; i < numDataSets; i++) { + File file2 = new File(path + "/graph/graph." + (i + 1) + ".txt"); + System.out.println("Loading graph from " + file2.getAbsolutePath()); + this.graphs.add(GraphUtils.loadGraphTxt(file2)); + + GraphUtils.circleLayout(this.graphs.get(i), 225, 200, 150); + + File file1 = new File(path + "/data/data." + (i + 1) + ".txt"); + + System.out.println("Loading data from " + file1.getAbsolutePath()); + DataReader reader = new DataReader(); + reader.setVariablesSupplied(true); + reader.setMaxIntegralDiscrete(parameters.getInt("maxDistinctValuesDiscrete")); + dataSets.add(reader.parseTabular(file1)); + } + + File file = new File(path, "parameters.txt"); + BufferedReader r = new BufferedReader(new FileReader(file)); + + String line; + + while ((line = r.readLine()) != null) { + if (line.contains(" = ")) { + String[] tokens = line.split(" = "); + String key = tokens[0]; + String value = tokens[1]; + + try { + double _value = Double.parseDouble(value); + usedParameters.add(key); + parameters.set(key, _value); + } catch (NumberFormatException e) { + usedParameters.add(key); + parameters.set(key, value); + } + } + } + + parameters.set("numRuns", numDataSets); + } catch (IOException e) { + e.printStackTrace(); + } + } + } + + @Override + public Graph getTrueGraph(int index) { + return graphs.get(index); + } + + @Override + public DataSet getDataSet(int index) { + return dataSets.get(index); + } + + @Override + public String getDescription() { + try { + File file = new File(path, "parameters.txt"); + BufferedReader r = new BufferedReader(new FileReader(file)); + + StringBuilder b = new StringBuilder(); + b.append("Load data sets and graphs from a directory.").append("\n\n"); + String line; + + while ((line = r.readLine()) != null) { + if (line.trim().isEmpty()) continue; + b.append(line).append("\n"); + } + + return b.toString(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public List getParameters() { + return usedParameters; + } + + @Override + public int getNumDataSets() { + return dataSets.size(); + } + + @Override + public DataType getDataType() { + boolean continuous = false; + boolean discrete = false; + boolean mixed = false; + + for (DataSet dataSet : dataSets) { + if (dataSet.isContinuous()) continuous = true; + if (dataSet.isDiscrete()) discrete = true; + if (dataSet.isMixed()) mixed = true; + } + + if (mixed) return DataType.Mixed; + else if (continuous && discrete) return DataType.Mixed; + else if (continuous) return DataType.Continuous; + else if (discrete) return DataType.Discrete; + + return DataType.Mixed; + } +} diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/TimeSeriesSemSimulation.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/TimeSeriesSemSimulation.java index d5a794ebdc..75d5f19c38 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/TimeSeriesSemSimulation.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/TimeSeriesSemSimulation.java @@ -41,14 +41,14 @@ public void createData(Parameters parameters) { graphs = new ArrayList<>(); Graph graph = randomGraph.createGraph(parameters); - graph = TimeSeriesUtils.GraphToLagGraph(graph); + graph = TimeSeriesUtils.graphToLagGraph(graph); topToBottomLayout((TimeLagGraph) graph); this.knowledge = TimeSeriesUtils.getKnowledge(graph); for (int i = 0; i < parameters.getInt("numRuns"); i++) { if (parameters.getBoolean("differentGraphs") && i > 0) { graph = randomGraph.createGraph(parameters); - graph = TimeSeriesUtils.GraphToLagGraph(graph); + graph = TimeSeriesUtils.graphToLagGraph(graph); topToBottomLayout((TimeLagGraph) graph); } @@ -91,6 +91,7 @@ public void createData(Parameters parameters) { } //else System.out.println("Coefficient matrix is stable."); dataSet.setName("" + (i + 1)); + dataSet.setKnowledge(knowledge.copy()); dataSets.add(dataSet); } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/TwoCyclePrecision.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/TwoCyclePrecision.java new file mode 100644 index 0000000000..85a204d908 --- /dev/null +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/TwoCyclePrecision.java @@ -0,0 +1,41 @@ +package edu.cmu.tetrad.algcomparison.statistic; + +import edu.cmu.tetrad.algcomparison.statistic.utils.ArrowConfusion; +import edu.cmu.tetrad.graph.Graph; + +/** + * The 2-cycle precision. This counts 2-cycles manually, wherever they occur in the graphs. + * The true positives are the number of 2-cycles in both the true and estimated graphs. + * Thus, if the true does not contains X->Y,Y->X and estimated graph does contain it, + * one false positive is counted. + * + * @author jdramsey, rubens (November 2016) + */ +public class TwoCyclePrecision implements Statistic { + static final long serialVersionUID = 23L; + + @Override + public String getAbbreviation() { + return "2CP"; + } + + @Override + public String getDescription() { + return "2-cycle precision"; + } + + @Override + public double getValue(Graph trueGraph, Graph estGraph) { + ArrowConfusion adjConfusion = new ArrowConfusion(trueGraph, estGraph); + double TwoCycleTp = adjConfusion.getTwoCycleTp(); + double TwoCycleFp = adjConfusion.getTwoCycleFp(); + double den = TwoCycleTp + TwoCycleFp; + return TwoCycleTp / den; + + } + + @Override + public double getNormValue(double value) { + return value; + } +} diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/TwoCycleRecall.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/TwoCycleRecall.java new file mode 100644 index 0000000000..be2a3e1704 --- /dev/null +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/TwoCycleRecall.java @@ -0,0 +1,41 @@ +package edu.cmu.tetrad.algcomparison.statistic; + +import edu.cmu.tetrad.algcomparison.statistic.utils.ArrowConfusion; +import edu.cmu.tetrad.graph.Graph; + +/** + * The 2-cycle recall. This counts 2-cycles manually, wherever they occur in the graphs. + * The true positives are the number of 2-cycles in both the true and estimated graphs. + * Thus, if the true contains X->Y,Y->X and estimated graph does not contain it, one false negative + * is counted. + * + * @author jdramsey, rubens (November 2016) + */ +public class TwoCycleRecall implements Statistic { + static final long serialVersionUID = 23L; + + @Override + public String getAbbreviation() { + return "2CR"; + } + + @Override + public String getDescription() { + return "2-cycle recall"; + } + + @Override + public double getValue(Graph trueGraph, Graph estGraph) { + ArrowConfusion adjConfusion = new ArrowConfusion(trueGraph, estGraph); + double TwoCycleTp = adjConfusion.getTwoCycleTp(); + double TwoCycleFn = adjConfusion.getTwoCycleFn(); + double den = TwoCycleTp + TwoCycleFn; + return TwoCycleTp / den; + + } + + @Override + public double getNormValue(double value) { + return value; + } +} diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/utils/ArrowConfusion.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/utils/ArrowConfusion.java index a48995a10e..dea5698a48 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/utils/ArrowConfusion.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/utils/ArrowConfusion.java @@ -5,6 +5,7 @@ import edu.cmu.tetrad.graph.Graph; import java.util.HashSet; +import java.util.List; import java.util.Set; /** @@ -12,15 +13,19 @@ * A true positive arrow is counted for X*->Y in the estimated graph if X is not adjacent * to Y or X--Y or X<--Y. * - * @author jdramsey + * @author jdramsey, rubens (November, 2016) */ public class ArrowConfusion { + private Graph truth; private Graph est; private int arrowsTp; private int arrowsFp; private int arrowsFn; private int arrowsTn; + private int TCtp; + private int TCfn; + private int TCfp; public ArrowConfusion(Graph truth, Graph est) { this.truth = truth; @@ -28,23 +33,48 @@ public ArrowConfusion(Graph truth, Graph est) { arrowsTp = 0; arrowsFp = 0; arrowsFn = 0; + TCtp = 0; //for the two-cycle accuracy + TCfn = 0; + TCfp = 0; + + - Set allOriented = new HashSet<>(); - allOriented.addAll(this.truth.getEdges()); - allOriented.addAll(this.est.getEdges()); + // Get edges from the true Graph to compute TruePositives, TrueNegatives and FalseNeagtives + // System.out.println(this.truth.getEdges()); - for (Edge edge : allOriented) { - Edge edge1 = this.est.getEdge(edge.getNode1(), edge.getNode2()); + for (Edge edge : this.truth.getEdges()) { + + List edges1 = this.est.getEdges(edge.getNode1(), edge.getNode2()); + Edge edge1; + + if (edges1.size() == 1) { + edge1 = edges1.get(0); + } else { + edge1 = this.est.getDirectedEdge(edge.getNode1(), edge.getNode2()); + } + + // System.out.println(edge1 + "(est)"); Endpoint e1Est = null; Endpoint e2Est = null; if (edge1 != null) { - e1Est = edge.getProximalEndpoint(edge.getNode1()); - e2Est = edge.getProximalEndpoint(edge.getNode2()); + e1Est = edge1.getProximalEndpoint(edge.getNode1()); + e2Est = edge1.getProximalEndpoint(edge.getNode2()); } + // System.out.println(e1Est); + // System.out.println(e2Est); + + List edges2 = this.truth.getEdges(edge.getNode1(), edge.getNode2()); + Edge edge2; - Edge edge2 = this.truth.getEdge(edge.getNode1(), edge.getNode2()); + if (edges2.size() == 1) { + edge2 = edges2.get(0); + } else { + edge2 = this.truth.getDirectedEdge(edge.getNode1(), edge.getNode2()); + } + + // System.out.println(edge2 + "(truth)"); Endpoint e1True = null; Endpoint e2True = null; @@ -53,21 +83,9 @@ public ArrowConfusion(Graph truth, Graph est) { e1True = edge2.getProximalEndpoint(edge.getNode1()); e2True = edge2.getProximalEndpoint(edge.getNode2()); } + // System.out.println(e1True); + // System.out.println(e2True); - edge = this.est.getEdge(edge.getNode1(), edge.getNode2()); - - if (edge != null) { - e1Est = edge.getProximalEndpoint(edge.getNode1()); - e2Est = edge.getProximalEndpoint(edge.getNode2()); - } - - if (e1Est == Endpoint.ARROW && e1True != Endpoint.ARROW) { - arrowsFp++; - } - - if (e2Est == Endpoint.ARROW && e2True != Endpoint.ARROW) { - arrowsFp++; - } if (e1True == Endpoint.ARROW && e1Est != Endpoint.ARROW) { arrowsFn++; @@ -92,12 +110,118 @@ public ArrowConfusion(Graph truth, Graph est) { if (e2True != Endpoint.ARROW && e2Est != Endpoint.ARROW) { arrowsTn++; } + + } +// Get edges from the estimated graph to compute only FalsePositives + // System.out.println(this.est.getEdges()); + + for (Edge edge : this.est.getEdges()) { + + List edges1 = this.est.getEdges(edge.getNode1(), edge.getNode2()); + Edge edge1; + + if (edges1.size() == 1) { + edge1 = edges1.get(0); + } else { + edge1 = this.est.getDirectedEdge(edge.getNode1(), edge.getNode2()); + } + // System.out.println(edge1 + "(est)"); + + Endpoint e1Est = null; + Endpoint e2Est = null; + + if (edge1 != null) { + e1Est = edge1.getProximalEndpoint(edge.getNode1()); + e2Est = edge1.getProximalEndpoint(edge.getNode2()); + } + // System.out.println(e1Est); + // System.out.println(e2Est); + + List edges2 = this.truth.getEdges(edge.getNode1(), edge.getNode2()); + Edge edge2; + + if (edges2.size() == 1) { + edge2 = edges2.get(0); + } else { + edge2 = this.truth.getDirectedEdge(edge.getNode1(), edge.getNode2()); + } + + // System.out.println(edge2 + "(truth)"); + + Endpoint e1True = null; + Endpoint e2True = null; + + if (edge2 != null) { + e1True = edge2.getProximalEndpoint(edge.getNode1()); + e2True = edge2.getProximalEndpoint(edge.getNode2()); + } + // System.out.println(e1True); + // System.out.println(e2True); + + + if (e1Est == Endpoint.ARROW && e1True != Endpoint.ARROW) { + arrowsFp++; + } + + if (e2Est == Endpoint.ARROW && e2True != Endpoint.ARROW) { + arrowsFp++; + } + + + } + + + // test for 2-cycle + //Set allOriented = new HashSet<>(); + //allOriented.addAll(this.truth.getEdges()); + //allOriented.addAll(this.est.getEdges()); + + for (Edge edge : this.truth.getEdges()) { + + + + List TwoCycle1 = this.truth.getEdges(edge.getNode1(), edge.getNode2()); + List TwoCycle2 = this.est.getEdges(edge.getNode1(), edge.getNode2()); + + if (TwoCycle1.size() == 2 && TwoCycle2.size() == 2) { + // System.out.println("2-cycle correctly inferred " + TwoCycle1); + TCtp++; + } + + if (TwoCycle1.size() == 2 && TwoCycle2.size() != 2) { + // System.out.println("2-cycle not inferred " + TwoCycle1); + TCfn++; + } + } + + for (Edge edge : this.est.getEdges()) { + + List TwoCycle1 = this.truth.getEdges(edge.getNode1(), edge.getNode2()); + List TwoCycle2 = this.est.getEdges(edge.getNode1(), edge.getNode2()); + + if (TwoCycle1.size() != 2 && TwoCycle2.size() == 2) { + // System.out.println("2-cycle falsely inferred" + TwoCycle2); + TCfp++; + } + } + + /* System.out.println(arrowsTp); + System.out.println(arrowsTn); + System.out.println(arrowsFn); + System.out.println(arrowsFp); +*/ + //divide by 2, the 2cycle accuracy is duplicated due to how getEdges is used + TCtp = TCtp / 2; + TCfn = TCfn / 2; + TCfp = TCfp / 2; + // System.out.println(TCtp); + // System.out.println(TCfn); + // System.out.println(TCfp); -// int allEdges = this.truth.getNumNodes() * (this.truth.getNumNodes() - 1) / 2; -// arrowsTn = allEdges - arrowsFn; } + public int getArrowsTp() { return arrowsTp; } @@ -114,4 +238,17 @@ public int getArrowsTn() { return arrowsTn; } + public int getTwoCycleTp() { + return TCtp; + } + + public int getTwoCycleFp() { + return TCfp; + } + + public int getTwoCycleFn() { + return TCfn; + } + + } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/cmd/TetradCmd.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/cmd/TetradCmd.java index 27d8fd400d..668f6bbacd 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/cmd/TetradCmd.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/cmd/TetradCmd.java @@ -505,8 +505,8 @@ private void runAlgorithm() { runCfci(); } else if ("ccd".equalsIgnoreCase(algorithmName)) { runCcd(); - } else if ("fgs".equalsIgnoreCase(algorithmName)) { - runFgs(); + } else if ("fges".equalsIgnoreCase(algorithmName)) { + runFges(); } else if ("bayes_est".equalsIgnoreCase(algorithmName)) { runBayesEst(); } else if ("fofc".equalsIgnoreCase(algorithmName)) { @@ -638,13 +638,13 @@ private void runPcStable() { writeGraph(resultGraph); } - private void runFgs() { + private void runFges() { if (this.data == null && this.covarianceMatrix == null) { throw new IllegalStateException("Data did not load correctly."); } if (verbose) { - systemPrint("FGS"); + systemPrint("FGES"); systemPrint(getKnowledge().toString()); systemPrint(getVariables().toString()); @@ -657,12 +657,12 @@ private void runFgs() { TetradLogger.getInstance().log("info", "Testing it."); } - Fgs fgs; + Fges fges; if (useCovariance) { - SemBicScore fgsScore = new SemBicScore(covarianceMatrix); - fgsScore.setPenaltyDiscount(penaltyDiscount); - fgs = new Fgs(fgsScore); + SemBicScore fgesScore = new SemBicScore(covarianceMatrix); + fgesScore.setPenaltyDiscount(penaltyDiscount); + fges = new Fges(fgesScore); } else { if (data.isDiscrete()) { @@ -670,24 +670,24 @@ private void runFgs() { score.setSamplePrior(samplePrior); score.setStructurePrior(structurePrior); - fgs = new Fgs(score); + fges = new Fges(score); } else if (data.isContinuous()) { SemBicScore score = new SemBicScore(new CovarianceMatrixOnTheFly(data)); score.setPenaltyDiscount(penaltyDiscount); - fgs = new Fgs(score); + fges = new Fges(score); } else { throw new IllegalArgumentException(); } } if (initialGraph != null) { - fgs.setInitialGraph(initialGraph); + fges.setInitialGraph(initialGraph); } - fgs.setKnowledge(getKnowledge()); + fges.setKnowledge(getKnowledge()); // Convert back to Graph.. - Graph resultGraph = fgs.search(); + Graph resultGraph = fges.search(); // PrintUtil outputStreamPath problem and graphs. outPrint("\nResult graph:"); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/CovarianceMatrixOnTheFly.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/CovarianceMatrixOnTheFly.java index 3efc50145c..72457379e9 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/CovarianceMatrixOnTheFly.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/CovarianceMatrixOnTheFly.java @@ -467,7 +467,7 @@ public final void setSampleSize(int sampleSize) { * @return the size of the square matrix. */ public final int getSize() { - return matrix.rows(); + return getVariables().size(); } /** diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/DataReader.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/DataReader.java index 414f2e7151..1e246bfd7d 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/DataReader.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/DataReader.java @@ -121,6 +121,7 @@ public DataReader() { /** * Lines beginning with blanks or this marker will be skipped. */ + @Override public void setCommentMarker(String commentMarker) { if (commentMarker == null) { throw new NullPointerException("Cannot be null."); @@ -132,6 +133,7 @@ public void setCommentMarker(String commentMarker) { /** * This is the delimiter used to parse the data. Default is whitespace. */ + @Override public void setDelimiter(DelimiterType delimiterType) { if (delimiterType == null) { throw new NullPointerException("Cannot be null."); @@ -143,6 +145,7 @@ public void setDelimiter(DelimiterType delimiterType) { /** * Text between matched ones of these will treated as quoted text. */ + @Override public void setQuoteChar(char quoteChar) { this.quoteChar = quoteChar; } @@ -151,6 +154,7 @@ public void setQuoteChar(char quoteChar) { * Will read variable names from the first row if this is true; otherwise, * will make make up variables in the series X1, x2, ... Xn. */ + @Override public void setVariablesSupplied(boolean varNamesSupplied) { this.varNamesSupplied = varNamesSupplied; } @@ -158,6 +162,7 @@ public void setVariablesSupplied(boolean varNamesSupplied) { /** * If true, a column of ID's is supplied; otherwise, not. */ + @Override public void setIdsSupplied(boolean caseIdsPresent) { this.idsSupplied = caseIdsPresent; } @@ -166,6 +171,7 @@ public void setIdsSupplied(boolean caseIdsPresent) { * If null, ID's are in an unlabeled first column; otherwise, they are in * the column with the given label. */ + @Override public void setIdLabel(String caseIdsLabel) { this.idLabel = caseIdsLabel; } @@ -174,6 +180,7 @@ public void setIdLabel(String caseIdsLabel) { * Tokens that are blank or equal to this value will be counted as missing * values. */ + @Override public void setMissingValueMarker(String missingValueMarker) { if (missingValueMarker == null) { throw new NullPointerException("Cannot be null."); @@ -186,6 +193,7 @@ public void setMissingValueMarker(String missingValueMarker) { * Integral columns with up to this number of discrete values will be * treated as discrete. */ + @Override public void setMaxIntegralDiscrete(int maxIntegralDiscrete) { if (maxIntegralDiscrete < -1) { throw new IllegalArgumentException( @@ -199,6 +207,7 @@ public void setMaxIntegralDiscrete(int maxIntegralDiscrete) { * The known variables for a given name will usurp guess the variable by * that name. */ + @Override public void setKnownVariables(List knownVariables) { if (knownVariables == null) { throw new NullPointerException(); @@ -214,6 +223,7 @@ public void setKnownVariables(List knownVariables) { * @throws IOException if the file cannot be read. // * @deprecated use the data readers from edu.cmu.tetrad.io package // Can't deprecate this yet. */ + @Override public DataSet parseTabular(File file) throws IOException { FileReader reader = null, reader2 = null; @@ -256,6 +266,7 @@ public DataSet parseTabular(File file) throws IOException { * RectangularDataSet if successful. Log messages are written to the * LogUtils log; to view them, add System.out to that. */ + @Override public DataSet parseTabular(char[] chars) { // Do first pass to get a description of the file. @@ -570,6 +581,7 @@ private DataSet doSecondTabularPass(DataSetDescription description, Reader reade * * @throws IOException if the file cannot be read. */ + @Override public ICovarianceMatrix parseCovariance(File file) throws IOException { FileReader reader = null; @@ -607,6 +619,7 @@ public ICovarianceMatrix parseCovariance(File file) throws IOException { * new FileReader(file), " \t", "//"); * The initial "/covariance" is optional. */ + @Override public ICovarianceMatrix parseCovariance(char[] chars) { // Do first pass to get a description of the file. @@ -745,6 +758,7 @@ private ICovarianceMatrix doCovariancePass(Reader reader) { * Loads knowledge from a file. Assumes knowledge is the only thing in the * file. No jokes please. :) */ + @Override public IKnowledge parseKnowledge(File file) throws IOException { FileReader reader = new FileReader(file); Lineizer lineizer = new Lineizer(reader, commentMarker); @@ -757,6 +771,7 @@ public IKnowledge parseKnowledge(File file) throws IOException { * Parses knowledge from the char array, assuming that's all there is in the * char array. */ + @Override public IKnowledge parseKnowledge(char[] chars) { CharArrayReader reader = new CharArrayReader(chars); Lineizer lineizer = new Lineizer(reader, commentMarker); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/DataUtils.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/DataUtils.java index c0463cb6e1..36d0ae9e8f 100755 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/DataUtils.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/DataUtils.java @@ -22,7 +22,6 @@ package edu.cmu.tetrad.data; import cern.colt.list.DoubleArrayList; -import com.sun.nio.sctp.IllegalReceiveException; import edu.cmu.tetrad.graph.*; import edu.cmu.tetrad.util.*; import org.apache.commons.math3.distribution.NormalDistribution; @@ -1104,6 +1103,8 @@ public static TetradMatrix covDemeaned(TetradMatrix data) { } public static TetradMatrix cov(TetradMatrix data) { + + for (int j = 0; j < data.columns(); j++) { double sum = 0.0; diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Dag.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Dag.java index 175f8e2d37..645c925371 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Dag.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Dag.java @@ -62,6 +62,9 @@ public final class Dag implements Graph { private Map nodesHash = new HashMap<>(); + private boolean pag; + private boolean pattern; + //===============================CONSTRUCTORS=======================// /** @@ -693,6 +696,26 @@ public List getTriplesClassificationTypes() { public List> getTriplesLists(Node node) { return null; } + + @Override + public boolean isPag() { + return pag; + } + + @Override + public void setPag(boolean pag) { + this.pag = pag; + } + + @Override + public boolean isPattern() { + return pattern; + } + + @Override + public void setPattern(boolean pattern) { + this.pattern = pattern; + } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Edge.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Edge.java index 3c8f137e20..30ed064e28 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Edge.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Edge.java @@ -23,6 +23,7 @@ import edu.cmu.tetrad.util.TetradSerializable; +import java.awt.*; import java.io.IOException; import java.io.ObjectInputStream; @@ -37,27 +38,16 @@ */ public class Edge implements TetradSerializable, Comparable { static final long serialVersionUID = 23L; - - /** - * @serial - */ private Node node1; - - /** - * @serial - */ private Node node2; - - /** - * @serial - */ private Endpoint endpoint1; - - /** - * @serial - */ private Endpoint endpoint2; + // Usual coloring--set to something else for a special line color. + private transient Color lineColor = null; + + private boolean dashed = false; + //=========================CONSTRUCTORS============================// /** @@ -95,6 +85,7 @@ public Edge(Node node1, Node node2, Endpoint endpoint1, public Edge(Edge edge) { this(edge.node1, edge.node2, edge.endpoint1, edge.endpoint2); + this.lineColor = edge.getLineColor(); } /** @@ -366,6 +357,24 @@ private void readObject(ObjectInputStream s) public boolean isNull() { return endpoint1 == Endpoint.NULL && endpoint2 == Endpoint.NULL; } + + public Color getLineColor() { + return this.lineColor; + } + + public void setLineColor(Color lineColor) { + if (lineColor != null) { + this.lineColor = lineColor; + } + } + + public boolean isDashed() { + return dashed; + } + + public void setDashed(boolean dashed) { + this.dashed = dashed; + } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/EdgeListGraph.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/EdgeListGraph.java index d1e02bed98..4dead4bf4f 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/EdgeListGraph.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/EdgeListGraph.java @@ -26,6 +26,7 @@ import java.io.IOException; import java.io.ObjectInputStream; import java.util.*; +import java.util.concurrent.ConcurrentHashMap; import static edu.cmu.tetrad.graph.Edges.directedEdge; @@ -57,14 +58,14 @@ public class EdgeListGraph implements Graph, TripleClassifier { * * @serial */ - protected Set edgesSet; + Set edgesSet; /** * Map from each node to the List of edges connected to that node. * * @serial */ - protected Map> edgeLists; + Map> edgeLists; /** * Fires property change events. @@ -75,34 +76,38 @@ public class EdgeListGraph implements Graph, TripleClassifier { * Set of ambiguous triples. Note the name can't be changed due to * serialization. */ - protected Set ambiguousTriples = new HashSet<>(); + protected Set ambiguousTriples = Collections.newSetFromMap(new ConcurrentHashMap()); /** * @serial */ - protected Set underLineTriples = new HashSet<>(); + Set underLineTriples = Collections.newSetFromMap(new ConcurrentHashMap()); /** * @serial */ - protected Set dottedUnderLineTriples = new HashSet<>(); + Set dottedUnderLineTriples = Collections.newSetFromMap(new ConcurrentHashMap()); /** * True iff nodes were removed since the last call to an accessor for ambiguous, underline, or dotted underline * triples. If there are triples in the lists involving removed nodes, these need to be removed from the lists * first, so as not to cause confusion. */ - protected boolean stuffRemovedSinceLastTripleAccess = false; + boolean stuffRemovedSinceLastTripleAccess = false; /** * The set of highlighted edges. */ - protected Set highlightedEdges = new HashSet<>(); + Set highlightedEdges = new HashSet<>(); /** * A hash from node names to nodes; */ - protected Map namesHash = new HashMap<>(); + Map namesHash = new HashMap<>(); + + private boolean pattern = false; + + private boolean pag = false; //==============================CONSTUCTORS===========================// @@ -142,7 +147,6 @@ public EdgeListGraph(Graph graph) throws IllegalArgumentException { this.underLineTriples = graph.getUnderLines(); this.dottedUnderLineTriples = graph.getDottedUnderlines(); - for (Edge edge : graph.getEdges()) { if (graph.isHighlighted(edge)) { setHighlighted(edge, true); @@ -152,6 +156,9 @@ public EdgeListGraph(Graph graph) throws IllegalArgumentException { for (Node node : nodes) { namesHash.put(node.getName(), node); } + + this.pag = graph.isPag(); + this.pattern = graph.isPattern(); } /** @@ -204,6 +211,8 @@ public static Graph shallowCopy(EdgeListGraph graph) { _graph.stuffRemovedSinceLastTripleAccess = graph.stuffRemovedSinceLastTripleAccess; _graph.highlightedEdges = new HashSet<>(graph.highlightedEdges); _graph.namesHash = new HashMap<>(graph.namesHash); + _graph.pag = graph.pag; + _graph.pattern = graph.pattern; return _graph; } @@ -720,11 +729,37 @@ public boolean isDSeparatedFrom(List x, List y, List z) { return !isDConnectedTo(x, y, z); } - protected static class Pair { + /** + * True if this graph has been stamped as a pattern. The search algorithm should do this. + */ + @Override + public boolean isPattern() { + return pattern; + } + + @Override + public void setPattern(boolean pattern) { + this.pattern = pattern; + } + + /** + * True if this graph has been "stamped" as a PAG. The search algorithm should do this. + */ + @Override + public boolean isPag() { + return pag; + } + + @Override + public void setPag(boolean pag) { + this.pag = pag; + } + + private static class Pair { private Node x; private Node y; - public Pair(Node x, Node y) { + Pair(Node x, Node y) { this.x = x; this.y = y; } @@ -959,14 +994,17 @@ public boolean isExogenous(Node node) { */ public List getAdjacentNodes(Node node) { List edges = edgeLists.get(node); - List adj = new ArrayList<>(edges.size()); + Set adj = new HashSet<>(edges.size()); for (Edge edge : edges) { if (edge == null) continue; - adj.add(edge.getDistalNode(node)); + Node z = edge.getDistalNode(node); + if (!adj.contains(z)) { + adj.add(z); + } } - return adj; + return new ArrayList<>(adj); } /** @@ -1597,7 +1635,8 @@ public void addUnderlineTriple(Node x, Node y, Node z) { Triple triple = new Triple(x, y, z); if (!triple.alongPathIn(this)) { - throw new IllegalArgumentException("<" + x + ", " + y + ", " + z + "> must lie along a path in the graph."); + return; +// throw new IllegalArgumentException("<" + x + ", " + y + ", " + z + "> must lie along a path in the graph."); } underLineTriples.add(new Triple(x, y, z)); @@ -1607,7 +1646,8 @@ public void addDottedUnderlineTriple(Node x, Node y, Node z) { Triple triple = new Triple(x, y, z); if (!triple.alongPathIn(this)) { - throw new IllegalArgumentException("<" + x + ", " + y + ", " + z + "> must lie along a path in the graph."); + return; +// throw new IllegalArgumentException("<" + x + ", " + y + ", " + z + "> must lie along a path in the graph."); } dottedUnderLineTriples.add(triple); @@ -1704,7 +1744,7 @@ public void removeTriplesNotInGraph() { } - protected void collectAncestorsVisit(Node node, Set ancestors) { + private void collectAncestorsVisit(Node node, Set ancestors) { if (ancestors.contains(node)) return; ancestors.add(node); @@ -1717,7 +1757,7 @@ protected void collectAncestorsVisit(Node node, Set ancestors) { } } - protected void collectDescendantsVisit(Node node, Set descendants) { + private void collectDescendantsVisit(Node node, Set descendants) { descendants.add(node); List children = getChildren(node); @@ -1732,7 +1772,7 @@ protected void collectDescendantsVisit(Node node, Set descendants) { /** * closure under the child relation */ - protected void doChildClosureVisit(Node node, Set closure) { + private void doChildClosureVisit(Node node, Set closure) { if (!closure.contains(node)) { closure.add(node); @@ -1757,7 +1797,7 @@ protected void doChildClosureVisit(Node node, Set closure) { * @param closure the closure of the conditioning set uner the parent * relation (to be calculated recursively). */ - protected void doParentClosureVisit(Node node, Set closure) { + private void doParentClosureVisit(Node node, Set closure) { if (closure.contains(node)) return; closure.add(node); @@ -1782,7 +1822,7 @@ protected PropertyChangeSupport getPcs() { /** * @return true iff there is a directed path from node1 to node2. */ - protected boolean existsUndirectedPathVisit(Node node1, Node node2, Set path) { + boolean existsUndirectedPathVisit(Node node1, Node node2, Set path) { path.add(node1); for (Edge edge : getEdges(node1)) { @@ -1809,7 +1849,7 @@ protected boolean existsUndirectedPathVisit(Node node1, Node node2, Set pa return false; } - protected boolean existsDirectedPathVisit(Node node1, Node node2, Set path) { + boolean existsDirectedPathVisit(Node node1, Node node2, Set path) { path.add(node1); for (Edge edge : getEdges(node1)) { @@ -1839,7 +1879,7 @@ protected boolean existsDirectedPathVisit(Node node1, Node node2, Set path /** * @return true iff there is a semi-directed path from node1 to node2 */ - protected boolean existsSemiDirectedPathVisit(Node node1, Set nodes2, + private boolean existsSemiDirectedPathVisit(Node node1, Set nodes2, LinkedList path) { path.addLast(node1); @@ -1951,6 +1991,7 @@ public List getTriplesClassificationTypes() { List names = new ArrayList<>(); names.add("Underlines"); names.add("Dotted Underlines"); + names.add("Ambiguous Triples"); return names; } @@ -1962,6 +2003,7 @@ public List> getTriplesLists(Node node) { List> triplesList = new ArrayList<>(); triplesList.add(GraphUtils.getUnderlinedTriplesFromGraph(node, this)); triplesList.add(GraphUtils.getDottedUnderlinedTriplesFromGraph(node, this)); + triplesList.add(GraphUtils.getAmbiguousTriplesFromGraph(node, this)); return triplesList; } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/EndpointMatrixGraph.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/EndpointMatrixGraph.java index c8d4df584b..dc2bfb0fe0 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/EndpointMatrixGraph.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/EndpointMatrixGraph.java @@ -88,6 +88,9 @@ public class EndpointMatrixGraph implements Graph { private HashMap endpointsToShorts; private int numEdges = 0; + private boolean pag; + private boolean pattern; + //==============================CONSTUCTORS===========================// /** @@ -716,6 +719,26 @@ public List> getTriplesLists(Node node) { return null; } + @Override + public boolean isPag() { + return pag; + } + + @Override + public void setPag(boolean pag) { + this.pag = pag; + } + + @Override + public boolean isPattern() { + return pattern; + } + + @Override + public void setPattern(boolean pattern) { + this.pattern = pattern; + } + private static class Pair { private Node x; private Node y; diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/FruchtermanReingoldLayout.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/FruchtermanReingoldLayout.java index d0bd1822f2..af5110cd6e 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/FruchtermanReingoldLayout.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/FruchtermanReingoldLayout.java @@ -57,7 +57,7 @@ public final class FruchtermanReingoldLayout { /** * Optimal distance between vertices. */ - private double optimalDistance; + private double optimalDistance = 100; /** * Temperature. @@ -82,6 +82,8 @@ public FruchtermanReingoldLayout(Graph graph) { //============================PUBLIC METHODS==========================// public void doLayout() { + GraphUtils.circleLayout(graph, 300, 300, 200); + List> components = GraphUtils.connectedComponents(this.graph()); @@ -112,7 +114,7 @@ private void layoutComponent(List nodes) { //pos[i][1] = RandomUtil.nextInt(600); } - List edges = new ArrayList<>(graph().getEdges()); + List edges = new ArrayList<>(GraphUtils.undirectedGraph(graph()).getEdges()); for (Iterator i = edges.iterator(); i.hasNext(); ) { Edge edge = i.next(); @@ -132,15 +134,17 @@ private void layoutComponent(List nodes) { this.edges()[i][1] = u; } - setOptimalDistance(60.0); + double avgDegree = 2 * graph.getNumEdges() / graph.getNumNodes(); + + setOptimalDistance(20.0 + 20.0 * avgDegree); setTemperature(5.0); for (int i = 0; i < numIterations(); i++) { // Calculate repulsive forces. for (int v = 0; v < numNodes; v++) { - nodeDisposition()[v][0] = 0.; - nodeDisposition()[v][1] = 0.; + nodeDisposition()[v][0] = 0.1; + nodeDisposition()[v][1] = 0.1; for (int u = 0; u < numNodes; u++) { double deltaX = nodePosition()[u][0] - nodePosition()[v][0]; @@ -149,12 +153,13 @@ private void layoutComponent(List nodes) { double norm = norm(deltaX, deltaY); if (norm == 0.0) { - continue; - } - - if (norm > 4.0 * optimalDistance()) { - continue; + norm = 0.1; +// continue; } +// +// if (norm > 4.0 * getOptimalDistance()) { +// continue; +// } double repulsiveForce = fr(norm); @@ -174,12 +179,13 @@ private void layoutComponent(List nodes) { double norm = norm(deltaX, deltaY); if (norm == 0.0) { - continue; + norm = 0.1; +// continue; } - if (norm < 1.5 * optimalDistance()) { - continue; - } +// if (norm < 1.5 * getOptimalDistance()) { +// continue; +// } double attractiveForce = fa(norm); double attractX = (deltaX / norm) * attractiveForce; @@ -203,17 +209,16 @@ private void layoutComponent(List nodes) { } for (int v = 0; v < numNodes; v++) { - double norm = - norm(nodeDisposition()[v][0], nodeDisposition()[v][1]); + double norm = norm(nodeDisposition()[v][0], nodeDisposition()[v][1]); - if (norm == 0.0) { - continue; - } +// if (norm == 0.0) { +// continue; +// } nodePosition()[v][0] += (nodeDisposition()[v][0] / norm) * - Math.min(norm, temperature()); + Math.min(norm, getTemperature()); nodePosition()[v][1] += (nodeDisposition()[v][1] / norm) * - Math.min(norm, temperature()); + Math.min(norm, getTemperature()); if (Double.isNaN(nodePosition()[v][0]) || Double.isNaN(nodePosition()[v][1])) { @@ -260,11 +265,11 @@ private void shiftComponentToRight(List componentNodes) { //============================PRIVATE METHODS=========================// \ private double fa(double d) { - return (d * d) / optimalDistance(); + return (d * d) / getOptimalDistance(); } private double fr(double d) { - return -(optimalDistance() * optimalDistance()) / d; + return -(getOptimalDistance() * getOptimalDistance()) / d; } private double norm(double x, double y) { @@ -287,20 +292,8 @@ private double[][] nodeDisposition() { return nodeDisposition; } - private double optimalDistance() { - return getOptimalDistance(); - } - private int numIterations() { - /* - The number of iterations. - */ - int numIterations = 6000; - return numIterations; - } - - private double temperature() { - return getTemperature(); + return 500; } private double leftmostX() { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Graph.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Graph.java index decb5ac774..d35e3f5312 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Graph.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Graph.java @@ -372,6 +372,14 @@ public interface Graph extends TetradSerializable, TripleClassifier { */ boolean isDConnectedTo(Node node1, Node node2, List z); + boolean isPattern(); + + void setPattern(boolean pattern); + + boolean isPag(); + + void setPag(boolean pag); + /** * Determines whether one node is d-separated from another. Two elements are E * d-separated just in case they are not d-connected. diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphUtils.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphUtils.java index 47450c3c11..f8dc82e072 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphUtils.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphUtils.java @@ -28,6 +28,7 @@ import edu.cmu.tetrad.util.TaskManager; import edu.cmu.tetrad.util.TextTable; +import java.awt.*; import java.io.BufferedReader; import java.io.ByteArrayOutputStream; import java.io.CharArrayReader; @@ -1111,7 +1112,7 @@ public static Node getAssociatedNode(Node errorNode, Graph graph) { * @return true if set is a clique in graph.

* R. Silva, June 2004 */ - public static boolean isClique(Set set, Graph graph) { + public static boolean isClique(Collection set, Graph graph) { List setv = new LinkedList<>(set); for (int i = 0; i < setv.size() - 1; i++) { for (int j = i + 1; j < setv.size(); j++) { @@ -1764,6 +1765,12 @@ public static String pathString(List path, Graph graph) { return pathString(graph, path, new LinkedList()); } + public static String pathString(Graph graph, Node...x) { + List path = new ArrayList<>(); + Collections.addAll(path, x); + return pathString(graph, path, new LinkedList()); + } + private static String pathString(Graph graph, List path, List conditioningVars) { StringBuilder buf = new StringBuilder(); @@ -3417,6 +3424,74 @@ public static int[][] edgeMisclassificationCounts1(Graph leftGraph, Graph topGra return counts; } + public static void addPagColoring(Graph graph) { + for (Edge edge : graph.getEdges()) { + if (!Edges.isDirectedEdge(edge)) { + continue; + } + + Node x = Edges.getDirectedEdgeTail(edge); + Node y = Edges.getDirectedEdgeHead(edge); + + graph.removeEdge(edge); + final boolean dashed = existsSemiDirectedPath(x, y, -1, graph); + graph.addEdge(edge); + + if (dashed) { + edge.setDashed(dashed); + } + + if (graph.defVisible(edge)) { + edge.setLineColor(Color.green); + } + } + } + + // Returns true if a path consisting of undirected and directed edges toward 'to' exists of + // length at most 'bound'. Cycle checker in other words. + public static boolean existsSemiDirectedPath(Node from, Node to, int bound, Graph graph) { + Queue Q = new LinkedList<>(); + Set V = new HashSet<>(); + Q.offer(from); + V.add(from); + Node e = null; + int distance = 0; + + while (!Q.isEmpty()) { + Node t = Q.remove(); + if (t == to) { + return true; + } + + if (e == t) { + e = null; + distance++; + if (distance > (bound == -1 ? 1000 : bound)) return false; + } + + for (Node u : graph.getAdjacentNodes(t)) { + Edge edge = graph.getEdge(t, u); + Node c = GraphUtils.traverseSemiDirected(t, edge); + if (c == null) continue; + + if (c == to) { + return true; + } + + if (!V.contains(c)) { + V.add(c); + Q.offer(c); + + if (e == null) { + e = u; + } + } + } + } + + return false; + } + private static class Counts { private int[][] counts; @@ -3574,7 +3649,9 @@ private static int getTypeTop(Edge edgeTop) { return 4; } - throw new IllegalArgumentException("Unsupported edgeTop type : " + edgeTop); + return 5; + +// throw new IllegalArgumentException("Unsupported edge type : " + edgeTop); } private static int getTypeLeft(Edge edgeLeft, Edge edgeTop) { @@ -4948,13 +5025,13 @@ public static List existsUnblockedSemiDirectedPath(Node from, Node to, Set } // Used to find semidirected paths for cycle checking. - private static Node traverseSemiDirected(Node node, Edge edge) { + public static Node traverseSemiDirected(Node node, Edge edge) { if (node == edge.getNode1()) { - if (edge.getEndpoint1() == Endpoint.TAIL) { + if (edge.getEndpoint1() == Endpoint.TAIL || edge.getEndpoint1() == Endpoint.CIRCLE) { return edge.getNode2(); } } else if (node == edge.getNode2()) { - if (edge.getEndpoint2() == Endpoint.TAIL) { + if (edge.getEndpoint2() == Endpoint.TAIL || edge.getEndpoint2() == Endpoint.CIRCLE) { return edge.getNode1(); } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/KamadaKawaiLayout.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/KamadaKawaiLayout.java index d4763af242..1b937b9b63 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/KamadaKawaiLayout.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/KamadaKawaiLayout.java @@ -105,12 +105,14 @@ public KamadaKawaiLayout(Graph graph) { throw new NullPointerException(); } - this.graph = graph; + this.graph = GraphUtils.undirectedGraph(graph); } //============================PUBLIC METHODS==========================// public void doLayout() { + GraphUtils.circleLayout(graph, 300, 300, 200); + this.monitor = new ProgressMonitor(null, "Energy settling...", "Energy = ?", 0, 100); getMonitor().setMillisToDecideToPopup(10); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/LagGraph.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/LagGraph.java index 6531a1b0b6..b51bc19da5 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/LagGraph.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/LagGraph.java @@ -38,6 +38,8 @@ public class LagGraph implements Graph { private List variables = new ArrayList<>(); private int numLags = 0; private Map> laggedVariables = new HashMap<>(); + private boolean pag; + private boolean pattern; // New methods. public boolean addVariable(String variable) { @@ -472,6 +474,26 @@ public List getTriplesClassificationTypes() { public List> getTriplesLists(Node node) { return null; } + + @Override + public boolean isPag() { + return pag; + } + + @Override + public void setPag(boolean pag) { + this.pag = pag; + } + + @Override + public boolean isPattern() { + return pattern; + } + + @Override + public void setPattern(boolean pattern) { + this.pattern = pattern; + } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/MisclassificationUtils.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/MisclassificationUtils.java index b940b0cc73..5516c43166 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/MisclassificationUtils.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/MisclassificationUtils.java @@ -255,7 +255,9 @@ private static int getTypeTop(Edge edgeTop) { return 5; } - throw new IllegalArgumentException("Unsupported edgeTop type : " + e1 + " " + e2); + return 5; + +// throw new IllegalArgumentException("Unsupported edge type : " + e1 + " " + e2); } private static int getTypeLeft(Edge edgeLeft, Edge edgeTop) { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/SemGraph.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/SemGraph.java index 4abee27752..ce76cb09d8 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/SemGraph.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/SemGraph.java @@ -80,6 +80,9 @@ public final class SemGraph implements Graph, TetradSerializable { */ private boolean showErrorTerms = false; + private boolean pag; + private boolean pattern; + //=========================CONSTRUCTORS============================// /** @@ -973,6 +976,26 @@ public List getTriplesClassificationTypes() { public List> getTriplesLists(Node node) { return null; } + + @Override + public boolean isPag() { + return pag; + } + + @Override + public void setPag(boolean pag) { + this.pag = pag; + } + + @Override + public boolean isPattern() { + return pattern; + } + + @Override + public void setPattern(boolean pattern) { + this.pattern = pattern; + } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/TimeLagGraph.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/TimeLagGraph.java index ecf20ce619..e9b737ca7f 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/TimeLagGraph.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/TimeLagGraph.java @@ -49,6 +49,9 @@ public class TimeLagGraph implements Graph { private int numInitialLags = 1; private List lag0Nodes = new ArrayList<>(); + private boolean pag; + private boolean pattern; + public TimeLagGraph() { } @@ -376,6 +379,26 @@ public List> getTriplesLists(Node node) { return null; } + @Override + public boolean isPag() { + return pag; + } + + @Override + public void setPag(boolean pag) { + this.pag = pag; + } + + @Override + public boolean isPattern() { + return pattern; + } + + @Override + public void setPattern(boolean pattern) { + this.pattern = pattern; + } + public static class NodeId { private String name; private int lag; diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/performance/Comparison.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/performance/Comparison.java index fb028bc153..3473847fbd 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/performance/Comparison.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/performance/Comparison.java @@ -3,6 +3,7 @@ import edu.cmu.tetrad.bayes.BayesPm; import edu.cmu.tetrad.bayes.MlBayesIm; import edu.cmu.tetrad.data.*; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.graph.GraphUtils; import edu.cmu.tetrad.graph.Node; @@ -180,34 +181,34 @@ else if (params.getScore() == ComparisonParameters.ScoreType.BDeu) { if (test == null) throw new IllegalArgumentException("Test not set."); Pc search = new Pc(test); result.setResultGraph(search.search()); - result.setCorrectResult(SearchGraphUtils.patternForDag(trueDag)); + result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag))); } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.CPC) { if (test == null) throw new IllegalArgumentException("Test not set."); Cpc search = new Cpc(test); result.setResultGraph(search.search()); - result.setCorrectResult(SearchGraphUtils.patternForDag(trueDag)); + result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag))); } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCLocal) { if (test == null) throw new IllegalArgumentException("Test not set."); PcLocal search = new PcLocal(test); result.setResultGraph(search.search()); - result.setCorrectResult(SearchGraphUtils.patternForDag(trueDag)); + result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag))); } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCMax) { if (test == null) throw new IllegalArgumentException("Test not set."); PcMax search = new PcMax(test); result.setResultGraph(search.search()); - result.setCorrectResult(SearchGraphUtils.patternForDag(trueDag)); - } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FGS) { + result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag))); + } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FGES) { if (score == null) throw new IllegalArgumentException("Score not set."); - Fgs search = new Fgs(score); + Fges search = new Fges(score); search.setFaithfulnessAssumed(params.isOneEdgeFaithfulnessAssumed()); result.setResultGraph(search.search()); - result.setCorrectResult(SearchGraphUtils.patternForDag(trueDag)); - } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FGS2) { + result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag))); + } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FGES2) { if (score == null) throw new IllegalArgumentException("Score not set."); - Fgs search = new Fgs(score); + Fges search = new Fges(score); search.setFaithfulnessAssumed(params.isOneEdgeFaithfulnessAssumed()); result.setResultGraph(search.search()); - result.setCorrectResult(SearchGraphUtils.patternForDag(trueDag)); + result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag))); } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FCI) { if (test == null) throw new IllegalArgumentException("Test not set."); Fci search = new Fci(test); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/performance/Comparison2.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/performance/Comparison2.java index 59ef6f534e..9e9d5e9931 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/performance/Comparison2.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/performance/Comparison2.java @@ -3,6 +3,7 @@ import edu.cmu.tetrad.bayes.BayesPm; import edu.cmu.tetrad.bayes.MlBayesIm; import edu.cmu.tetrad.data.*; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.graph.GraphUtils; import edu.cmu.tetrad.graph.Node; @@ -121,7 +122,7 @@ public static ComparisonResult compare(ComparisonParameters params) { if (params.getAlgorithm() == ComparisonParameters.Algorithm.TsFCI) { trueDag = GraphUtils.randomGraphRandomForwardEdges( nodes, 0, params.getNumEdges(), 10, 10, 10, false, true); - trueDag = TimeSeriesUtils.GraphToLagGraph(trueDag); + trueDag = TimeSeriesUtils.graphToLagGraph(trueDag); System.out.println("Creating Time Lag Graph : " + trueDag); } /***************************/ @@ -139,28 +140,28 @@ public static ComparisonResult compare(ComparisonParameters params) { if (test == null) throw new IllegalArgumentException("Test not set."); Pc search = new Pc(test); result.setResultGraph(search.search()); - result.setCorrectResult(SearchGraphUtils.patternForDag(trueDag)); + result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag))); } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.CPC) { if (test == null) throw new IllegalArgumentException("Test not set."); Cpc search = new Cpc(test); result.setResultGraph(search.search()); - result.setCorrectResult(SearchGraphUtils.patternForDag(trueDag)); + result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag))); } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCLocal) { if (test == null) throw new IllegalArgumentException("Test not set."); PcLocal search = new PcLocal(test); result.setResultGraph(search.search()); - result.setCorrectResult(SearchGraphUtils.patternForDag(trueDag)); + result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag))); } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCMax) { if (test == null) throw new IllegalArgumentException("Test not set."); PcMax search = new PcMax(test); result.setResultGraph(search.search()); - result.setCorrectResult(SearchGraphUtils.patternForDag(trueDag)); - } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FGS) { + result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag))); + } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FGES) { if (score == null) throw new IllegalArgumentException("Score not set."); - Fgs search = new Fgs(score); + Fges search = new Fges(score); //search.setFaithfulnessAssumed(params.isOneEdgeFaithfulnessAssumed()); result.setResultGraph(search.search()); - result.setCorrectResult(SearchGraphUtils.patternForDag(trueDag)); + result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag))); } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FCI) { if (test == null) throw new IllegalArgumentException("Test not set."); Fci search = new Fci(test); @@ -228,7 +229,7 @@ public static ComparisonResult compare(ComparisonParameters params) { if (params.getAlgorithm() == ComparisonParameters.Algorithm.TsFCI) { trueDag = GraphUtils.randomGraphRandomForwardEdges( nodes, 0, params.getNumEdges(), 10, 10, 10, false, true); - trueDag = TimeSeriesUtils.GraphToLagGraph(trueDag); + trueDag = TimeSeriesUtils.graphToLagGraph(trueDag); System.out.println("Creating Time Lag Graph : " + trueDag); } /***************************/ @@ -411,28 +412,28 @@ else if (params.getScore() == ComparisonParameters.ScoreType.BDeu) { if (test == null) throw new IllegalArgumentException("Test not set."); Pc search = new Pc(test); result.setResultGraph(search.search()); - result.setCorrectResult(SearchGraphUtils.patternForDag(trueDag)); + result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag))); } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.CPC) { if (test == null) throw new IllegalArgumentException("Test not set."); Cpc search = new Cpc(test); result.setResultGraph(search.search()); - result.setCorrectResult(SearchGraphUtils.patternForDag(trueDag)); + result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag))); } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCLocal) { if (test == null) throw new IllegalArgumentException("Test not set."); PcLocal search = new PcLocal(test); result.setResultGraph(search.search()); - result.setCorrectResult(SearchGraphUtils.patternForDag(trueDag)); + result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag))); } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCMax) { if (test == null) throw new IllegalArgumentException("Test not set."); PcMax search = new PcMax(test); result.setResultGraph(search.search()); - result.setCorrectResult(SearchGraphUtils.patternForDag(trueDag)); - } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FGS) { + result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag))); + } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FGES) { if (score == null) throw new IllegalArgumentException("Score not set."); - Fgs search = new Fgs(score); + Fges search = new Fges(score); //search.setFaithfulnessAssumed(params.isOneEdgeFaithfulnessAssumed()); result.setResultGraph(search.search()); - result.setCorrectResult(SearchGraphUtils.patternForDag(trueDag)); + result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag))); } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FCI) { if (test == null) throw new IllegalArgumentException("Test not set."); Fci search = new Fci(test); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/performance/ComparisonParameters.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/performance/ComparisonParameters.java index 612d1f804d..2856f9d30e 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/performance/ComparisonParameters.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/performance/ComparisonParameters.java @@ -131,9 +131,9 @@ public void setAlgorithm(Algorithm algorithm) { resultType = ResultType.Pattern; } else if (algorithm == Algorithm.CPC) { resultType = ResultType.Pattern; - } else if (algorithm == Algorithm.FGS) { + } else if (algorithm == Algorithm.FGES) { resultType = ResultType.Pattern; - } else if (algorithm == Algorithm.FGS2) { + } else if (algorithm == Algorithm.FGES2) { resultType = ResultType.Pattern; } else if (algorithm == Algorithm.PCLocal) { resultType = ResultType.Pattern; @@ -323,5 +323,5 @@ public enum DataType {Continuous, Discrete} public enum ResultType {Pattern, PAG} public enum IndependenceTestType {FisherZ, ChiSquare} public enum ScoreType {SemBic, BDeu} - public enum Algorithm {PC, CPC, FGS, FGS2, PCLocal, PCMax, FCI, GFCI, TsFCI} + public enum Algorithm {PC, CPC, FGES, FGES2, PCLocal, PCMax, FCI, GFCI, TsFCI} } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/performance/ComparisonScript.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/performance/ComparisonScript.java index 12a8e8e48b..383d124743 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/performance/ComparisonScript.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/performance/ComparisonScript.java @@ -51,7 +51,7 @@ private void runFromSimulation() { /** add algorithm to compare to the list algList. comment out those you don't want to consider. **/ //algList.add(ComparisonParameters.Algorithm.PC); - //algList.add(ComparisonParameters.Algorithm.FGS); + //algList.add(ComparisonParameters.Algorithm.FGES); //algList.add(ComparisonParameters.Algorithm.FCI); algList.add(ComparisonParameters.Algorithm.TsFCI); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/performance/ExploreComparison.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/performance/ExploreComparison.java index 392190cb9d..791037217c 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/performance/ExploreComparison.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/performance/ExploreComparison.java @@ -13,7 +13,7 @@ public class ExploreComparison { private void runFromSimulation() { ComparisonParameters params = new ComparisonParameters(); params.setDataType(ComparisonParameters.DataType.Continuous); - params.setAlgorithm(ComparisonParameters.Algorithm.FGS2); + params.setAlgorithm(ComparisonParameters.Algorithm.FGES2); // params.setIndependenceTest(ComparisonParameters.IndependenceTestType.FisherZ); params.setScore(ComparisonParameters.ScoreType.SemBic); // params.setOneEdgeFaithfulnessAssumed(false); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/performance/PerformanceTests.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/performance/PerformanceTests.java index 628c58cdf5..3670f7235d 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/performance/PerformanceTests.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/performance/PerformanceTests.java @@ -342,8 +342,8 @@ public void testPcMax(int numVars, double edgeFactor, int numCases, double alpha out.close(); } - public void testFgs(int numVars, double edgeFactor, int numCases, double penaltyDiscount) { - init(new File("long.fgs." + numVars + "." + edgeFactor + "." + penaltyDiscount + ".txt"), "Tests performance of the FGS algorithm"); + public void testFges(int numVars, double edgeFactor, int numCases, double penaltyDiscount) { + init(new File("long.fges." + numVars + "." + edgeFactor + "." + penaltyDiscount + ".txt"), "Tests performance of the FGES algorithm"); long time1 = System.currentTimeMillis(); @@ -378,7 +378,7 @@ public void testFgs(int numVars, double edgeFactor, int numCases, double penalty SemBicScore semBicScore = new SemBicScore(cov); semBicScore.setPenaltyDiscount(penaltyDiscount); - Fgs pcStable = new Fgs(semBicScore); + Fges pcStable = new Fges(semBicScore); Graph estPattern = pcStable.search(); @@ -393,9 +393,9 @@ public void testFgs(int numVars, double edgeFactor, int numCases, double penalty out.println("Elapsed (simulating the data): " + (time2 - time1) + " ms"); out.println("Elapsed (calculating cov): " + (time3 - time2) + " ms"); - out.println("Elapsed (running FGS) " + (time4 - time3) + " ms"); + out.println("Elapsed (running FGES) " + (time4 - time3) + " ms"); - out.println("Total elapsed (cov + FGS) " + (time4 - time2) + " ms"); + out.println("Total elapsed (cov + FGES) " + (time4 - time2) + " ms"); final Graph truePattern = SearchGraphUtils.patternForDag(dag); @@ -742,7 +742,7 @@ public void testGfci(int numVars, double edgeFactor) { fci.setVerbose(false); fci.setMaxPathLength(maxPathLength); - fci.setMaxIndegree(depth); + fci.setMaxDegree(depth); fci.setFaithfulnessAssumed(false); fci.setCompleteRuleSetUsed(true); Graph outGraph = fci.search(); @@ -764,15 +764,15 @@ public void testGfci(int numVars, double edgeFactor) { out.close(); } - public void testFgsComparisonContinuous(int numVars, double edgeFactor, int numCases, int numRuns) { - testFgs(numVars, edgeFactor, numCases, numRuns, true); + public void testFgesComparisonContinuous(int numVars, double edgeFactor, int numCases, int numRuns) { + testFges(numVars, edgeFactor, numCases, numRuns, true); } - public void testFgsComparisonDiscrete(int numVars, double edgeFactor, int numCases, int numRuns) { - testFgs(numVars, edgeFactor, numCases, numRuns, false); + public void testFgesComparisonDiscrete(int numVars, double edgeFactor, int numCases, int numRuns) { + testFges(numVars, edgeFactor, numCases, numRuns, false); } - private void testFgs(int numVars, double edgeFactor, int numCases, int numRuns, boolean continuous) { + private void testFges(int numVars, double edgeFactor, int numCases, int numRuns, boolean continuous) { out.println(new Date()); // RandomUtil.getInstance().setSeed(4828384343999L); @@ -789,7 +789,7 @@ private void testFgs(int numVars, double edgeFactor, int numCases, int numRuns, List elapsedTimes = new ArrayList<>(); if (continuous) { - init(new File("fgs.comparison.continuous" + numVars + "." + (int) (edgeFactor * numVars) + + init(new File("fges.comparison.continuous" + numVars + "." + (int) (edgeFactor * numVars) + "." + numCases + "." + numRuns + ".txt"), "Num runs = " + numRuns); out.println("Num vars = " + numVars); out.println("Num edges = " + (int) (numVars * edgeFactor)); @@ -799,7 +799,7 @@ private void testFgs(int numVars, double edgeFactor, int numCases, int numRuns, out.println(); } else { - init(new File("fgs.comparison.discrete" + numVars + "." + (int) (edgeFactor * numVars) + + init(new File("fges.comparison.discrete" + numVars + "." + (int) (edgeFactor * numVars) + "." + numCases + "." + numRuns + ".txt"), "Num runs = " + numRuns); out.println("Num vars = " + numVars); out.println("Num edges = " + (int) (numVars * edgeFactor)); @@ -875,25 +875,25 @@ private void testFgs(int numVars, double edgeFactor, int numCases, int numRuns, score.setPenaltyDiscount(penaltyDiscount); System.out.println(new Date()); - System.out.println("\nStarting FGS"); + System.out.println("\nStarting FGES"); long timea = System.currentTimeMillis(); - Fgs fgs = new Fgs(score); -// fgs.setVerbose(false); - fgs.setNumPatternsToStore(0); - fgs.setOut(System.out); - fgs.setFaithfulnessAssumed(faithfulness); - fgs.setCycleBound(-1); + Fges fges = new Fges(score); +// fges.setVerbose(false); + fges.setNumPatternsToStore(0); + fges.setOut(System.out); + fges.setFaithfulnessAssumed(faithfulness); + fges.setCycleBound(-1); long timeb = System.currentTimeMillis(); - estPattern = fgs.search(); + estPattern = fges.search(); long timec = System.currentTimeMillis(); - out.println("Time for FGS constructor " + (timeb - timea) + " ms"); - out.println("Time for FGS search " + (timec - timea) + " ms"); + out.println("Time for FGES constructor " + (timeb - timea) + " ms"); + out.println("Time for FGES search " + (timec - timea) + " ms"); out.println(); out.flush(); @@ -918,32 +918,32 @@ private void testFgs(int numVars, double edgeFactor, int numCases, int numRuns, score.setSamplePrior(1); System.out.println(new Date()); - System.out.println("\nStarting FGS"); + System.out.println("\nStarting FGES"); long timea = System.currentTimeMillis(); - Fgs fgs = new Fgs(score); -// fgs.setVerbose(false); - fgs.setNumPatternsToStore(0); - fgs.setOut(System.out); - fgs.setFaithfulnessAssumed(faithfulness); - fgs.setCycleBound(-1); + Fges fges = new Fges(score); +// fges.setVerbose(false); + fges.setNumPatternsToStore(0); + fges.setOut(System.out); + fges.setFaithfulnessAssumed(faithfulness); + fges.setCycleBound(-1); long timeb = System.currentTimeMillis(); - estPattern = fgs.search(); + estPattern = fges.search(); long timec = System.currentTimeMillis(); out.println("Time consructing BDeu score " + (timea - time3) + " ms"); - out.println("Time for FGS constructor " + (timeb - timea) + " ms"); - out.println("Time for FGS search " + (timec - timea) + " ms"); + out.println("Time for FGES constructor " + (timeb - timea) + " ms"); + out.println("Time for FGES search " + (timec - timea) + " ms"); out.println(); elapsed = timec - timea; } - System.out.println("Done with FGS"); + System.out.println("Done with FGES"); System.out.println(new Date()); @@ -1023,15 +1023,15 @@ private void testFgs(int numVars, double edgeFactor, int numCases, int numRuns, out.close(); } - public void testFgsMbComparisonContinuous(int numVars, double edgeFactor, int numCases, int numRuns) { - testFgsMb(numVars, edgeFactor, numCases, numRuns, true); + public void testFgesMbComparisonContinuous(int numVars, double edgeFactor, int numCases, int numRuns) { + testFgesMb(numVars, edgeFactor, numCases, numRuns, true); } - public void testFgsMbComparisonDiscrete(int numVars, double edgeFactor, int numCases, int numRuns) { - testFgsMb(numVars, edgeFactor, numCases, numRuns, false); + public void testFgesMbComparisonDiscrete(int numVars, double edgeFactor, int numCases, int numRuns) { + testFgesMb(numVars, edgeFactor, numCases, numRuns, false); } - private void testFgsMb(int numVars, double edgeFactor, int numCases, int numRuns, boolean continuous) { + private void testFgesMb(int numVars, double edgeFactor, int numCases, int numRuns, boolean continuous) { double penaltyDiscount = 4.0; int structurePrior = 10; @@ -1072,11 +1072,11 @@ private void testFgsMb(int numVars, double edgeFactor, int numCases, int numRuns Graph estPattern; long elapsed; - FgsMb2 fgs; + FgesMb2 fges; List vars; if (continuous) { - init(new File("FgsMb.comparison.continuous" + numVars + "." + (int) (edgeFactor * numVars) + + init(new File("FgesMb.comparison.continuous" + numVars + "." + (int) (edgeFactor * numVars) + "." + numCases + "." + numRuns + ".txt"), "Num runs = " + numRuns); out.println("Num vars = " + numVars); out.println("Num edges = " + (int) (numVars * edgeFactor)); @@ -1118,17 +1118,17 @@ private void testFgsMb(int numVars, double edgeFactor, int numCases, int numRuns score.setPenaltyDiscount(penaltyDiscount); System.out.println(new Date()); - System.out.println("\nStarting FGS-MB"); - - fgs = new FgsMb2(score); - fgs.setVerbose(false); - fgs.setNumPatternsToStore(0); - fgs.setOut(System.out); -// fgs.setHeuristicSpeedup(faithfulness); - fgs.setMaxIndegree(maxIndegree); - fgs.setCycleBound(-1); + System.out.println("\nStarting FGES-MB"); + + fges = new FgesMb2(score); + fges.setVerbose(false); + fges.setNumPatternsToStore(0); + fges.setOut(System.out); +// fges.setHeuristicSpeedup(faithfulness); + fges.setMaxIndegree(maxIndegree); + fges.setCycleBound(-1); } else { - init(new File("FgsMb.comparison.discrete" + numVars + "." + (int) (edgeFactor * numVars) + + init(new File("FgesMb.comparison.discrete" + numVars + "." + (int) (edgeFactor * numVars) + "." + numCases + "." + numRuns + ".txt"), "Num runs = " + numRuns); out.println("Num vars = " + numVars); out.println("Num edges = " + (int) (numVars * edgeFactor)); @@ -1161,22 +1161,22 @@ private void testFgsMb(int numVars, double edgeFactor, int numCases, int numRuns score.setSamplePrior(samplePrior); System.out.println(new Date()); - System.out.println("\nStarting FGS"); + System.out.println("\nStarting FGES"); long time4 = System.currentTimeMillis(); - fgs = new FgsMb2(score); - fgs.setVerbose(false); - fgs.setNumPatternsToStore(0); - fgs.setOut(System.out); -// fgs.setHeuristicSpeedup(faithfulness); - fgs.setMaxIndegree(maxIndegree); - fgs.setCycleBound(-1); + fges = new FgesMb2(score); + fges.setVerbose(false); + fges.setNumPatternsToStore(0); + fges.setOut(System.out); +// fges.setHeuristicSpeedup(faithfulness); + fges.setMaxIndegree(maxIndegree); + fges.setCycleBound(-1); long timeb = System.currentTimeMillis(); out.println("Time consructing BDeu score " + (time4 - time3) + " ms"); - out.println("Time for FGS-MB constructor " + (timeb - time4) + " ms"); + out.println("Time for FGES-MB constructor " + (timeb - time4) + " ms"); out.println(); } @@ -1189,7 +1189,7 @@ private void testFgsMb(int numVars, double edgeFactor, int numCases, int numRuns System.out.println("Target = " + target); long timea = System.currentTimeMillis(); - estPattern = fgs.search(target); + estPattern = fges.search(target); long timed = System.currentTimeMillis(); @@ -1208,10 +1208,10 @@ private void testFgsMb(int numVars, double edgeFactor, int numCases, int numRuns long timec = System.currentTimeMillis(); - out.println("Time for FGS-MB search " + (timec - timea) + " ms"); + out.println("Time for FGES-MB search " + (timec - timea) + " ms"); out.println(); - System.out.println("Done with FGS"); + System.out.println("Done with FGES"); System.out.println(new Date()); @@ -1390,7 +1390,7 @@ public void testGFciComparison(int numVars, double edgeFactor, int numCases, int GFci fci = new GFci(independenceTest, score); // TFci fci = new TFci(independenceTest); // fci.setVerbose(false); - fci.setMaxIndegree(depth); + fci.setMaxDegree(depth); fci.setMaxPathLength(maxPathLength); // fci.setPossibleDsepSearchDone(possibleDsepDone); fci.setCompleteRuleSetUsed(completeRuleSetUsed); @@ -1690,12 +1690,12 @@ public void testComparePcVersions(int numVars, double edgeFactor, int numLatents // // final IndTestFisherZ independenceTestGFci = new IndTestFisherZ(cov, alphaGFci); // -// out6.println("FCI.FGS.PAG"); +// out6.println("FCI.FGES.PAG"); // // GFci GFci = new GFci(independenceTestGFci); // GFci.setVerbose(false); // GFci.setAlpha(penaltyDiscount); -// GFci.setMaxIndegree(depth); +// GFci.setMaxDegree(depth); // GFci.setMaxPathLength(maxPathLength); // GFci.setPossibleDsepSearchDone(true); // GFci.setCompleteRuleSetUsed(true); @@ -1714,7 +1714,7 @@ public void testComparePcVersions(int numVars, double edgeFactor, int numLatents // // PC pc = new PC(independencePc); // pc.setVerbose(false); -// pc.setMaxIndegree(depth); +// pc.setMaxDegree(depth); // // Graph pattern = pc.search(); // @@ -2370,40 +2370,40 @@ else if (args.length == 5) { performanceTests.testCpcStable(numVars, edgeFactor, numCases, alpha); break; } - case "TestFgsComparisonContinuous": { + case "TestFgesComparisonContinuous": { final int numVars = Integer.parseInt(args[1]); final double edgeFactor = Double.parseDouble(args[2]); final int numCases = Integer.parseInt(args[3]); final int numRuns = Integer.parseInt(args[4]); - performanceTests.testFgsComparisonContinuous(numVars, edgeFactor, numCases, numRuns); + performanceTests.testFgesComparisonContinuous(numVars, edgeFactor, numCases, numRuns); break; } - case "TestFgsComparisonDiscrete": { + case "TestFgesComparisonDiscrete": { final int numVars = Integer.parseInt(args[1]); final double edgeFactor = Double.parseDouble(args[2]); final int numCases = Integer.parseInt(args[3]); final int numRuns = Integer.parseInt(args[4]); - performanceTests.testFgsComparisonDiscrete(numVars, edgeFactor, numCases, numRuns); + performanceTests.testFgesComparisonDiscrete(numVars, edgeFactor, numCases, numRuns); break; } - case "TestFgsMbComparisonContinuous": { + case "TestFgesMbComparisonContinuous": { final int numVars = Integer.parseInt(args[1]); final double edgeFactor = Double.parseDouble(args[2]); final int numCases = Integer.parseInt(args[3]); final int numRuns = Integer.parseInt(args[4]); - performanceTests.testFgsMbComparisonContinuous(numVars, edgeFactor, numCases, numRuns); + performanceTests.testFgesMbComparisonContinuous(numVars, edgeFactor, numCases, numRuns); break; } - case "TestFgsMbComparisonDiscrete": { + case "TestFgesMbComparisonDiscrete": { final int numVars = Integer.parseInt(args[1]); final double edgeFactor = Double.parseDouble(args[2]); final int numCases = Integer.parseInt(args[3]); final int numRuns = Integer.parseInt(args[4]); - performanceTests.testFgsMbComparisonDiscrete(numVars, edgeFactor, numCases, numRuns); + performanceTests.testFgesMbComparisonDiscrete(numVars, edgeFactor, numCases, numRuns); break; } default: @@ -2428,7 +2428,7 @@ else if (args.length == 5) { // performanceTests.testPcStable(20000, 1, 1000, .00001); performanceTests.testPcMax(5000, 1, 1000, .0001); // performanceTests.testPcMax(5000, 5, 1000, .0001); -// performanceTests.testFgs(5000, 5, 1000, 4); +// performanceTests.testFges(5000, 5, 1000, 4); // performanceTests.testPcStable(10000, 1, 1000, .0001); // performanceTests.testPcMax(10000, 1, 1000, .0001); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/performance/PerformanceTestsDan.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/performance/PerformanceTestsDan.java index 71c4fbffae..41f43d53cd 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/performance/PerformanceTestsDan.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/performance/PerformanceTestsDan.java @@ -21,7 +21,6 @@ package edu.cmu.tetrad.performance; -import edu.cmu.tetrad.algcomparison.score.SemBicScore; import edu.cmu.tetrad.data.CovarianceMatrix; import edu.cmu.tetrad.data.DataSet; import edu.cmu.tetrad.data.DataUtils; @@ -33,7 +32,6 @@ import edu.cmu.tetrad.search.Pc; import edu.cmu.tetrad.sem.SemIm; import edu.cmu.tetrad.sem.SemPm; -import edu.cmu.tetrad.util.RandomUtil; import java.io.File; import java.io.FileNotFoundException; @@ -181,7 +179,7 @@ private void testIdaOutputForDan() { GFci gFci = new GFci(independenceTestGFci, scoreGfci); gFci.setVerbose(false); - gFci.setMaxIndegree(depth); + gFci.setMaxDegree(depth); gFci.setMaxPathLength(maxPathLength); // gFci.setPossibleDsepSearchDone(true); gFci.setCompleteRuleSetUsed(true); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/performance/readme_ComparisonScript.txt b/tetrad-lib/src/main/java/edu/cmu/tetrad/performance/readme_ComparisonScript.txt index 1150ae961d..3460bf65e5 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/performance/readme_ComparisonScript.txt +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/performance/readme_ComparisonScript.txt @@ -57,7 +57,7 @@ For mode 3): In ComparisonScript.java set parameters.isNoData(true) (it is false by default). You can change numTrials to change the number of random graphs to try, the default is 100. -In this mode you want to see perfect performance everywhere: perfect precision and recall, etc. If you see any deviation from this, there must be a mistake somewhere. Note that this procedure compares, for example, the output of PC or FGS with a true graph specified by DagToPattern. FCI or GFCI would be compared with the output of DagToPag. If you want to examine an algorithm which searches for something other than a standard Pattern or PAG, you have to add the appropriate method in Comparison2.java. Just search for “DagToPag” and you’ll see where it goes. +In this mode you want to see perfect performance everywhere: perfect precision and recall, etc. If you see any deviation from this, there must be a mistake somewhere. Note that this procedure compares, for example, the output of PC or FGES with a true graph specified by DagToPattern. FCI or GFCI would be compared with the output of DagToPag. If you want to examine an algorithm which searches for something other than a standard Pattern or PAG, you have to add the appropriate method in Comparison2.java. Just search for “DagToPag” and you’ll see where it goes. Note also that some algorithms are quite slow running directly on graphs. Start with small variable sets. diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/AdLeafTree.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/AdLeafTree.java index 7d721f55f3..74a28d7c41 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/AdLeafTree.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/AdLeafTree.java @@ -102,6 +102,47 @@ public int compare(DiscreteVariable o1, DiscreteVariable o2) { return rows; } + /** + * Finds the set of indices into the leaves of the tree for the given variables. + * Counts are the sizes of the index sets. + * + * @param A A list of discrete variables. + * @return The list of index sets of the first variable varied by the second variable, + * and so on, to the last variable. + */ + public List>> getCellLeaves(List A, DiscreteVariable B) { + Collections.sort(A, new Comparator() { + + @Override + public int compare(DiscreteVariable o1, DiscreteVariable o2) { + return Integer.compare(nodesHash.get(o1), nodesHash.get(o2)); + } + }); + + if (baseCase == null) { + Vary vary = new Vary(); + this.baseCase = new ArrayList<>(); + baseCase.add(vary); + } + + List varies = baseCase; + + for (DiscreteVariable v : A) { + varies = getVaries(varies, nodesHash.get(v)); + } + + List>> rows = new ArrayList<>(); + + for (Vary vary : varies) { + for (int i = 0; i < vary.getNumCategories(); i++) { + Vary subvary = vary.getSubvary(nodesHash.get(B), i); + rows.add(subvary.getRows()); + } + } + + return rows; + } + private List getVaries(List varies, int v) { List _varies = new ArrayList<>(); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/BdeuScoreImages.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/BdeuScoreImages.java index d75c872f39..ce2ee36c77 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/BdeuScoreImages.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/BdeuScoreImages.java @@ -34,7 +34,7 @@ import java.util.List; /** - * Implements the continuous BIC score for FGS. + * Implements the continuous BIC score for FGES. * * @author Joseph Ramsey */ diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/BffGes.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/BffGes.java index 53fc82bbaa..7877e4dac9 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/BffGes.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/BffGes.java @@ -497,7 +497,7 @@ private boolean validInsert(Node x, Node y, Set subset, Graph graph) { List naYXT = new LinkedList<>(subset); naYXT.addAll(findNaYX(x, y, graph)); - return isClique(naYXT, graph) && isSemiDirectedBlocked(x, y, naYXT, graph, new HashSet()); + return GraphUtils.isClique(naYXT, graph) && isSemiDirectedBlocked(x, y, naYXT, graph, new HashSet()); } @@ -508,7 +508,7 @@ private static boolean validDelete(Node x, Node y, Set h, Graph graph) { List naYXH = findNaYX(x, y, graph); naYXH.removeAll(h); - return isClique(naYXH, graph); + return GraphUtils.isClique(naYXH, graph); } /** @@ -644,21 +644,6 @@ private boolean validSetByKnowledge(Node x, Node y, Set subset, // return score1 - score2; // } - /** - * @return true iif the given set forms a clique in the given graph. - */ - private static boolean isClique(List set, Graph graph) { - List setv = new LinkedList<>(set); - for (int i = 0; i < setv.size() - 1; i++) { - for (int j = i + 1; j < setv.size(); j++) { - if (!graph.isAdjacentTo(setv.get(i), setv.get(j))) { - return false; - } - } - } - return true; - } - /** * Verifies if every semidirected path from y to x contains a node in naYXT. */ diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/CcdMax.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/CcdMax.java new file mode 100644 index 0000000000..81a28a00f3 --- /dev/null +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/CcdMax.java @@ -0,0 +1,569 @@ +/////////////////////////////////////////////////////////////////////////////// +// For information as to what this class does, see the Javadoc, below. // +// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // +// 2007, 2008, 2009, 2010, 2014, 2015 by Peter Spirtes, Richard Scheines, Joseph // +// Ramsey, and Clark Glymour. // +// // +// This program is free software; you can redistribute it and/or modify // +// it under the terms of the GNU General Public License as published by // +// the Free Software Foundation; either version 2 of the License, or // +// (at your option) any later version. // +// // +// This program is distributed in the hope that it will be useful, // +// but WITHOUT ANY WARRANTY; without even the implied warranty of // +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // +// GNU General Public License for more details. // +// // +// You should have received a copy of the GNU General Public License // +// along with this program; if not, write to the Free Software // +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // +/////////////////////////////////////////////////////////////////////////////// + +package edu.cmu.tetrad.search; + +import edu.cmu.tetrad.algcomparison.independence.FisherZ; +import edu.cmu.tetrad.data.*; +import edu.cmu.tetrad.graph.*; +import edu.cmu.tetrad.util.*; + +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.RecursiveTask; + +/** + * This is an optimization of the CCD (Cyclic Causal Discovery) algorithm by Thomas Richardson. + * + * @author Joseph Ramsey + */ +public final class CcdMax implements GraphSearch { + private final IndependenceTest independenceTest; + private int depth = -1; + private boolean applyOrientAwayFromCollider = false; + private long elapsed = 0; + private IKnowledge knowledge = new Knowledge2(); + private boolean useHeuristic = true; + private int maxPathLength = 3; + private boolean useOrientTowardDConnections = true; + private boolean orientVisibleFeedbackLoops = true; + private boolean doColliderOrientations = true; + + public CcdMax(IndependenceTest test) { + if (test == null) throw new NullPointerException(); + this.independenceTest = test; + } + + //======================================== PUBLIC METHODS ====================================// + + /** + * Searches for a PAG satisfying the description in Thomas Richardson (1997), dissertation, + * Carnegie Mellon University. Uses a simplification of that algorithm. + */ + public Graph search() { + SepsetMap map = new SepsetMap(); + System.out.println("FAS"); + Graph graph = fastAdjacencySearch(); + SearchGraphUtils.pcOrientbk(knowledge, graph, graph.getNodes()); + System.out.println("Two shield constructs"); + + if (orientVisibleFeedbackLoops) { + orientTwoShieldConstructs(graph); + } + + System.out.println("Max P collider orientation"); + + if (doColliderOrientations) { + final OrientCollidersMaxP orientCollidersMaxP = new OrientCollidersMaxP(independenceTest); + orientCollidersMaxP.setUseHeuristic(useHeuristic); + orientCollidersMaxP.setMaxPathLength(maxPathLength); + orientCollidersMaxP.orient(graph); + } + + orientAwayFromArrow(graph); + + System.out.println("Toward D-connection"); + + if (useOrientTowardDConnections) { + orientTowardDConnection(graph, map); + } + + System.out.println("Done"); + return graph; + } + + /** + * @return The depth of search for the Fast Adjacency Search. + */ + public int getDepth() { + return depth; + } + + /** + * @param depth The depth of search for the Fast Adjacency Search. + */ + public void setDepth(int depth) { + this.depth = depth; + } + + /** + * @return The elapsed time in milliseconds. + */ + public long getElapsedTime() { + return elapsed; + } + + /** + * @param applyOrientAwayFromCollider True if the orient away from collider rule should be + * applied. + */ + public void setApplyOrientAwayFromCollider(boolean applyOrientAwayFromCollider) { + this.applyOrientAwayFromCollider = applyOrientAwayFromCollider; + } + + //======================================== PRIVATE METHODS ====================================// + + private Graph fastAdjacencySearch() { + long start = System.currentTimeMillis(); + + FasStableConcurrent fas = new FasStableConcurrent(null, independenceTest); + fas.setDepth(getDepth()); + fas.setKnowledge(knowledge); + fas.setVerbose(false); + fas.setRecordSepsets(false); + Graph graph = fas.search(); + + long stop = System.currentTimeMillis(); + this.elapsed = stop - start; + + return new EdgeListGraph(graph); + } + + // Orient feedback loops and a few extra directed edges. + private void orientTwoShieldConstructs(Graph graph) { + TetradLogger.getInstance().log("info", "\nStep E"); + + for (Node c : graph.getNodes()) { + List adj = graph.getAdjacentNodes(c); + + for (int i = 0; i < adj.size(); i++) { + Node a = adj.get(i); + + for (int j = i + 1; j < adj.size(); j++) { + Node b = adj.get(j); + if (a == b) continue; + if (graph.isAdjacentTo(a, b)) continue; + + for (Node d : adj) { + if (d == a || d == b) continue; + + if (graph.isAdjacentTo(d, a) && graph.isAdjacentTo(d, b)) { + if (sepset(graph, a, b, set(), set(c, d)) != null) { + if ((graph.getEdges().size() == 2 || Edges.isDirectedEdge(graph.getEdge(c, d)))) { + continue; + } + + if (sepset(graph, a, b, set(c, d), set()) != null) { + orientCollider(graph, a, c, b); + orientCollider(graph, a, d, b); + addFeedback(graph, c, d); + } + } + } + } + } + } + } + } + + private void orientTowardDConnection(Graph graph, SepsetMap map) { + + EDGE: + for (Edge edge : graph.getEdges()) { + if (!Edges.isUndirectedEdge(edge)) continue; + + Set surround = new HashSet<>(); + Node b = edge.getNode1(); + Node c = edge.getNode2(); + surround.add(b); + + for (int i = 1; i < 3; i++) { + for (Node z : new HashSet<>(surround)) { + surround.addAll(graph.getAdjacentNodes(z)); + } + } + + surround.remove(b); + surround.remove(c); + surround.removeAll(graph.getAdjacentNodes(b)); + surround.removeAll(graph.getAdjacentNodes(c)); + boolean orient = false; + boolean agree = true; + + for (Node a : surround) { +// List sepsetax = map.get(a, b); +// List sepsetay = map.get(a, c); + + List sepsetax = maxPSepset(a, b, graph).getCond(); + List sepsetay = maxPSepset(a, c, graph).getCond(); + + if (sepsetax == null) continue; + if (sepsetay == null) continue; + + if (!sepsetax.equals(sepsetay)) { + if (sepsetax.containsAll(sepsetay)) { + orient = true; + } else { + agree = false; + } + } + } + + if (orient && agree) { + addDirectedEdge(graph, c, b); + } + + for (Node a : surround) { + if (b == a) continue; + if (c == a) continue; + if (graph.getAdjacentNodes(b).contains(a)) continue; + if (graph.getAdjacentNodes(c).contains(a)) continue; + + List sepsetax = map.get(a, b); + List sepsetay = map.get(a, c); + + if (sepsetax == null) continue; + if (sepsetay == null) continue; + if (sepsetay.contains(b)) continue; + + if (!sepsetay.containsAll(sepsetax)) { + if (!independenceTest.isIndependent(a, b, sepsetay)) { + addDirectedEdge(graph, c, b); + continue EDGE; + } + } + } + } + } + + private void addDirectedEdge(Graph graph, Node a, Node b) { + graph.removeEdges(a, b); + graph.addDirectedEdge(a, b); + orientAwayFromArrow(graph, a, b); + } + + private void addFeedback(Graph graph, Node a, Node b) { + graph.removeEdges(a, b); + graph.addEdge(Edges.directedEdge(a, b)); + graph.addEdge(Edges.directedEdge(b, a)); + } + + private void orientCollider(Graph graph, Node a, Node b, Node c) { + if (wouldCreateBadCollider(graph, a, b)) return; + if (wouldCreateBadCollider(graph, c, b)) return; + if (graph.getEdges(a, b).size() > 1) return; + if (graph.getEdges(b, c).size() > 1) return; + graph.removeEdge(a, b); + graph.removeEdge(c, b); + graph.addDirectedEdge(a, b); + graph.addDirectedEdge(c, b); + } + + private void orientAwayFromArrow(Graph graph, Node a, Node b) { + if (!applyOrientAwayFromCollider) return; + + for (Node c : graph.getAdjacentNodes(b)) { + if (c == a) continue; + orientAwayFromArrowVisit(a, b, c, graph); + } + } + + private boolean wouldCreateBadCollider(Graph graph, Node x, Node y) { + for (Node z : graph.getAdjacentNodes(y)) { + if (x == z) continue; + + if ( graph.isDefCollider(x, y, z)) { + return true; + } + +// if (!graph.isAdjacentTo(z, y) && +// graph.getEndpoint(z, y) == Endpoint.ARROW +//// && +//// sepset(graph, x, z, set(), set(y)) == null +// ) { +// return true; +// } + } + + return false; + } + + public IKnowledge getKnowledge() { + return knowledge; + } + + public void setKnowledge(IKnowledge knowledge) { + this.knowledge = knowledge; + } + + public boolean isUseHeuristic() { + return useHeuristic; + } + + public void setUseHeuristic(boolean useHeuristic) { + this.useHeuristic = useHeuristic; + } + + public int getMaxPathLength() { + return maxPathLength; + } + + public void setMaxPathLength(int maxPathLength) { + this.maxPathLength = maxPathLength; + } + + public boolean isUseOrientTowardDConnections() { + return useOrientTowardDConnections; + } + + public void setUseOrientTowardDConnections(boolean useOrientTowardDConnections) { + this.useOrientTowardDConnections = useOrientTowardDConnections; + } + + public void setOrientVisibleFeedbackLoops(boolean orientVisibleFeedbackLoops) { + this.orientVisibleFeedbackLoops = orientVisibleFeedbackLoops; + } + + public boolean isOrientVisibleFeedbackLoops() { + return orientVisibleFeedbackLoops; + } + + public boolean isDoColliderOrientations() { + return doColliderOrientations; + } + + public void setDoColliderOrientations(boolean doColliderOrientations) { + this.doColliderOrientations = doColliderOrientations; + } + + private class Pair { + private List cond; + private double score; + + Pair(List cond, double score) { + this.cond = cond; + this.score = score; + } + + public List getCond() { + return cond; + } + + public double getScore() { + return score; + } + } + + private Pair maxPSepset(Node i, Node k, Graph graph) { + double _p = Double.POSITIVE_INFINITY; + List _v = null; + + List adji = graph.getAdjacentNodes(i); + List adjk = graph.getAdjacentNodes(k); + adji.remove(k); + adjk.remove(i); + + for (int d = 0; d <= Math.min((depth == -1 ? 1000 : depth), Math.max(adji.size(), adjk.size())); d++) { + if (d <= adji.size()) { + ChoiceGenerator gen = new ChoiceGenerator(adji.size(), d); + int[] choice; + + WHILE: + while ((choice = gen.next()) != null) { + List v2 = GraphUtils.asList(choice, adji); + + for (Node v : v2) { + if (isForbidden(i, k, v2)) continue WHILE; + } + + try { + getIndependenceTest().isIndependent(i, k, v2); + double p2 = getIndependenceTest().getScore(); + + if (p2 < _p) { + _p = p2; + _v = v2; + } + } catch (Exception e) { + e.printStackTrace(); + return new Pair(null, Double.POSITIVE_INFINITY); + } + } + } + + if (d <= adjk.size()) { + ChoiceGenerator gen = new ChoiceGenerator(adjk.size(), d); + int[] choice; + + while ((choice = gen.next()) != null) { + List v2 = GraphUtils.asList(choice, adjk); + + try { + getIndependenceTest().isIndependent(i, k, v2); + double p2 = getIndependenceTest().getScore(); + + if (p2 < _p) { + _p = p2; + _v = v2; + } + } catch (Exception e) { + e.printStackTrace(); + return new Pair(null, Double.POSITIVE_INFINITY); + } + } + } + } + + return new Pair(_v, _p); + } + + private boolean isForbidden(Node i, Node k, List v) { + for (Node w : v) { + if (knowledge.isForbidden(w.getName(), i.getName())) { + return true; + } + + if (knowledge.isForbidden(w.getName(), k.getName())) { + return true; + } + } + + return false; + } + + // Returns a sepset containing the nodes in 'containing' but not the nodes in 'notContaining', or + // null if there is no such sepset. + private List sepset(Graph graph, Node a, Node c, Set containing, Set notContaining) { + List adj = graph.getAdjacentNodes(a); + adj.addAll(graph.getAdjacentNodes(c)); + adj.remove(c); + adj.remove(a); + + for (int d = 0; d <= Math.min((depth == -1 ? 1000 : depth), Math.max(adj.size(), adj.size())); d++) { + if (d <= adj.size()) { + ChoiceGenerator gen = new ChoiceGenerator(adj.size(), d); + int[] choice; + + WHILE: + while ((choice = gen.next()) != null) { + Set v2 = GraphUtils.asSet(choice, adj); + v2.addAll(containing); + v2.removeAll(notContaining); + v2.remove(a); + v2.remove(c); + + if (isForbidden(a, c, new ArrayList<>(v2))) + + getIndependenceTest().isIndependent(a, c, new ArrayList<>(v2)); + double p2 = getIndependenceTest().getScore(); + + if (p2 < 0) { + return new ArrayList<>(v2); + } + } + } + } + + return null; + } + + private Set set(Node... n) { + Set S = new HashSet<>(); + Collections.addAll(S, n); + return S; + } + + private IndependenceTest getIndependenceTest() { + return independenceTest; + } + + private void orientAwayFromArrow(Graph graph) { + for (Edge edge : graph.getEdges()) { + Node n1 = edge.getNode1(); + Node n2 = edge.getNode2(); + + edge = graph.getEdge(n1, n2); + + if (edge.pointsTowards(n1)) { + orientAwayFromArrow(graph, n2, n1); + } else if (edge.pointsTowards(n2)) { + orientAwayFromArrow(graph, n1, n2); + } + } + } + + private void orientAwayFromArrowVisit(Node a, Node b, Node c, Graph graph) { + + // This shouldn't happen--a--b--c should be shielded. Checking just in case... + if (graph.getEdges(b, c).size() > 1) { + return; + } + + if (!Edges.isUndirectedEdge(graph.getEdge(b, c))) { + return; + } + + if (graph.isAdjacentTo(a, c)) { + return; + } + +// if (sepset(graph, a, c, set(), set(b)) != null) { +// return; +// } + + if (wouldCreateBadCollider(graph, b, c)) { + return; + } + + addDirectedEdge(graph, b, c); + + List undirectedEdges = new ArrayList<>(); + + for (Node d : graph.getAdjacentNodes(c)) { + if (d == b) continue; + Edge e = graph.getEdge(c, d); + if (Edges.isUndirectedEdge(e)) undirectedEdges.add(e); + } + + for (Node d : graph.getAdjacentNodes(c)) { + if (d == b) continue; + orientAwayFromArrowVisit(b, c, d, graph); + } + + boolean allOriented = true; + + for (Edge e : undirectedEdges) { + Node d = Edges.traverse(c, e); + Edge f = graph.getEdge(c, d); + + if (!f.pointsTowards(d)) { + allOriented = false; + break; + } + } + + if (!allOriented) { + for (Edge e : undirectedEdges) { + Node d = Edges.traverse(c, e); + Edge f = graph.getEdge(c, d); + + graph.removeEdge(f); + graph.addEdge(e); + } + } + } +} + + + + + + diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/ConditionalGaussianLikelihood.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/ConditionalGaussianLikelihood.java index 6e981184c2..14b70c3ce4 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/ConditionalGaussianLikelihood.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/ConditionalGaussianLikelihood.java @@ -26,20 +26,26 @@ import edu.cmu.tetrad.data.DiscreteVariable; import edu.cmu.tetrad.graph.Node; import edu.cmu.tetrad.util.TetradMatrix; +import edu.cmu.tetrad.util.TetradVector; import org.apache.commons.math3.stat.correlation.Covariance; import java.util.*; +import static java.lang.Math.log; + /** - * Implements a conditional Gaussian BIC score for FGS. + * Implements a conditional Gaussian likelihood. Please note that this this likelihood will be maximal only if the + * the continuous variables are jointly Gaussian conditional on the discrete variables; in all other cases, it will + * be less than maximal. For an algorithm like FGES this is fine. * * @author Joseph Ramsey */ public class ConditionalGaussianLikelihood { + // The data set. May contain continuous and/or discrete variables. private DataSet dataSet; - // The variables of the continuousData set. + // The variables of the data set. private List variables; // Indices of variables. @@ -48,8 +54,40 @@ public class ConditionalGaussianLikelihood { // Continuous data only. private double[][] continuousData; - // Discrete data only. - private int[][] discreteData; + //The AD Tree used to count discrete cells. + private AdLeafTree adTree; + + // True if by assumption the denominator for problems like P(C | X) = P(X | C) P(C) / P(X) is mixed. + private boolean denominatorMixed = true; + + // Multiplier on degrees of freedom for the continuous portion of those degrees. + private double penaltyDiscount = 1; + + /** + * A return value for a likelihood--returns a likelihood value and the degrees of freedom + * for it. + */ + public class Ret { + private double lik; + private int dof; + + private Ret(double lik, int dof) { + this.lik = lik; + this.dof = dof; + } + + public double getLik() { + return lik; + } + + public int getDof() { + return dof; + } + + public String toString() { + return "lik = " + lik + " dof = " + dof; + } + } /** * Constructs the score using a covariance matrix. @@ -63,7 +101,6 @@ public ConditionalGaussianLikelihood(DataSet dataSet) { this.variables = dataSet.getVariables(); continuousData = new double[dataSet.getNumColumns()][]; - discreteData = new int[dataSet.getNumColumns()][]; for (int j = 0; j < dataSet.getNumColumns(); j++) { Node v = dataSet.getVariable(j); @@ -76,14 +113,6 @@ public ConditionalGaussianLikelihood(DataSet dataSet) { } continuousData[j] = col; - } else if (v instanceof DiscreteVariable) { - int[] col = new int[dataSet.getNumRows()]; - - for (int i = 0; i < dataSet.getNumRows(); i++) { - col[i] = dataSet.getInt(i, j); - } - - discreteData[j] = col; } } @@ -94,40 +123,18 @@ public ConditionalGaussianLikelihood(DataSet dataSet) { nodesHash.put(v, j); } - this.adTree = AdTrees.getAdLeafTree(dataSet);// new AdLeafTree(dataSet); + this.adTree = AdTrees.getAdLeafTree(dataSet); } - private int getDof2(int i, int[] parents) { - Node target = variables.get(i); - - int dof2; - - List X = new ArrayList<>(); - List A = new ArrayList<>(); - - for (int parent1 : parents) { - Node parent = variables.get(parent1); - - if (parent instanceof ContinuousVariable) { - X.add((ContinuousVariable) parent); - } else { - A.add((DiscreteVariable) parent); - } - } - - if (target instanceof ContinuousVariable) { - dof2 = f(A) * g(X); - } else if (target instanceof DiscreteVariable) { - List b = Collections.singletonList((DiscreteVariable) target); - dof2 = f(A) * (f(b) - 1) + f(A) * f(b) * h(X); - } else { - throw new IllegalStateException(); - } - - return dof2; - } - - public Ret getLikelihoodRatio(int i, int[] parents) { + /** + * Returns the likelihood of variable i conditional on the given parents, assuming the continuous variables + * index by i or by the parents are jointly Gaussian conditional on the discrete comparison. + * + * @param i The index of the conditioned variable. + * @param parents The indices of the conditioning variables. + * @return The likelihood. + */ + public Ret getLikelihood(int i, int[] parents) { Node target = variables.get(i); List X = new ArrayList<>(); @@ -152,153 +159,220 @@ public Ret getLikelihoodRatio(int i, int[] parents) { APlus.add((DiscreteVariable) target); } - Ret ret1 = getJointLikelihood(XPlus, APlus); - Ret ret2 = getJointLikelihood(X, A); + if (target instanceof DiscreteVariable && !X.isEmpty() && denominatorMixed) { + return likelihoodMixed(X, A, (DiscreteVariable) target); + } else { + Ret ret1 = likelihoodJoint(XPlus, APlus); + Ret ret2 = likelihoodJoint(X, A); + return new Ret(ret1.getLik() - ret2.getLik(), ret1.getDof() - ret2.getDof()); + } + } - double lik = ret1.getLik() - ret2.getLik(); - int dof = ret1.getDof() - ret2.getDof(); -// int dof = getDof2(i, parents); - return new Ret(lik, dof); + public void setDenominatorMixed(boolean denominatorMixed) { + this.denominatorMixed = denominatorMixed; } - // The likelihood of the joint over all of these variables, continuous and discrete. - private Ret getJointLikelihood(List X, List A) { - int p = X.size(); + public double getPenaltyDiscount() { + return penaltyDiscount; + } + + public void setPenaltyDiscount(double penaltyDiscount) { + this.penaltyDiscount = penaltyDiscount; + } -// List> cells = getCellsOriginal(A); - List> cells = adTree.getCellLeaves(A); - int[] continuousCols = new int[p]; - for (int j = 0; j < p; j++) continuousCols[j] = nodesHash.get(X.get(j)); + // The likelihood of the joint over all of these variables, assuming conditional Gaussian, + // continuous and discrete. + private Ret likelihoodJoint(List X, List A) { + int k = X.size(); + + int[] continuousCols = new int[k]; + for (int j = 0; j < k; j++) continuousCols[j] = nodesHash.get(X.get(j)); int N = dataSet.getNumRows(); - double lik = 0; + + double c1 = 0, c2 = 0; + + List> cells = adTree.getCellLeaves(A); for (List cell : cells) { - int r = cell.size(); + int a = cell.size(); + if (a == 0) continue; if (A.size() > 0) { - if (r > 0) { - double prob = r / (double) N; - lik += r * Math.log(prob); - } + c1 += a * multinomialLikelihood(a, N); } if (X.size() > 0) { - if (r > 3 * p) { - TetradMatrix subset = new TetradMatrix(r, p); + double v; - for (int i = 0; i < r; i++) { - for (int j = 0; j < p; j++) { - subset.set(i, j, continuousData[continuousCols[j]][cell.get(i)]); - } + try { + + // Determinant will be zero if data are linearly dependent. + if (a <= continuousCols.length) { + throw new IllegalArgumentException(); } - TetradMatrix Sigma = new TetradMatrix(new Covariance(subset.getRealMatrix(), - false).getCovarianceMatrix()); - double det = Sigma.det(); - lik -= 0.5 * r * Math.log(det); - } + TetradMatrix cov = cov(getSubsample(continuousCols, cell)); + v = gaussianLikelihood(k, cov); - lik -= 0.5 * r * p * (1.0 + Math.log(2.0 * Math.PI)); + // Double check. + if (Double.isInfinite(v)) { + throw new IllegalArgumentException(); + } + + c2 += a * v; + } catch (Exception e) { + // No contribution. + } } } - int dof = f(A) * h(X) + f(A); + final double lnL = c1 + c2; + int p = (int) getPenaltyDiscount(); + + // Only count dof for continuous cells that contributed to the likelihood calculation. + final int dof = f(A) * p * h(X) + f(A); + return new Ret(lnL, dof); + } - return new Ret(lik, dof); + private double multinomialLikelihood(int a, int N) { + return log(a / (double) N); } - private List> getCellsOriginal(List A) { - int d = A.size(); + // One record. + private double gaussianLikelihood(int k, TetradMatrix sigma) { + return -0.5 * (log(sigma.det()) - k - k * log(2.0 * Math.PI)); + } - // For each combination of values for the A guys extract a subset of the data. - int[] discreteCols = new int[d]; - int[] dims = new int[d]; - int n = dataSet.getNumRows(); + // For cases like P(C | X). This is a ratio of joints, but if the numerator is conditional Gaussian, + // the denominator is a mixture of Gaussians. + private Ret likelihoodMixed(List X, List A, DiscreteVariable B) { + final int k = X.size(); + final double g = Math.pow(2.0 * Math.PI, k); - for (int i = 0; i < d; i++) discreteCols[i] = nodesHash.get(A.get(i)); - for (int i = 0; i < d; i++) dims[i] = A.get(i).getNumCategories(); + int[] continuousCols = new int[k]; + for (int j = 0; j < k; j++) continuousCols[j] = nodesHash.get(X.get(j)); + double lnL = 0.0; - List> cells = new ArrayList<>(); - for (int i = 0; i < f(A); i++) { - cells.add(new ArrayList()); - } + int N = dataSet.getNumRows(); + + List>> cells = adTree.getCellLeaves(A, B); + + TetradMatrix defaultCov = null; + + for (List> mycells : cells) { + List x = new ArrayList<>(); + List sigmas = new ArrayList<>(); + List inv = new ArrayList<>(); + List mu = new ArrayList<>(); + + for (List cell : mycells) { + TetradMatrix subsample = getSubsample(continuousCols, cell); - int[] values = new int[A.size()]; + try { - for (int i = 0; i < n; i++) { - for (int j = 0; j < A.size(); j++) { - values[j] = discreteData[discreteCols[j]][i]; + // Determinant will be zero if data are linearly dependent. + if (mycells.size() <= continuousCols.length) throw new IllegalArgumentException(); + + TetradMatrix cov = cov(subsample); + TetradMatrix covinv = cov.inverse(); + + if (defaultCov == null) { + defaultCov = cov; + } + + x.add(subsample); + sigmas.add(cov); + inv.add(covinv); + mu.add(means(subsample)); + } catch (Exception e) { + // No contribution. + } } - int rowIndex = getRowIndex(values, dims); - cells.get(rowIndex).add(i); - } + double[] factors = new double[x.size()]; - return cells; - } + for (int u = 0; u < x.size(); u++) { + factors[u] = Math.pow(g * sigmas.get(u).det(), -0.5); + } - private AdLeafTree adTree; + double[] a = new double[x.size()]; - public class Ret { - private double lik; - private int dof; + for (int u = 0; u < x.size(); u++) { + for (int i = 0; i < x.get(u).rows(); i++) { + for (int v = 0; v < x.size(); v++) { + final TetradVector xm = x.get(u).getRow(i).minus(mu.get(v)); + a[v] = prob(factors[v], inv.get(v), xm); + } - public Ret(double lik, int dof) { - this.lik = lik; - this.dof = dof; - } + double num = a[u] * p(x, u, N); + double denom = 0.0; - public double getLik() { - return lik; - } + for (int v = 0; v < x.size(); v++) { + denom += a[v] * (p(x, v, N)); + } - public int getDof() { - return dof; + lnL += log(num) - log(denom); + } + } } - public String toString() { return "lik = " + lik + " dof = " + dof;}; - } + int p = (int) getPenaltyDiscount(); - private int f(List A) { - int f = 1; + // Only count dof for continuous cells that contributed to the likelihood calculation. + int dof = f(A) * B.getNumCategories() + f(A) * p * h(X); + return new Ret(lnL, dof); + } - for (DiscreteVariable V : A) { - f *= V.getNumCategories(); - } + private double p(List x, int u, double N) { + return x.get(u).rows() / N; + } - return f; + private TetradMatrix cov(TetradMatrix x) { + return new TetradMatrix(new Covariance(x.getRealMatrix(), true).getCovarianceMatrix()); } - private int g(List X) { - return X.size() + 1; + private double prob(Double factor, TetradMatrix inv, TetradVector x) { + return factor * Math.exp(-0.5 * inv.times(x).dotProduct(x)); } - private int h(List X) { - int p = X.size(); - return p * (p + 1) / 2; + // Calculates the means of the columns of x. + private TetradVector means(TetradMatrix x) { + return x.sum(1).scalarMult(1.0 / x.rows()); } - private int j(List A) { - int v = 1; + // Subsample of the continuous variables conditioning on the given cell. + private TetradMatrix getSubsample(int[] continuousCols, List cell) { + int a = cell.size(); + TetradMatrix subset = new TetradMatrix(a, continuousCols.length); - for (DiscreteVariable a : A) { - v *= a.getNumCategories() - 1; + for (int i = 0; i < a; i++) { + for (int j = 0; j < continuousCols.length; j++) { + subset.set(i, j, continuousData[continuousCols[j]][cell.get(i)]); + } } - return v; + return subset; } - public int getRowIndex(int[] values, int[] dims) { - int rowIndex = 0; + // Degrees of freedom for a discrete distribution is the product of the number of categories for each + // variable. + private int f(List A) { + int f = 1; - for (int i = 0; i < dims.length; i++) { - rowIndex *= dims[i]; - rowIndex += values[i]; + for (DiscreteVariable V : A) { + f *= V.getNumCategories(); } - return rowIndex; + return f; + } + + // Degrees of freedom for a multivariate Gaussian distribution is p * (p + 1) / 2, where p is the number + // of variables. This is the number of unique entries in the covariance matrix over X. + private int h(List X) { + int p = X.size(); + return p * (p + 1) / 2; } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/ConditionalGaussianScore.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/ConditionalGaussianScore.java index 062587b3e4..d9e44c66e0 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/ConditionalGaussianScore.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/ConditionalGaussianScore.java @@ -27,7 +27,7 @@ import java.util.*; /** - * Implements a conditional Gaussian BIC score for FGS. + * Implements a conditional Gaussian BIC score for FGES. * * @author Joseph Ramsey */ @@ -41,6 +41,9 @@ public class ConditionalGaussianScore implements Score { // Likelihood function private ConditionalGaussianLikelihood likelihood; + private double penaltyDiscount = 2; + private boolean denominatorMixed = true; + /** * Constructs the score using a covariance matrix. */ @@ -59,15 +62,16 @@ public ConditionalGaussianScore(DataSet dataSet) { * Calculates the sample likelihood and BIC score for i given its parents in a simple SEM model */ public double localScore(int i, int... parents) { - ConditionalGaussianLikelihood.Ret ret = likelihood.getLikelihoodRatio(i, parents); + likelihood.setDenominatorMixed(denominatorMixed); + likelihood.setPenaltyDiscount(penaltyDiscount); - int N = dataSet.getNumRows(); + ConditionalGaussianLikelihood.Ret ret = likelihood.getLikelihood(i, parents); + int N = dataSet.getNumRows(); double lik = ret.getLik(); int k = ret.getDof(); - double prior = getStructurePrior(parents); - return 2.0 * lik - k * Math.log(N) + prior; + return 2.0 * lik - k * Math.log(N); } private double getStructurePrior(int[] parents) { @@ -146,6 +150,18 @@ public Node getVariable(String targetName) { public int getMaxDegree() { return (int) Math.ceil(Math.log(dataSet.getNumRows())); } + + public double getPenaltyDiscount() { + return penaltyDiscount; + } + + public void setPenaltyDiscount(double penaltyDiscount) { + this.penaltyDiscount = penaltyDiscount; + } + + public void setDenominatorMixed(boolean denominatorMixed) { + this.denominatorMixed = denominatorMixed; + } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/DMSearch.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/DMSearch.java index 2c6c66d3cc..1051c3b30e 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/DMSearch.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/DMSearch.java @@ -33,7 +33,7 @@ public class DMSearch { private int minDiscount = 4; //If true, use GES, else use PC. - private boolean useFgs = true; + private boolean useFges = true; //Lets the user select a subset of the inputs in the dataset to search over. //If not subseting, should be set to the entire input set. @@ -112,8 +112,8 @@ public void setDiscount(double discount) { this.gesDiscount = discount; } - public void setUseFgs(boolean set) { - this.useFgs = set; + public void setUseFges(boolean set) { + this.useFges = set; } @@ -158,11 +158,11 @@ public Graph search() { Graph pattern = new EdgeListGraph(); - if (useFgs) { + if (useFges) { Score score = new SemBicScore(cov); - Fgs fgs = new Fgs(score); + Fges fges = new Fges(score); - pattern = recursiveFgs(pattern, knowledge, this.gesDiscount, getMinDepth(), data, inputString); + pattern = recursiveFges(pattern, knowledge, this.gesDiscount, getMinDepth(), data, inputString); } else { this.cov = new CovarianceMatrixOnTheFly(data); // PC pc = new PC(new IndTestFisherZ(cov, this.alphaPC)); @@ -624,7 +624,7 @@ private boolean allEqual(SortedSet set1, SortedSet set2) { } // Uses previous runs of GES as new knowledge for a additional runs of GES with lower penalty discounts. - private Graph recursiveFgs(Graph previousGES, Knowledge2 knowledge, double penalty, double minPenalty, DataSet data, Set inputString) { + private Graph recursiveFges(Graph previousGES, Knowledge2 knowledge, double penalty, double minPenalty, DataSet data, Set inputString) { for (Edge edge : previousGES.getEdges()) { knowledge.setRequired(edge.getNode1().getName(), edge.getNode2().getName()); @@ -634,12 +634,12 @@ private Graph recursiveFgs(Graph previousGES, Knowledge2 knowledge, double penal SemBicScore score = new SemBicScore(cov); score.setPenaltyDiscount(penalty); - Fgs fgs = new Fgs(score); - fgs.setKnowledge(knowledge); -// fgs.setMaxIndegree(this.gesDepth); -// fgs.setIgnoreLinearDependent(true); + Fges fges = new Fges(score); + fges.setKnowledge(knowledge); +// fges.setMaxIndegree(this.gesDepth); +// fges.setIgnoreLinearDependent(true); - Graph pattern = fgs.search(); + Graph pattern = fges.search(); //Saves GES output in case is needed. File file = new File("src/edu/cmu/tetradproj/amurrayw/ges_output_" + penalty + "_.txt"); @@ -656,7 +656,7 @@ private Graph recursiveFgs(Graph previousGES, Knowledge2 knowledge, double penal if (penalty > minPenalty) { applyDmSearch(pattern, inputString, penalty); - return (recursiveFgs(pattern, knowledge, penalty - 1, minPenalty, data, inputString)); + return (recursiveFges(pattern, knowledge, penalty - 1, minPenalty, data, inputString)); } else { applyDmSearch(pattern, inputString, penalty); return (pattern); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FasStableConcurrent.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FasStableConcurrent.java index 61f390d6e4..b3d89dbd8b 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FasStableConcurrent.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FasStableConcurrent.java @@ -100,7 +100,7 @@ public class FasStableConcurrent implements IFas { */ private PrintStream out = System.out; - int chunk = 50; + private int chunk = 100; private boolean recordSepsets = true; @@ -333,16 +333,14 @@ protected Boolean compute() { return true; } else { - List tasks = new ArrayList<>(); - final int mid = (to + from) / 2; Depth0Task left = new Depth0Task(chunk, from, mid); - tasks.add(left); Depth0Task right = new Depth0Task(chunk, mid, to); - tasks.add(right); - invokeAll(tasks); + left.fork(); + right.compute(); + left.join(); return true; } @@ -442,6 +440,7 @@ protected Boolean compute() { EDGE: for (Node y : adjx) { List _adjx = new ArrayList<>(adjx); + _adjx.remove(y); List ppx = possibleParents(x, _adjx, knowledge); @@ -449,6 +448,7 @@ protected Boolean compute() { ChoiceGenerator cg = new ChoiceGenerator(ppx.size(), depth); int[] choice; + COND: while ((choice = cg.next()) != null) { List condSet = GraphUtils.asList(choice, ppx); @@ -472,13 +472,6 @@ protected Boolean compute() { getSepsets().set(x, y, condSet); } - // This creates a bottleneck for the parallel search. -// if (verbose) { -// TetradLogger.getInstance().log("independencies", SearchLogUtils.independenceFact(x, y, condSet) + " p = " + -// nf.format(test.getPValue())); -// out.println(SearchLogUtils.independenceFactMsg(x, y, condSet, test.getPValue())); -// } - continue EDGE; } } @@ -488,16 +481,14 @@ protected Boolean compute() { return true; } else { - List tasks = new ArrayList<>(); - final int mid = (to + from) / 2; DepthTask left = new DepthTask(chunk, from, mid); - tasks.add(left); DepthTask right = new DepthTask(chunk, mid, to); - tasks.add(right); - invokeAll(tasks); + left.fork(); + right.compute(); + left.join(); return true; } @@ -608,6 +599,7 @@ public boolean isRecordSepsets() { public void setRecordSepsets(boolean recordSepsets) { this.recordSepsets = recordSepsets; } + } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FasStableConcurrent2.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FasStableConcurrentFdr.java similarity index 68% rename from tetrad-lib/src/main/java/edu/cmu/tetrad/search/FasStableConcurrent2.java rename to tetrad-lib/src/main/java/edu/cmu/tetrad/search/FasStableConcurrentFdr.java index 88de1291ff..2002d0439b 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FasStableConcurrent2.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FasStableConcurrentFdr.java @@ -26,6 +26,7 @@ import edu.cmu.tetrad.graph.*; import edu.cmu.tetrad.util.ChoiceGenerator; import edu.cmu.tetrad.util.ForkJoinPoolInstance; +import edu.cmu.tetrad.util.StatUtils; import edu.cmu.tetrad.util.TetradLogger; import java.io.PrintStream; @@ -50,7 +51,7 @@ * * @author Joseph Ramsey. */ -public class FasStableConcurrent2 implements IFas { +public class FasStableConcurrentFdr implements IFas { /** * The independence test. This should be appropriate to the types @@ -102,22 +103,23 @@ public class FasStableConcurrent2 implements IFas { */ private PrintStream out = System.out; - int chunk = 50; + private int chunk = 100; + private boolean recordSepsets = true; //==========================CONSTRUCTORS=============================// /** * Constructs a new FastAdjacencySearch. wd */ - public FasStableConcurrent2(IndependenceTest test) { + public FasStableConcurrentFdr(IndependenceTest test) { this.test = test; } /** * Constructs a new FastAdjacencySearch. */ - public FasStableConcurrent2(Graph initialGraph, IndependenceTest test) { + public FasStableConcurrentFdr(Graph initialGraph, IndependenceTest test) { this.test = test; this.initialGraph = initialGraph; } @@ -160,29 +162,21 @@ public Graph search() { adjacencies.put(node, new HashSet()); } - double alpha = test.getAlpha(); - for (double _alpha = 0.9; _alpha > alpha; _alpha /= 2.0) { - System.out.println("_alpha = " + _alpha); - searchAtDepth0(nodes, test, adjacencies); - - test.setAlpha(_alpha); - - boolean didIt = false; - - for (int d = didIt ? 1 : 0; d <= _depth; d++) { - boolean more; + for (int d = 0; d <= _depth; d++) { + boolean more; + if (d == 0) { + more = searchAtDepth0(nodes, test, adjacencies); + } else { more = searchAtDepth(nodes, test, adjacencies, d); + } - if (!more) { - break; - } + if (!more) { + break; } } - test.setAlpha(alpha); - if (verbose) { out.println("Finished with search, constructing Graph..."); } @@ -267,6 +261,9 @@ private boolean searchAtDepth0(final List nodes, final IndependenceTest te } final List empty = Collections.emptyList(); + final Map scoredEdges = new ConcurrentSkipListMap<>(); + + final List sorted = new ArrayList<>(); class Depth0Task extends RecursiveTask { private int chunk; @@ -301,13 +298,77 @@ protected Boolean compute() { } } - boolean independent; + try { + test.isIndependent(x, y, empty); + } catch (Exception e) { + e.printStackTrace(); + } + + numIndependenceTests++; + + final double pValue = test.getPValue(); + + sorted.add(pValue); + } + } + + return true; + } else { + final int mid = (to + from) / 2; + + Depth0Task left = new Depth0Task(chunk, from, mid); + Depth0Task right = new Depth0Task(chunk, mid, to); + + left.fork(); + right.compute(); + left.join(); + + return true; + } + } + } + + pool.invoke(new Depth0Task(chunk, 0, nodes.size())); + Collections.sort(sorted); + final double cutoff = StatUtils.fdrCutoff(test.getAlpha(), sorted, false, true); + + class Depth0Task2 extends RecursiveTask { + private int chunk; + private int from; + private int to; + + public Depth0Task2(int chunk, int from, int to) { + this.chunk = chunk; + this.from = from; + this.to = to; + } + + @Override + protected Boolean compute() { + if (to - from <= chunk) { + for (int i = from; i < to; i++) { + if (verbose) { + if ((i + 1) % 1000 == 0) System.out.println("i = " + (i + 1)); + } + + final Node x = nodes.get(i); + + for (int j = 0; j < i; j++) { + final Node y = nodes.get(j); + + if (initialGraph != null) { + Node x2 = initialGraph.getNode(x.getName()); + Node y2 = initialGraph.getNode(y.getName()); + + if (!initialGraph.isAdjacentTo(x2, y2)) { + continue; + } + } try { - independent = test.isIndependent(x, y, empty); + test.isIndependent(x, y, empty); } catch (Exception e) { e.printStackTrace(); - independent = true; } numIndependenceTests++; @@ -315,19 +376,10 @@ protected Boolean compute() { boolean noEdgeRequired = knowledge.noEdgeRequired(x.getName(), y.getName()); - if (independent && noEdgeRequired) { - if (!sepsets.isReturnEmptyIfNotSet()) { + if (test.getPValue() > cutoff && noEdgeRequired) { + if (recordSepsets && !sepsets.isReturnEmptyIfNotSet()) { getSepsets().set(x, y, empty); } - - // This creates a bottleneck for the parallel search. -// if (verbose) { -// TetradLogger.getInstance().log("independencies", SearchLogUtils.independenceFact(x, y, empty) + " p = " + -// nf.format(test.getPValue())); -// -// out.println(SearchLogUtils.independenceFact(x, y, empty) + " p = " + -// nf.format(test.getPValue())); -// } } else if (!forbiddenEdge(x, y)) { adjacencies.get(x).add(y); adjacencies.get(y).add(x); @@ -342,23 +394,21 @@ protected Boolean compute() { return true; } else { - List tasks = new ArrayList<>(); - final int mid = (to + from) / 2; Depth0Task left = new Depth0Task(chunk, from, mid); - tasks.add(left); Depth0Task right = new Depth0Task(chunk, mid, to); - tasks.add(right); - invokeAll(tasks); + left.fork(); + right.compute(); + left.join(); return true; } } } - pool.invoke(new Depth0Task(chunk, 0, nodes.size())); + pool.invoke(new Depth0Task2(chunk, 0, nodes.size())); return freeDegree(nodes, adjacencies) > 0; } @@ -411,7 +461,8 @@ private int freeDegree(List nodes, Map> adjacencies) { // return false; // } - private boolean searchAtDepth(final List nodes, final IndependenceTest test, final Map> adjacencies, + private boolean searchAtDepth(final List nodes, final IndependenceTest test, + final Map> adjacencies, final int depth) { if (verbose) { @@ -425,6 +476,8 @@ private boolean searchAtDepth(final List nodes, final IndependenceTest tes adjacenciesCopy.put(node, new HashSet<>(adjacencies.get(node))); } + final List sorted = new ArrayList<>(); + class DepthTask extends RecursiveTask { private int chunk; private int from; @@ -450,6 +503,10 @@ protected Boolean compute() { EDGE: for (Node y : adjx) { + if (!existsShortPath(x, y, 3, adjacencies)) { + continue; + } + List _adjx = new ArrayList<>(adjx); _adjx.remove(y); List ppx = possibleParents(x, _adjx, knowledge); @@ -474,19 +531,89 @@ protected Boolean compute() { knowledge.noEdgeRequired(x.getName(), y.getName()); if (independent && noEdgeRequired) { - adjacencies.get(x).remove(y); - adjacencies.get(y).remove(x); + sorted.add(test.getPValue()); + continue EDGE; + } + } + } + } + } + + return true; + } else { + final int mid = (to + from) / 2; - getSepsets().set(x, y, condSet); + DepthTask left = new DepthTask(chunk, from, mid); + DepthTask right = new DepthTask(chunk, mid, to); - // This creates a bottleneck for the parallel search. -// if (verbose) { -// TetradLogger.getInstance().log("independencies", SearchLogUtils.independenceFact(x, y, condSet) + " p = " + -// nf.format(test.getPValue())); -// out.println(SearchLogUtils.independenceFactMsg(x, y, condSet, test.getPValue())); -// } + left.fork(); + right.compute(); + left.join(); - continue EDGE; + return true; + } + } + } + + pool.invoke(new DepthTask(chunk, 0, nodes.size())); + + Collections.sort(sorted); + final double cutoff = StatUtils.fdrCutoff(test.getAlpha(), sorted, false, true); + + System.out.println(); + + class DepthTask2 extends RecursiveTask { + private int chunk; + private int from; + private int to; + + public DepthTask2(int chunk, int from, int to) { + this.chunk = chunk; + this.from = from; + this.to = to; + } + + @Override + protected Boolean compute() { + if (to - from <= chunk) { + for (int i = from; i < to; i++) { + if (verbose) { + if ((i + 1) % 1000 == 0) System.out.println("i = " + (i + 1)); + } + + Node x = nodes.get(i); + + List adjx = new ArrayList<>(adjacenciesCopy.get(x)); + + EDGE: + for (Node y : adjx) { + List _adjx = new ArrayList<>(adjx); + _adjx.remove(y); + List ppx = possibleParents(x, _adjx, knowledge); + + if (ppx.size() >= depth) { + ChoiceGenerator cg = new ChoiceGenerator(ppx.size(), depth); + int[] choice; + + while ((choice = cg.next()) != null) { + List condSet = GraphUtils.asList(choice, ppx); + + try { + numIndependenceTests++; + test.isIndependent(x, y, condSet); + + if (test.getPValue() > cutoff) { + adjacencies.get(x).remove(y); + adjacencies.get(y).remove(x); + + if (recordSepsets) { + getSepsets().set(x, y, condSet); + } + + continue EDGE; + } + } catch (Exception e) { + e.printStackTrace(); } } } @@ -495,23 +622,21 @@ protected Boolean compute() { return true; } else { - List tasks = new ArrayList<>(); - final int mid = (to + from) / 2; DepthTask left = new DepthTask(chunk, from, mid); - tasks.add(left); DepthTask right = new DepthTask(chunk, mid, to); - tasks.add(right); - invokeAll(tasks); + left.fork(); + right.compute(); + left.join(); return true; } } } - pool.invoke(new DepthTask(chunk, 0, nodes.size())); + pool.invoke(new DepthTask2(chunk, 0, nodes.size())); if (verbose) { System.out.println("Done with depth"); @@ -604,6 +729,53 @@ public void setOut(PrintStream out) { public PrintStream getOut() { return out; } + + /** + * True if sepsets should be recorded. This is not necessary for all algorithms. + */ + public boolean isRecordSepsets() { + return recordSepsets; + } + + public void setRecordSepsets(boolean recordSepsets) { + this.recordSepsets = recordSepsets; + } + + private boolean existsShortPath(Node x, Node z, int bound, final Map> adjacencies) { + Queue Q = new LinkedList<>(); + Set V = new HashSet<>(); + Q.offer(x); + V.add(x); + Node e = null; + int distance = 0; + + while (!Q.isEmpty()) { + Node t = Q.remove(); + + if (e == t) { + e = null; + distance++; + if (distance > (bound == -1 ? 1000 : bound)) return false; + } + + for (Node c : adjacencies.get(t)) { + if (c == null) continue; +// if (t == y && c == z && distance > 2) continue; + if (t != x && t != z && c == z) return true; + + if (!V.contains(c)) { + V.add(c); + Q.offer(c); + + if (e == null) { + e = c; + } + } + } + } + + return false; + } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fci.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fci.java index 1fe48d1eed..16c172af00 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fci.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fci.java @@ -232,7 +232,7 @@ public Graph search(IFas fas) { // // System.out.println("Starting possible dsep search"); // PossibleDsepFci possibleDSep = new PossibleDsepFci(graph, independenceTest); -// possibleDSep.setMaxIndegree(getPossibleDsepDepth()); +// possibleDSep.setMaxDegree(getPossibleDsepDepth()); // possibleDSep.setKnowledge(getKnowledge()); // possibleDSep.setMaxPathLength(maxPathLength); // this.sepsets.addAll(possibleDSep.search()); @@ -260,6 +260,7 @@ public Graph search(IFas fas) { fciOrient.setKnowledge(knowledge); fciOrient.ruleR0(graph); fciOrient.doFinalOrientation(graph); + graph.setPag(true); return graph; } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FciMax.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FciMax.java index 87fab63694..df1910108e 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FciMax.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FciMax.java @@ -21,17 +21,18 @@ package edu.cmu.tetrad.search; -import edu.cmu.tetrad.data.ICovarianceMatrix; import edu.cmu.tetrad.data.IKnowledge; import edu.cmu.tetrad.data.Knowledge2; import edu.cmu.tetrad.data.KnowledgeEdge; import edu.cmu.tetrad.graph.*; import edu.cmu.tetrad.util.ChoiceGenerator; +import edu.cmu.tetrad.util.DepthChoiceGenerator; +import edu.cmu.tetrad.util.ForkJoinPoolInstance; import edu.cmu.tetrad.util.TetradLogger; import java.util.*; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.RecursiveTask; /** @@ -43,10 +44,8 @@ * done by extending doFinalOrientation() with methods for Zhang's rules R5-R10 which implements the augmented search. * (By a remark of Zhang's, the rule applications can be staged in this way.) * - * @author Erin Korber, June 2004 - * @author Alex Smith, December 2008 * @author Joseph Ramsey - * @author Choh-Man Teng + * @author Vineet Rhagu */ public final class FciMax implements GraphSearch { @@ -106,12 +105,15 @@ public final class FciMax implements GraphSearch { * True iff verbose output should be printed. */ private boolean verbose = false; - private Graph trueDag; - private ConcurrentMap hashIndices; - private ICovarianceMatrix covarianceMatrix; - private double penaltyDiscount = 2; - private SepsetMap possibleDsepSepsets = new SepsetMap(); + + /** + * An initial graph if there is one. + */ private Graph initialGraph; + + /** + * Max path length for the possible dsep search. + */ private int possibleDsepDepth = -1; @@ -127,34 +129,6 @@ public FciMax(IndependenceTest independenceTest) { this.independenceTest = independenceTest; this.variables.addAll(independenceTest.getVariables()); - buildIndexing(independenceTest.getVariables()); - } - - /** - * Constructs a new FCI search for the given independence test and background knowledge and a list of variables to - * search over. - */ - public FciMax(IndependenceTest independenceTest, List searchVars) { - if (independenceTest == null || knowledge == null) { - throw new NullPointerException(); - } - - this.independenceTest = independenceTest; - this.variables.addAll(independenceTest.getVariables()); - - Set remVars = new HashSet<>(); - for (Node node1 : this.variables) { - boolean search = false; - for (Node node2 : searchVars) { - if (node1.getName().equals(node2.getName())) { - search = true; - } - } - if (!search) { - remVars.add(node1); - } - } - this.variables.removeAll(remVars); } //========================PUBLIC METHODS==========================// @@ -177,14 +151,9 @@ public long getElapsedTime() { } public Graph search() { - return search(getIndependenceTest().getVariables()); - } - - public Graph search(List nodes) { FasStableConcurrent fas = new FasStableConcurrent(initialGraph, getIndependenceTest()); fas.setVerbose(verbose); return search(fas); -// return search(new Fas(getIndependenceTest())); } public void setInitialGraph(Graph initialGraph) { @@ -204,20 +173,11 @@ public Graph search(IFas fas) { graph.reorientAllWith(Endpoint.CIRCLE); -// SepsetProducer sp = new SepsetsPossibleDsep(graph, independenceTest, knowledge, depth, maxPathLength); SepsetProducer sp = new SepsetsMaxPValuePossDsep(graph, independenceTest, null, depth, maxPathLength); - SepsetProducer sp2 = new SepsetsMaxScore(graph, independenceTest, null, depth); // The original FCI, with or without JiJi Zhang's orientation rules - // // Optional step: Possible Dsep. (Needed for correctness but very time consuming.) if (isPossibleDsepSearchDone()) { -// long time1 = System.currentTimeMillis(); -// new FciOrient(new SepsetsSet(this.sepsets, independenceTest)).ruleR0(graph); -// SepsetProducer sepsetProducer = new SepsetsSet(sepsets, independenceTest); - - addColliders(graph, sp2, knowledge); - - System.out.println("Possible dsep add colliders done"); + addColliders(graph); for (Edge edge : new ArrayList<>(graph.getEdges())) { Node x = edge.getNode1(); @@ -228,38 +188,14 @@ public Graph search(IFas fas) { if (sepset != null) { graph.removeEdge(x, y); sepsets.set(x, y, sepset); - System.out.println("Possible DSEP Removed " + x + "--- " + y + " sepset = " + sepset); } } - System.out.println("Possible dsep done"); - - -// long time2 = System.currentTimeMillis(); -// logger.log("info", "Step C: " + (time2 - time1) / 1000. + "s"); -// -// // Step FCI D. -// long time3 = System.currentTimeMillis(); -// -// System.out.println("Starting possible dsep search"); -// PossibleDsepFci possibleDSep = new PossibleDsepFci(graph, independenceTest); -// possibleDSep.setMaxIndegree(getPossibleDsepDepth()); -// possibleDSep.setKnowledge(getKnowledge()); -// possibleDSep.setMaxPathLength(maxPathLength); -// this.sepsets.addAll(possibleDSep.search()); -// long time4 = System.currentTimeMillis(); -// logger.log("info", "Step D: " + (time4 - time3) / 1000. + "s"); -// System.out.println("Starting possible dsep search"); - // Reorient all edges as o-o. graph.reorientAllWith(Endpoint.CIRCLE); - - System.out.println("Reoriented with circles"); } - addColliders(graph, sp2, knowledge); - - System.out.println("Added colliders again"); + addColliders(graph); // Step CI C (Zhang's step F3.) long time5 = System.currentTimeMillis(); @@ -273,109 +209,137 @@ public Graph search(IFas fas) { fciOrient.setMaxPathLength(maxPathLength); fciOrient.setKnowledge(knowledge); fciOrient.doFinalOrientation(graph); - - System.out.println("Final orientation done"); + graph.setPag(true); return graph; } - /** - * Step C of PC; orients colliders using specified sepset. That is, orients x *-* y *-* z as x *-> y <-* z just in - * case y is in Sepset({x, z}). - */ - public Map findCollidersUsingSepsets(SepsetProducer sepsetProducer, Graph graph, boolean verbose, IKnowledge knowledge) { - TetradLogger.getInstance().log("details", "Starting Collider Orientation:"); - Map colliders = new HashMap<>(); + private void addColliders(Graph graph) { + final Map scores = new ConcurrentHashMap<>(); List nodes = graph.getNodes(); - for (Node b : nodes) { - List adjacentNodes = graph.getAdjacentNodes(b); - - if (adjacentNodes.size() < 2) { - continue; + class Task extends RecursiveTask { + int from; + int to; + int chunk = 20; + List nodes; + Graph graph; + + public Task(List nodes, Graph graph, Map scores, int from, int to) { + this.nodes = nodes; + this.graph = graph; + this.from = from; + this.to = to; } - ChoiceGenerator cg = new ChoiceGenerator(adjacentNodes.size(), 2); - int[] combination; + @Override + protected Boolean compute() { + if (to - from <= chunk) { + for (int i = from; i < to; i++) { + doNode(graph, scores, nodes.get(i)); + } + + return true; + } else { + int mid = (to + from) / 2; - while ((combination = cg.next()) != null) { - Node a = adjacentNodes.get(combination[0]); - Node c = adjacentNodes.get(combination[1]); + Task left = new Task(nodes, graph, scores, from, mid); + Task right = new Task(nodes, graph, scores, mid, to); - // Skip triples that are shielded. - if (graph.isAdjacentTo(a, c)) { - continue; + left.fork(); + right.compute(); + left.join(); + + return true; } + } + } - List sepset = sepsetProducer.getSepset(a, c); + Task task = new Task(nodes, graph, scores, 0, nodes.size()); - if (sepset == null) continue; + ForkJoinPoolInstance.getInstance().getPool().invoke(task); -// - if (sepsetProducer.getScore() < getIndependenceTest().getAlpha()) continue; + List tripleList = new ArrayList<>(scores.keySet()); - if (!sepset.contains(b)) { - System.out.println("sepset = " + sepset + " b = " + b + " p = " + sepsetProducer.getScore()); + // Most independent ones first. + Collections.sort(tripleList, new Comparator() { - if (verbose) { - System.out.println("Collider orientation <" + a + ", " + b + ", " + c + "> sepset = " + sepset); - } + @Override + public int compare(Triple o1, Triple o2) { + return Double.compare(scores.get(o2), scores.get(o1)); + } + }); - IndependenceTest test2 = new IndTestDSep(trueDag); - SepsetProducer sp2 = new SepsetsMaxScore(graph, test2, null, depth); + for (Triple triple : tripleList) { + Node a = triple.getX(); + Node b = triple.getY(); + Node c = triple.getZ(); - System.out.println("Dsep sepset = " + sp2.getSepset(a, c)); + graph.setEndpoint(a, b, Endpoint.ARROW); + graph.setEndpoint(c, b, Endpoint.ARROW); + } + } - colliders.put(new Triple(a, b, c), sepsetProducer.getScore()); + private void doNode(Graph graph, Map scores, Node b) { + List adjacentNodes = graph.getAdjacentNodes(b); -// colliders.add(new Triple(a, b, c)); - TetradLogger.getInstance().log("colliderOrientations", SearchLogUtils.colliderOrientedMsg(a, b, c, sepset)); - } - } + if (adjacentNodes.size() < 2) { + return; } - TetradLogger.getInstance().log("details", "Finishing Collider Orientation."); + ChoiceGenerator cg = new ChoiceGenerator(adjacentNodes.size(), 2); + int[] combination; - System.out.println("Done finding colliders"); + while ((combination = cg.next()) != null) { + Node a = adjacentNodes.get(combination[0]); + Node c = adjacentNodes.get(combination[1]); - return colliders; - } + // Skip triples that are shielded. + if (graph.isAdjacentTo(a, c)) { + continue; + } - private void addColliders(Graph graph, final SepsetProducer sepsetProducer, IKnowledge knowledge) { - final Map collidersPs = findCollidersUsingSepsets(sepsetProducer, graph, verbose, knowledge); + List adja = graph.getAdjacentNodes(a); + double score = Double.POSITIVE_INFINITY; + List S = null; - final List colliders = new ArrayList<>(collidersPs.keySet()); + DepthChoiceGenerator cg2 = new DepthChoiceGenerator(adja.size(), -1); + int[] comb2; - for (Triple collider : colliders) { - if (collidersPs.get(collider) < getIndependenceTest().getAlpha()) continue; + while ((comb2 = cg2.next()) != null) { + List s = GraphUtils.asList(comb2, adja); + independenceTest.isIndependent(a, c, s); + double _score = independenceTest.getScore(); - Node a = collider.getX(); - Node b = collider.getY(); - Node c = collider.getZ(); + if (_score < score) { + score = _score; + S = s; + } + } -// if (!(isArrowpointAllowed(a, b, knowledge) && isArrowpointAllowed(c, b, knowledge))) { -// continue; -// } + List adjc = graph.getAdjacentNodes(c); -// if (!graph.getEdge(a, b).pointsTowards(a) && !graph.getEdge(b, c).pointsTowards(c)) { -// graph.removeEdge(a, b); -// graph.removeEdge(c, b); -// graph.addDirectedEdge(a, b); -// graph.addDirectedEdge(c, b); -// } + DepthChoiceGenerator cg3 = new DepthChoiceGenerator(adjc.size(), -1); + int[] comb3; - graph.setEndpoint(a, b, Endpoint.ARROW); - graph.setEndpoint(c, b, Endpoint.ARROW); - } - } + while ((comb3 = cg3.next()) != null) { + List s = GraphUtils.asList(comb3, adjc); + independenceTest.isIndependent(c, a, s); + double _score = independenceTest.getScore(); - private static List union(List nodes, Node a) { - List union = new ArrayList<>(nodes); - union.add(a); - return union; - } + if (_score < score) { + score = _score; + S = s; + } + } + // S actually has to be non-null here, but the compiler doesn't know that. + if (S != null && !S.contains(b)) { + scores.put(new Triple(a, b, c), score); + } + } + } public SepsetMap getSepsets() { return this.sepsets; @@ -453,27 +417,8 @@ public IndependenceTest getIndependenceTest() { return independenceTest; } - public void setTrueDag(Graph trueDag) { - this.trueDag = trueDag; - } - - public double getPenaltyDiscount() { - return penaltyDiscount; - } - - public void setPenaltyDiscount(double penaltyDiscount) { - this.penaltyDiscount = penaltyDiscount; - } - //===========================PRIVATE METHODS=========================// - private void buildIndexing(List nodes) { - this.hashIndices = new ConcurrentHashMap<>(); - for (Node node : nodes) { - this.hashIndices.put(node, variables.indexOf(node)); - } - } - /** * Orients according to background knowledge */ @@ -526,14 +471,6 @@ private void fciOrientbk(IKnowledge bk, Graph graph, List variables) { logger.log("info", "Finishing BK Orientation."); } - - public int getPossibleDsepDepth() { - return possibleDsepDepth; - } - - public void setPossibleDsepDepth(int possibleDsepDepth) { - this.possibleDsepDepth = possibleDsepDepth; - } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fgs.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fges.java similarity index 98% rename from tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fgs.java rename to tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fges.java index 4f2ebf979e..098f1b179e 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fgs.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fges.java @@ -52,7 +52,7 @@ * @author Ricardo Silva, Summer 2003 * @author Joseph Ramsey, Revisions 5/2015 */ -public final class Fgs implements GraphSearch, GraphScorer { +public final class Fges implements GraphSearch, GraphScorer { /** @@ -184,7 +184,7 @@ private enum Mode { * values in case of conditional independence. See Chickering (2002), * locally consistent scoring criterion. */ - public Fgs(Score score) { + public Fges(Score score) { if (score == null) throw new NullPointerException(); setScore(score); this.graph = new EdgeListGraphSingleConnections(getVariables()); @@ -856,6 +856,9 @@ private void fes() { continue; } + if (graph.getDegree(x) > maxDegree - 1) continue; + if (graph.getDegree(y) > maxDegree - 1) continue; + if (!arrow.getNaYX().equals(getNaYX(x, y))) { continue; } @@ -868,9 +871,6 @@ private void fes() { continue; } - if (graph.getDegree(x) > maxDegree - 1) continue; - if (graph.getDegree(y) > maxDegree - 1) continue; - Set T = arrow.getHOrT(); double bump = arrow.getBump(); @@ -1117,7 +1117,7 @@ private void calculateArrowsForward(Node a, Node b) { } Set naYX = getNaYX(a, b); - if (!isClique(naYX)) return; + if (!GraphUtils.isClique(naYX, this.graph)) return; List TNeighbors = getTNeighbors(a, b); @@ -1149,7 +1149,7 @@ private void calculateArrowsForward(Node a, Node b) { break FOR; } - if (!isClique(union)) continue; + if (!GraphUtils.isClique(union, this.graph)) continue; newCliques.add(union); double bump = insertEval(a, b, T, naYX, hashIndices); @@ -1552,7 +1552,7 @@ private boolean validInsert(Node x, Node y, Set T, Set naYX) { Set union = new HashSet<>(T); union.addAll(naYX); - boolean clique = isClique(union); + boolean clique = GraphUtils.isClique(union, this.graph); boolean noCycle = !existsUnblockedSemiDirectedPath(y, x, union, cycleBound); return clique && noCycle && !violatesKnowledge; } @@ -1574,7 +1574,7 @@ private boolean validDelete(Node x, Node y, Set H, Set naYX) { Set diff = new HashSet<>(naYX); diff.removeAll(H); - return isClique(diff) && !violatesKnowledge; + return GraphUtils.isClique(diff, this.graph) && !violatesKnowledge; } // Adds edges required by knowledge. @@ -1669,21 +1669,6 @@ private Set getNaYX(Node x, Node y) { return nayx; } - Set cliqueEdges = new HashSet<>(); - - // Returns true iif the given set forms a clique in the given graph. - private boolean isClique(Set nodes) { - List _nodes = new ArrayList<>(nodes); - for (int i = 0; i < _nodes.size() - 1; i++) { - for (int j = i + 1; j < _nodes.size(); j++) { - if (!graph.isAdjacentTo(_nodes.get(i), _nodes.get(j))) { - return false; - } - } - } - - return true; - } // Returns true if a path consisting of undirected and directed edges toward 'to' exists of // length at most 'bound'. Cycle checker in other words. diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FgsMb.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FgesMb.java similarity index 95% rename from tetrad-lib/src/main/java/edu/cmu/tetrad/search/FgsMb.java rename to tetrad-lib/src/main/java/edu/cmu/tetrad/search/FgesMb.java index 2133402e06..fdbc1ee4d1 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FgsMb.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FgesMb.java @@ -52,7 +52,7 @@ * @author Ricardo Silva, Summer 2003 * @author Joseph Ramsey, Revisions 5/2015 */ -public final class FgsMb { +public final class FgesMb { /** * Specification of forbidden and required edges. @@ -98,7 +98,7 @@ public final class FgsMb { /** * The score for discrete searches. */ - private Score fgsScore; + private Score fgesScore; /** * The logger for this class. The config needs to be set. @@ -171,14 +171,14 @@ public final class FgsMb { /** * The data set must either be all continuous or all discrete. */ - public FgsMb(Score score) { + public FgesMb(Score score) { if (verbose) { out.println("GES constructor"); } this.variables = score.getVariables(); - setFgsScore(score); + setFgesScore(score); this.graph = new EdgeListGraphSingleConnections(getVariables()); } @@ -216,7 +216,7 @@ public Graph search(List targets) { if (targets == null) throw new NullPointerException(); for (Node target : targets) { - if (!fgsScore.getVariables().contains(target)) throw new IllegalArgumentException( + if (!fgesScore.getVariables().contains(target)) throw new IllegalArgumentException( "Target is not one of the variables for the score." ); } @@ -226,7 +226,7 @@ public Graph search(List targets) { topGraphs.clear(); lookupArrows = new ConcurrentHashMap<>(); - final List nodes = new ArrayList<>(fgsScore.getVariables()); + final List nodes = new ArrayList<>(fgesScore.getVariables()); if (adjacencies != null) { adjacencies = GraphUtils.replaceNodes(adjacencies, nodes); @@ -338,14 +338,14 @@ public Graph getInitialGraph() { */ public void setInitialGraph(Graph initialGraph) { if (initialGraph != null) { - initialGraph = GraphUtils.replaceNodes(initialGraph, fgsScore.getVariables()); + initialGraph = GraphUtils.replaceNodes(initialGraph, fgesScore.getVariables()); if (verbose) { out.println("Initial graph variables: " + initialGraph.getNodes()); - out.println("Data set variables: " + fgsScore.getVariables()); + out.println("Data set variables: " + fgesScore.getVariables()); } - if (!new HashSet<>(initialGraph.getNodes()).equals(new HashSet<>(fgsScore.getVariables()))) { + if (!new HashSet<>(initialGraph.getNodes()).equals(new HashSet<>(fgesScore.getVariables()))) { throw new IllegalArgumentException("Variables aren't the same."); } } @@ -436,16 +436,16 @@ public void setParallelism(int numProcessors) { * True iff edges that cause linear dependence are ignored. */ public boolean isIgnoreLinearDependent() { - if (fgsScore instanceof SemBicScore) { - return ((SemBicScore) fgsScore).isIgnoreLinearDependent(); + if (fgesScore instanceof SemBicScore) { + return ((SemBicScore) fgesScore).isIgnoreLinearDependent(); } throw new UnsupportedOperationException("Operation supported only for SemBicScore."); } public void setIgnoreLinearDependent(boolean ignoreLinearDependent) { - if (fgsScore instanceof SemBicScore) { - ((SemBicScore) fgsScore).setIgnoreLinearDependent(ignoreLinearDependent); + if (fgesScore instanceof SemBicScore) { + ((SemBicScore) fgesScore).setIgnoreLinearDependent(ignoreLinearDependent); } else { throw new UnsupportedOperationException("Operation supported only for SemBicScore."); } @@ -464,8 +464,8 @@ public void setBoundGraph(Graph boundGraph) { * @deprecated Use the getters on the individual scores instead. */ public double getPenaltyDiscount() { - if (fgsScore instanceof ISemBicScore) { - return ((ISemBicScore) fgsScore).getPenaltyDiscount(); + if (fgesScore instanceof ISemBicScore) { + return ((ISemBicScore) fgesScore).getPenaltyDiscount(); } else { return 2.0; } @@ -475,8 +475,8 @@ public double getPenaltyDiscount() { * @deprecated Use the setters on the individual scores instead. */ public void setSamplePrior(double samplePrior) { - if (fgsScore instanceof LocalDiscreteScore) { - ((LocalDiscreteScore) fgsScore).setSamplePrior(samplePrior); + if (fgesScore instanceof LocalDiscreteScore) { + ((LocalDiscreteScore) fgesScore).setSamplePrior(samplePrior); } } @@ -484,8 +484,8 @@ public void setSamplePrior(double samplePrior) { * @deprecated Use the setters on the individual scores instead. */ public void setStructurePrior(double expectedNumParents) { - if (fgsScore instanceof LocalDiscreteScore) { - ((LocalDiscreteScore) fgsScore).setStructurePrior(expectedNumParents); + if (fgesScore instanceof LocalDiscreteScore) { + ((LocalDiscreteScore) fgesScore).setStructurePrior(expectedNumParents); } } @@ -495,20 +495,20 @@ public void setStructurePrior(double expectedNumParents) { * @deprecated Use the setters on the individual scores instead. */ public void setPenaltyDiscount(double penaltyDiscount) { - if (fgsScore instanceof ISemBicScore) { - ((ISemBicScore) fgsScore).setPenaltyDiscount(penaltyDiscount); + if (fgesScore instanceof ISemBicScore) { + ((ISemBicScore) fgesScore).setPenaltyDiscount(penaltyDiscount); } } //===========================PRIVATE METHODS========================// //Sets the discrete scoring function to use. - private void setFgsScore(Score fgsScore) { - this.fgsScore = fgsScore; + private void setFgesScore(Score fgesScore) { + this.fgesScore = fgesScore; this.variables = new ArrayList<>(); - List variables = fgsScore.getVariables(); + List variables = fgesScore.getVariables(); for (Node node : variables) { if (node.getNodeType() == NodeType.MEASURED) { @@ -549,7 +549,7 @@ private void calcDConnections(List targets) { int child = hashIndices.get(target); int parent = hashIndices.get(x); - double bump = fgsScore.localScoreDiff(parent, child); + double bump = fgesScore.localScoreDiff(parent, child); if (bump > 0) { dconn.addNode(x); @@ -562,7 +562,7 @@ private void calcDConnections(List targets) { int child2 = hashIndices.get(x); int parent2 = hashIndices.get(y); - double bump2 = fgsScore.localScoreDiff(parent2, child2); + double bump2 = fgesScore.localScoreDiff(parent2, child2); if (bump2 > 0) { dconn.addNode(y); @@ -598,7 +598,7 @@ private void addUnconditionalArrows(Node x, Node y, Set emptySet) { int child = hashIndices.get(y); int parent = hashIndices.get(x); - double bump = fgsScore.localScoreDiff(parent, child); + double bump = fgesScore.localScoreDiff(parent, child); if (boundGraph != null && !boundGraph.isAdjacentTo(x, y)) return; @@ -856,7 +856,7 @@ private void calculateArrowsForward(Node a, Node b) { } Set naYX = getNaYX(a, b); - if (!isClique(naYX)) return; + if (!GraphUtils.isClique(naYX, this.graph)) return; List TNeighbors = getTNeighbors(a, b); @@ -890,7 +890,7 @@ private void calculateArrowsForward(Node a, Node b) { break FOR; } - if (!isClique(union)) continue; + if (!GraphUtils.isClique(union, this.graph)) continue; newCliques.add(union); double bump = insertEval(a, b, T, naYX, hashIndices); @@ -1240,7 +1240,7 @@ private boolean validInsert(Node x, Node y, Set T, Set naYX) { Set union = new HashSet<>(T); union.addAll(naYX); - boolean clique = isClique(union); + boolean clique = GraphUtils.isClique(union, this.graph); boolean noCycle = !existsUnblockedSemiDirectedPath(y, x, union, cycleBound); return clique && noCycle && !violatesKnowledge; } @@ -1262,7 +1262,7 @@ private boolean validDelete(Node x, Node y, Set H, Set naYX) { Set diff = new HashSet<>(naYX); diff.removeAll(H); - return isClique(diff) && !violatesKnowledge; + return GraphUtils.isClique(diff, this.graph) && !violatesKnowledge; } // Adds edges required by knowledge. @@ -1357,20 +1357,6 @@ private Set getNaYX(Node x, Node y) { return nayx; } - // Returns true iif the given set forms a clique in the given graph. - private boolean isClique(Set nodes) { - List _nodes = new ArrayList<>(nodes); - for (int i = 0; i < _nodes.size() - 1; i++) { - for (int j = i + 1; j < _nodes.size(); j++) { - if (!graph.isAdjacentTo(_nodes.get(i), _nodes.get(j))) { - return false; - } - } - } - - return true; - } - // Returns true if a path consisting of undirected and directed edges toward 'to' exists of // length at most 'bound'. Cycle checker in other words. private boolean existsUnblockedSemiDirectedPath(Node from, Node to, Set cond, int bound) { @@ -1520,7 +1506,7 @@ public double scoreDag(Graph dag) { } int yIndex = hashIndices.get(y); - score += fgsScore.localScore(yIndex, parentIndices); + score += fgesScore.localScore(yIndex, parentIndices); } return score; @@ -1539,11 +1525,11 @@ private double scoreGraphChange(Node y, Set parents, parentIndices[count++] = hashIndices.get(parent); } - return fgsScore.localScoreDiff(hashIndices.get(x), yIndex, parentIndices); + return fgesScore.localScoreDiff(hashIndices.get(x), yIndex, parentIndices); } private List getVariables() { - return fgsScore.getVariables(); + return fgesScore.getVariables(); } // Stores the graph, if its score knocks out one of the top ones. diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FgsMb2.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FgesMb2.java similarity index 96% rename from tetrad-lib/src/main/java/edu/cmu/tetrad/search/FgsMb2.java rename to tetrad-lib/src/main/java/edu/cmu/tetrad/search/FgesMb2.java index c21bfadc68..843590cc19 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FgsMb2.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FgesMb2.java @@ -54,7 +54,7 @@ * @author Ricardo Silva, Summer 2003 * @author Joseph Ramsey, Revisions 5/2015 */ -public final class FgsMb2 { +public final class FgesMb2 { private List targets; @@ -110,7 +110,7 @@ private enum Mode { /** * The totalScore for discrete searches. */ - private Score fgsScore; + private Score fgesScore; /** * The logger for this class. The config needs to be set. @@ -193,9 +193,9 @@ private enum Mode { * values in case of conditional independence. See Chickering (2002), * locally consistent scoring criterion. */ - public FgsMb2(Score score) { + public FgesMb2(Score score) { if (score == null) throw new NullPointerException(); - setFgsScore(score); + setFgesScore(score); this.graph = new EdgeListGraphSingleConnections(getVariables()); } @@ -292,8 +292,8 @@ public Graph search(List targets) { if (targets == null) throw new NullPointerException(); for (Node target : targets) { - if (!fgsScore.getVariables().contains(target)) throw new IllegalArgumentException( - "Target is not one of the variables for the fgsScore." + if (!fgesScore.getVariables().contains(target)) throw new IllegalArgumentException( + "Target is not one of the variables for the fgesScore." ); } @@ -302,7 +302,7 @@ public Graph search(List targets) { topGraphs.clear(); lookupArrows = new ConcurrentHashMap<>(); - final List nodes = new ArrayList<>(fgsScore.getVariables()); + final List nodes = new ArrayList<>(fgesScore.getVariables()); if (adjacencies != null) { adjacencies = GraphUtils.replaceNodes(adjacencies, nodes); @@ -352,7 +352,7 @@ private void calcDConnections(List targets) { sortedArrows = new ConcurrentSkipListSet<>(); lookupArrows = new ConcurrentHashMap<>(); neighbors = new ConcurrentHashMap<>(); - final List nodes = fgsScore.getVariables(); + final List nodes = fgesScore.getVariables(); this.effectEdgesGraph = new EdgeListGraphSingleConnections(); @@ -363,14 +363,14 @@ private void calcDConnections(List targets) { final Set emptySet = new HashSet(); for (final Node target : targets) { - for (final Node x : fgsScore.getVariables()) { + for (final Node x : fgesScore.getVariables()) { if (targets.contains(x)) { continue; } int child = hashIndices.get(target); int parent = hashIndices.get(x); - double bump = fgsScore.localScoreDiff(parent, child); + double bump = fgesScore.localScoreDiff(parent, child); if (bump > 0) { synchronized (effectEdgesGraph) { @@ -388,7 +388,7 @@ public MbAboutNodeTask() { protected Boolean compute() { Queue tasks = new ArrayDeque<>(); - for (final Node y : fgsScore.getVariables()) { + for (final Node y : fgesScore.getVariables()) { if (x == y) continue; MbTask mbTask = new MbTask(x, y, target); @@ -439,7 +439,7 @@ protected Boolean compute() { int child2 = hashIndices.get(x); int parent2 = hashIndices.get(y); - double bump2 = fgsScore.localScoreDiff(parent2, child2); + double bump2 = fgesScore.localScoreDiff(parent2, child2); if (bump2 > 0) { synchronized (effectEdgesGraph) { @@ -472,7 +472,7 @@ private void addUnconditionalArrows(Node x, Node y, Set emptySet) { int child = hashIndices.get(y); int parent = hashIndices.get(x); - double bump = fgsScore.localScoreDiff(parent, child); + double bump = fgesScore.localScoreDiff(parent, child); if (boundGraph != null && !boundGraph.isAdjacentTo(x, y)) return; @@ -650,8 +650,8 @@ public void setBoundGraph(Graph boundGraph) { * @deprecated Use the getters on the individual scores instead. */ public double getPenaltyDiscount() { - if (fgsScore instanceof ISemBicScore) { - return ((ISemBicScore) fgsScore).getPenaltyDiscount(); + if (fgesScore instanceof ISemBicScore) { + return ((ISemBicScore) fgesScore).getPenaltyDiscount(); } else { return 2.0; } @@ -661,8 +661,8 @@ public double getPenaltyDiscount() { * @deprecated Use the setters on the individual scores instead. */ public void setSamplePrior(double samplePrior) { - if (fgsScore instanceof LocalDiscreteScore) { - ((LocalDiscreteScore) fgsScore).setSamplePrior(samplePrior); + if (fgesScore instanceof LocalDiscreteScore) { + ((LocalDiscreteScore) fgesScore).setSamplePrior(samplePrior); } } @@ -670,8 +670,8 @@ public void setSamplePrior(double samplePrior) { * @deprecated Use the setters on the individual scores instead. */ public void setStructurePrior(double expectedNumParents) { - if (fgsScore instanceof LocalDiscreteScore) { - ((LocalDiscreteScore) fgsScore).setStructurePrior(expectedNumParents); + if (fgesScore instanceof LocalDiscreteScore) { + ((LocalDiscreteScore) fgesScore).setStructurePrior(expectedNumParents); } } @@ -681,8 +681,8 @@ public void setStructurePrior(double expectedNumParents) { * @deprecated Use the setters on the individual scores instead. */ public void setPenaltyDiscount(double penaltyDiscount) { - if (fgsScore instanceof ISemBicScore) { - ((ISemBicScore) fgsScore).setPenaltyDiscount(penaltyDiscount); + if (fgesScore instanceof ISemBicScore) { + ((ISemBicScore) fgesScore).setPenaltyDiscount(penaltyDiscount); } } @@ -706,8 +706,8 @@ public void setMaxIndegree(int maxIndegree) { //===========================PRIVATE METHODS========================// //Sets the discrete scoring function to use. - private void setFgsScore(Score totalScore) { - this.fgsScore = totalScore; + private void setFgesScore(Score totalScore) { + this.fgesScore = totalScore; this.variables = new ArrayList<>(); @@ -719,7 +719,7 @@ private void setFgsScore(Score totalScore) { buildIndexing(totalScore.getVariables()); - this.maxIndegree = fgsScore.getMaxDegree(); + this.maxIndegree = fgesScore.getMaxDegree(); } final int[] count = new int[1]; @@ -771,7 +771,7 @@ protected Boolean compute() { int child = hashIndices.get(y); int parent = hashIndices.get(x); - double bump = fgsScore.localScoreDiff(parent, child); + double bump = fgesScore.localScoreDiff(parent, child); if (boundGraph != null && !boundGraph.isAdjacentTo(x, y)) continue; @@ -1320,7 +1320,7 @@ private void calculateArrowsForward(Node a, Node b) { } Set naYX = getNaYX(a, b); - if (!isClique(naYX)) return; + if (!GraphUtils.isClique(naYX, this.graph)) return; List TNeighbors = getTNeighbors(a, b); int _maxIndegree = maxIndegree == -1 ? 1000 : maxIndegree; @@ -1355,7 +1355,7 @@ private void calculateArrowsForward(Node a, Node b) { break FOR; } - if (!isClique(union)) continue; + if (!GraphUtils.isClique(union, this.graph)) continue; newCliques.add(union); double bump = insertEval(a, b, T, naYX, hashIndices); @@ -1757,7 +1757,7 @@ private boolean validInsert(Node x, Node y, Set T, Set naYX) { Set union = new HashSet<>(T); union.addAll(naYX); - boolean clique = isClique(union); + boolean clique = GraphUtils.isClique(union, this.graph); boolean noCycle = !existsUnblockedSemiDirectedPath(y, x, union, cycleBound); return clique && noCycle && !violatesKnowledge; } @@ -1779,7 +1779,7 @@ private boolean validDelete(Node x, Node y, Set H, Set naYX) { Set diff = new HashSet<>(naYX); diff.removeAll(H); - return isClique(diff) && !violatesKnowledge; + return GraphUtils.isClique(diff, this.graph) && !violatesKnowledge; } // Adds edges required by knowledge. @@ -1874,21 +1874,6 @@ private Set getNaYX(Node x, Node y) { return nayx; } - Set cliqueEdges = new HashSet<>(); - - // Returns true iif the given set forms a clique in the given graph. - private boolean isClique(Set nodes) { - List _nodes = new ArrayList<>(nodes); - for (int i = 0; i < _nodes.size() - 1; i++) { - for (int j = i + 1; j < _nodes.size(); j++) { - if (!graph.isAdjacentTo(_nodes.get(i), _nodes.get(j))) { - return false; - } - } - } - - return true; - } // Returns true if a path consisting of undirected and directed edges toward 'to' exists of // length at most 'bound'. Cycle checker in other words. @@ -2024,7 +2009,7 @@ public double scoreDag(Graph dag) { } int yIndex = hashIndices.get(y); - _score += fgsScore.localScore(yIndex, parentIndices); + _score += fgesScore.localScore(yIndex, parentIndices); } return _score; @@ -2043,7 +2028,7 @@ private double scoreGraphChange(Node y, Set parents, parentIndices[count++] = hashIndices.get(parent); } - return fgsScore.localScoreDiff(hashIndices.get(x), yIndex, parentIndices); + return fgesScore.localScoreDiff(hashIndices.get(x), yIndex, parentIndices); } private List getVariables() { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FgsOld.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FgesOld.java similarity index 96% rename from tetrad-lib/src/main/java/edu/cmu/tetrad/search/FgsOld.java rename to tetrad-lib/src/main/java/edu/cmu/tetrad/search/FgesOld.java index 0713313012..ce37c5ab4a 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FgsOld.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FgesOld.java @@ -50,7 +50,7 @@ * @author Ricardo Silva, Summer 2003 * @author Joseph Ramsey, Revisions 5/2015 */ -public final class FgsOld implements GraphSearch, GraphScorer { +public final class FgesOld implements GraphSearch, GraphScorer { /** * Specification of forbidden and required edges. @@ -96,7 +96,7 @@ public final class FgsOld implements GraphSearch, GraphScorer { /** * The score for discrete searches. */ - private Score fgsScore; + private Score fgesScore; /** * The logger for this class. The config needs to be set. @@ -167,7 +167,7 @@ public final class FgsOld implements GraphSearch, GraphScorer { * The data set must either be all continuous or all discrete. * @deprecated Construct a Score and pass it in instead. */ - public FgsOld(DataSet dataSet) { + public FgesOld(DataSet dataSet) { if (verbose) { out.println("GES constructor"); } @@ -175,8 +175,8 @@ public FgsOld(DataSet dataSet) { if (dataSet.isDiscrete()) { setScore(new BDeuScore(dataSet)); } else { - SemBicScore fgsScore = new SemBicScore(new CovarianceMatrixOnTheFly(dataSet)); - setScore(fgsScore); + SemBicScore fgesScore = new SemBicScore(new CovarianceMatrixOnTheFly(dataSet)); + setScore(fgesScore); } this.graph = new EdgeListGraphSingleConnections(getVariables()); @@ -190,7 +190,7 @@ public FgsOld(DataSet dataSet) { * Continuous case--where a covariance matrix is already available. * @deprecated Construct a Score and pass it in instead. */ - public FgsOld(ICovarianceMatrix covMatrix) { + public FgesOld(ICovarianceMatrix covMatrix) { if (verbose) { out.println("GES constructor"); } @@ -206,9 +206,9 @@ public FgsOld(ICovarianceMatrix covMatrix) { } } - public FgsOld(Score fgsScore) { - if (fgsScore == null) throw new NullPointerException(); - setScore(fgsScore); + public FgesOld(Score fgesScore) { + if (fgesScore == null) throw new NullPointerException(); + setScore(fgesScore); this.graph = new EdgeListGraphSingleConnections(getVariables()); } @@ -330,8 +330,8 @@ public long getElapsedTime() { * @deprecated Use the getters on the individual scores instead. */ public double getPenaltyDiscount() { - if (fgsScore instanceof ISemBicScore) { - return ((ISemBicScore) fgsScore).getPenaltyDiscount(); + if (fgesScore instanceof ISemBicScore) { + return ((ISemBicScore) fgesScore).getPenaltyDiscount(); } else { return 2.0; } @@ -342,8 +342,8 @@ public double getPenaltyDiscount() { * @deprecated Use the setters on the individual scores instead. */ public void setPenaltyDiscount(double penaltyDiscount) { - if (fgsScore instanceof ISemBicScore) { - ((ISemBicScore) fgsScore).setPenaltyDiscount(penaltyDiscount); + if (fgesScore instanceof ISemBicScore) { + ((ISemBicScore) fgesScore).setPenaltyDiscount(penaltyDiscount); } } @@ -497,16 +497,16 @@ public void setParallelism(int numProcessors) { * True iff edges that cause linear dependence are ignored. */ public boolean isIgnoreLinearDependent() { - if (fgsScore instanceof SemBicScore) { - return ((SemBicScore) fgsScore).isIgnoreLinearDependent(); + if (fgesScore instanceof SemBicScore) { + return ((SemBicScore) fgesScore).isIgnoreLinearDependent(); } throw new UnsupportedOperationException("Operation supported only for SemBicScore."); } public void setIgnoreLinearDependent(boolean ignoreLinearDependent) { - if (fgsScore instanceof SemBicScore) { - ((SemBicScore) fgsScore).setIgnoreLinearDependent(ignoreLinearDependent); + if (fgesScore instanceof SemBicScore) { + ((SemBicScore) fgesScore).setIgnoreLinearDependent(ignoreLinearDependent); } else { throw new UnsupportedOperationException("Operation supported only for SemBicScore."); } @@ -523,7 +523,7 @@ public void setBoundGraph(Graph boundGraph) { //Sets the discrete scoring function to use. private void setScore(Score score) { - this.fgsScore = score; + this.fgesScore = score; this.variables = new ArrayList<>(); @@ -593,9 +593,9 @@ protected Boolean compute() { int child = hashIndices.get(y); int parent = hashIndices.get(x); - double bump = fgsScore.localScoreDiff(parent, child, new int[]{}); + double bump = fgesScore.localScoreDiff(parent, child, new int[]{}); - if (isHeuristicSpeedup() && fgsScore.isEffectEdge(bump)) { + if (isHeuristicSpeedup() && fgesScore.isEffectEdge(bump)) { final Edge edge = Edges.undirectedEdge(x, y); if (boundGraph != null && !boundGraph.isAdjacentTo(edge.getNode1(), edge.getNode2())) continue; @@ -986,7 +986,7 @@ private void calculateArrowsForward(Node a, Node b) { } Set naYX = getNaYX(a, b); - if (!isClique(naYX)) return; + if (!GraphUtils.isClique(naYX, this.graph)) return; List TNeighbors = getTNeighbors(a, b); @@ -1020,7 +1020,7 @@ private void calculateArrowsForward(Node a, Node b) { break FOR; } - if (!isClique(union)) continue; + if (!GraphUtils.isClique(union, this.graph)) continue; newCliques.add(union); double bump = insertEval(a, b, T, naYX, hashIndices); @@ -1029,7 +1029,7 @@ private void calculateArrowsForward(Node a, Node b) { addArrow(a, b, naYX, T, bump); } - if (isHeuristicSpeedup() && union.isEmpty() && fgsScore.isEffectEdge(bump) && + if (isHeuristicSpeedup() && union.isEmpty() && fgesScore.isEffectEdge(bump) && !effectEdgesGraph.isAdjacentTo(a, b) && graph.getParents(b).isEmpty()) { effectEdgesGraph.addUndirectedEdge(a, b); } @@ -1155,8 +1155,8 @@ private void calculateArrowsBackward(Node a, Node b) { * @deprecated Use the setters on the individual scores instead. */ public void setSamplePrior(double samplePrior) { - if (fgsScore instanceof LocalDiscreteScore) { - ((LocalDiscreteScore) fgsScore).setSamplePrior(samplePrior); + if (fgesScore instanceof LocalDiscreteScore) { + ((LocalDiscreteScore) fgesScore).setSamplePrior(samplePrior); } } @@ -1164,8 +1164,8 @@ public void setSamplePrior(double samplePrior) { * @deprecated Use the setters on the individual scores instead. */ public void setStructurePrior(double expectedNumParents) { - if (fgsScore instanceof LocalDiscreteScore) { - ((LocalDiscreteScore) fgsScore).setStructurePrior(expectedNumParents); + if (fgesScore instanceof LocalDiscreteScore) { + ((LocalDiscreteScore) fgesScore).setStructurePrior(expectedNumParents); } } @@ -1434,7 +1434,7 @@ private boolean validInsert(Node x, Node y, Set T, Set naYX) { Set union = new HashSet<>(T); union.addAll(naYX); - boolean clique = isClique(union); + boolean clique = GraphUtils.isClique(union, this.graph); boolean noCycle = !existsUnblockedSemiDirectedPath(y, x, union, cycleBound); return clique && noCycle && !violatesKnowledge; } @@ -1456,7 +1456,7 @@ private boolean validDelete(Node x, Node y, Set H, Set naYX) { Set diff = new HashSet<>(naYX); diff.removeAll(H); - return isClique(diff) && !violatesKnowledge; + return GraphUtils.isClique(diff, this.graph) && !violatesKnowledge; } // Adds edges required by knowledge. @@ -1556,20 +1556,6 @@ private Set getNaYX(Node x, Node y) { return nayx; } - // Returns true iif the given set forms a clique in the given graph. - private boolean isClique(Set nodes) { - List _nodes = new ArrayList<>(nodes); - for (int i = 0; i < _nodes.size() - 1; i++) { - for (int j = i + 1; j < _nodes.size(); j++) { - if (!graph.isAdjacentTo(_nodes.get(i), _nodes.get(j))) { - return false; - } - } - } - - return true; - } - // Returns true if a path consisting of undirected and directed edges toward 'to' exists of // length at most 'bound'. Cycle checker in other words. private boolean existsUnblockedSemiDirectedPath(Node from, Node to, Set cond, int bound) { @@ -1701,7 +1687,7 @@ public double scoreDag(Graph dag) { } int yIndex = hashIndices.get(y); - score += fgsScore.localScore(yIndex, parentIndices); + score += fgesScore.localScore(yIndex, parentIndices); } return score; @@ -1720,7 +1706,7 @@ private double scoreGraphChange(Node y, Set parents, parentIndices[count++] = hashIndices.get(parent); } - return fgsScore.localScoreDiff(hashIndices.get(x), yIndex, parentIndices); + return fgesScore.localScoreDiff(hashIndices.get(x), yIndex, parentIndices); } private List getVariables() { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FgsOrienter.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FgesOrienter.java similarity index 98% rename from tetrad-lib/src/main/java/edu/cmu/tetrad/search/FgsOrienter.java rename to tetrad-lib/src/main/java/edu/cmu/tetrad/search/FgesOrienter.java index e7fbee9883..23a872e03f 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FgsOrienter.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FgesOrienter.java @@ -50,7 +50,7 @@ * of the edges in the oriented graph * @author AJ Sedgewick, 5/2015 */ -public final class FgsOrienter implements GraphSearch, GraphScorer, Reorienter { +public final class FgesOrienter implements GraphSearch, GraphScorer, Reorienter { /** * The covariance matrix for continuous data. @@ -179,7 +179,7 @@ public final class FgsOrienter implements GraphSearch, GraphScorer, Reorienter { /** * The data set must either be all continuous or all discrete. */ - public FgsOrienter(DataSet dataSet) { + public FgesOrienter(DataSet dataSet) { out.println("GES constructor"); if (dataSet.isDiscrete()) { @@ -210,7 +210,7 @@ public FgsOrienter(DataSet dataSet) { /** * Continuous case--where a covariance matrix is already available. */ - public FgsOrienter(ICovarianceMatrix covMatrix) { + public FgesOrienter(ICovarianceMatrix covMatrix) { out.println("GES(orient) constructor"); setCovMatrix(covMatrix); @@ -883,7 +883,7 @@ private void calculateArrowsForward(final Node a, final Node b, final Graph grap // Necessary condition for it to be a clique later (after possible edge removals) is that it be a clique // now. - if (!isClique(union, graph)) continue; + if (!GraphUtils.isClique(union, graph)) continue; if (existsKnowledge()) { if (!validSetByKnowledge(b, s)) { @@ -992,7 +992,7 @@ private void calculateArrowsBackward(Node a, Node b, Graph graph) { Set diff = new HashSet<>(naYX); diff.removeAll(h); - if (!isClique(diff, graph)) continue; + if (!GraphUtils.isClique(diff, graph)) continue; if (existsKnowledge()) { if (!validSetByKnowledge(b, h)) { @@ -1278,7 +1278,7 @@ private boolean allNeighbors(Node y, Set union, Graph graph) { private boolean validDelete(Node y, Set h, Set naXY, Graph graph) { Set set = new HashSet<>(naXY); set.removeAll(h); - return isClique(set, graph) && allNeighbors(y, set, graph); + return GraphUtils.isClique(set, graph) && allNeighbors(y, set, graph); } // Adds edges required by knowledge. @@ -1379,20 +1379,6 @@ private static Set getNaYX(Node x, Node y, Graph graph) { return nayx; } - // Returns true iif the given set forms a clique in the given graph. - private static boolean isClique(Set nodes, Graph graph) { - List _nodes = new ArrayList<>(nodes); - for (int i = 0; i < _nodes.size() - 1; i++) { - for (int j = i + 1; j < _nodes.size(); j++) { - if (!graph.isAdjacentTo(_nodes.get(i), _nodes.get(j))) { - return false; - } - } - } - - return true; - } - // Returns true is a path consisting of undirected and directed edges toward 'to' exists of // length at most 'bound'. Cycle checker in other words. private boolean existsUnblockedSemiDirectedPath(Node from, Node to, Set cond, Graph G, int bound) { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/GCcd.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/GCcd.java deleted file mode 100644 index 6079a2440e..0000000000 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/GCcd.java +++ /dev/null @@ -1,654 +0,0 @@ -/////////////////////////////////////////////////////////////////////////////// -// For information as to what this class does, see the Javadoc, below. // -// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // -// 2007, 2008, 2009, 2010, 2014, 2015 by Peter Spirtes, Richard Scheines, Joseph // -// Ramsey, and Clark Glymour. // -// // -// This program is free software; you can redistribute it and/or modify // -// it under the terms of the GNU General Public License as published by // -// the Free Software Foundation; either version 2 of the License, or // -// (at your option) any later version. // -// // -// This program is distributed in the hope that it will be useful, // -// but WITHOUT ANY WARRANTY; without even the implied warranty of // -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // -// GNU General Public License for more details. // -// // -// You should have received a copy of the GNU General Public License // -// along with this program; if not, write to the Free Software // -// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // -/////////////////////////////////////////////////////////////////////////////// - -package edu.cmu.tetrad.search; - -import edu.cmu.tetrad.data.IKnowledge; -import edu.cmu.tetrad.graph.*; -import edu.cmu.tetrad.util.ChoiceGenerator; -import edu.cmu.tetrad.util.DepthChoiceGenerator; -import edu.cmu.tetrad.util.ForkJoinPoolInstance; -import edu.cmu.tetrad.util.TetradLogger; - -import java.util.*; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.RecursiveTask; - -/** - * This class provides the data structures and methods for carrying out the Cyclic Causal Discovery algorithm (CCD) - * described by Thomas Richardson and Peter Spirtes in Chapter 7 of Computation, Causation, & Discovery by Glymour and - * Cooper eds. The comments that appear below are keyed to the algorithm specification on pp. 269-271.

The search - * method returns an instance of a Graph but it also constructs two lists of node triples which represent the underlines - * and dotted underlines that the algorithm discovers. - * - * @author Frank C. Wimberly - * @author Joseph Ramsey - */ -public final class GCcd implements GraphSearch { - private IndependenceTest independenceTest; - private Score score; - private int depth = -1; - private IKnowledge knowledge; - private List nodes; - private boolean applyR1 = false; - private boolean verbose; - - public GCcd(IndependenceTest test, Score score) { - if (test == null) throw new NullPointerException(); - this.independenceTest = test; - this.score = score; - this.nodes = test.getVariables(); - } - - //======================================== PUBLIC METHODS ====================================// - - /** - * The search method assumes that the IndependenceTest provided to the constructor is a conditional independence - * oracle for the SEM (or Bayes network) which describes the causal structure of the population. The method returns - * a PAG instantiated as a Tetrad GaSearchGraph which represents the equivalence class of digraphs which are - * d-separation equivalent to the digraph of the underlying model (SEM or BN).

Although they are not returned - * by the search method it also computes two lists of triples which, respectively store the underlines and dotted - * underlines of the PAG. - */ - public Graph search() { - Map> supSepsets = new HashMap<>(); - - // Step A - Fgs fgs = new Fgs(score); - fgs.setVerbose(verbose); - fgs.setNumPatternsToStore(0); - fgs.setFaithfulnessAssumed(false); - Graph psi = fgs.search(); - -// SepsetProducer sepsets0 = new SepsetsGreedy(new EdgeListGraphSingleConnections(psi), -// independenceTest, null, -1); - SepsetProducer sepsets = new SepsetsMinScore(psi, independenceTest, -1); - - for (Edge edge : psi.getEdges()) { - Node a = edge.getNode1(); - Node c = edge.getNode2(); - - if (psi.isAdjacentTo(a, c)) { - if (sepsets.getSepset(a, c) != null) { - psi.removeEdge(a, c); - } - } - } - - psi.reorientAllWith(Endpoint.CIRCLE); - - stepB(psi); - stepC(psi, sepsets); - stepD(psi, sepsets, supSepsets); - stepE(supSepsets, psi); - stepF(psi, sepsets, supSepsets); - - return psi; - } - - private void orientAwayFromArrow(Graph graph) { - for (Edge edge : graph.getEdges()) { - Node n1 = edge.getNode1(); - Node n2 = edge.getNode2(); - - edge = graph.getEdge(n1, n2); - - if (edge.pointsTowards(n1)) { - orientAwayFromArrow(n2, n1, graph); - } else if (edge.pointsTowards(n2)) { - orientAwayFromArrow(n1, n2, graph); - } - } - } - - public IKnowledge getKnowledge() { - return knowledge; - } - - public int getDepth() { - return depth; - } - - public void setDepth(int depth) { - this.depth = depth; - } - - public void setVerbose(boolean verbose) { - this.verbose = verbose; - } - - public void setKnowledge(IKnowledge knowledge) { - if (knowledge == null) { - throw new NullPointerException(); - } - this.knowledge = knowledge; - } - - public long getElapsedTime() { - return 0; - } - - //======================================== PRIVATE METHODS ====================================// - - private void stepB(Graph psi) { - final Map colliders = new ConcurrentHashMap<>(); - final Map noncolliders = new ConcurrentHashMap<>(); - - List nodes = psi.getNodes(); - - class Task extends RecursiveTask { - private final Map colliders; - private final Map noncolliders; - private int from; - private int to; - private int chunk = 20; - private List nodes; - private Graph psi; - - public Task(List nodes, Graph graph, Map colliders, - Map noncolliders, int from, int to) { - this.nodes = nodes; - this.psi = graph; - this.from = from; - this.to = to; - this.colliders = colliders; - this.noncolliders = noncolliders; - } - - @Override - protected Boolean compute() { - if (to - from <= chunk) { - for (int i = from; i < to; i++) { - doNodeCollider(psi, colliders, noncolliders, nodes.get(i)); - } - - return true; - } else { - int mid = (to + from) / 2; - - Task left = new Task(nodes, psi, colliders, noncolliders, from, mid); - Task right = new Task(nodes, psi, colliders, noncolliders, mid, to); - - left.fork(); - right.compute(); - left.join(); - - return true; - } - } - } - - Task task = new Task(nodes, psi, colliders, noncolliders, 0, nodes.size()); - - ForkJoinPoolInstance.getInstance().getPool().invoke(task); - - List collidersList = new ArrayList<>(colliders.keySet()); - List noncollidersList = new ArrayList<>(noncolliders.keySet()); - - Collections.sort(collidersList, new Comparator() { - - @Override - public int compare(Triple o1, Triple o2) { - return -Double.compare(colliders.get(o2), colliders.get(o1)); - } - }); - - for (Triple triple : collidersList) { - Node a = triple.getX(); - Node b = triple.getY(); - Node c = triple.getZ(); - - if (!(psi.getEndpoint(b, a) == Endpoint.ARROW || psi.getEndpoint(b, c) == Endpoint.ARROW)) { - psi.removeEdge(a, b); - psi.removeEdge(c, b); - psi.addDirectedEdge(a, b); - psi.addDirectedEdge(c, b); - } - } - - for (Triple triple : noncollidersList) { - Node a = triple.getX(); - Node b = triple.getY(); - Node c = triple.getZ(); - - psi.addUnderlineTriple(a, b, c); - } - - orientAwayFromArrow(psi); - } - - private void doNodeCollider(Graph psi, Map colliders, Map noncolliders, Node b) { - List adjacentNodes = psi.getAdjacentNodes(b); - - if (adjacentNodes.size() < 2) { - return; - } - - ChoiceGenerator cg = new ChoiceGenerator(adjacentNodes.size(), 2); - int[] combination; - - while ((combination = cg.next()) != null) { - Node a = adjacentNodes.get(combination[0]); - Node c = adjacentNodes.get(combination[1]); - - // Skip triples that are shielded. - if (psi.isAdjacentTo(a, c)) { - continue; - } - - List adja = psi.getAdjacentNodes(a); - double score = Double.POSITIVE_INFINITY; - List S = null; - - DepthChoiceGenerator cg2 = new DepthChoiceGenerator(adja.size(), -1); - int[] comb2; - - while ((comb2 = cg2.next()) != null) { - List s = GraphUtils.asList(comb2, adja); - independenceTest.isIndependent(a, c, s); - double _score = independenceTest.getScore(); - - if (_score < score && _score < 0) { - score = _score; - S = s; - } - } - - List adjc = psi.getAdjacentNodes(c); - - DepthChoiceGenerator cg3 = new DepthChoiceGenerator(adjc.size(), -1); - int[] comb3; - - while ((comb3 = cg3.next()) != null) { - List s = GraphUtils.asList(comb3, adjc); - independenceTest.isIndependent(c, a, s); - double _score = independenceTest.getScore(); - - if (_score < score && _score < 0) { - score = _score; - S = s; - } - } - - // This could happen if there are undefined values and such. - if (S == null) { - continue; - } - - if (S.contains(b)) { - noncolliders.put(new Triple(a, b, c), score); - } else { - colliders.put(new Triple(a, b, c), score); - } - } - } - - private void stepC(Graph psi, SepsetProducer sepsets) { - TetradLogger.getInstance().log("info", "\nStep C"); - - EDGE: - for (Edge edge : psi.getEdges()) { - Node x = edge.getNode1(); - Node y = edge.getNode2(); - - // x and y are adjacent. - - List adjx = psi.getAdjacentNodes(x); - List adjy = psi.getAdjacentNodes(y); - - for (Node node : adjx) { - if (psi.getEdge(node, x).getProximalEndpoint(x) == Endpoint.ARROW - && psi.isUnderlineTriple(y, x, node)) { - continue EDGE; - } - } - - // Check each A - for (Node a : nodes) { - if (a == x) continue; - if (a == y) continue; - - //...A is not adjacent to X and A is not adjacent to Y... - if (adjx.contains(a)) continue; - if (adjy.contains(a)) continue; - - // Orientable... - if (!(psi.getEndpoint(y, x) == Endpoint.CIRCLE && - (psi.getEndpoint(x, y) == Endpoint.CIRCLE || psi.getEndpoint(x, y) == Endpoint.TAIL))) { - continue; - } - - if (wouldCreateBadCollider(x, y, psi)) { - continue; - } - - //...X is not in sepset... - List sepset = sepsets.getSepset(a, y); - - if (sepset == null) { - continue; - } - - if (sepset.contains(x)) continue; - - if (!sepsets.isIndependent(a, x, sepset)) { - psi.removeEdge(x, y); - psi.addDirectedEdge(y, x); - orientAwayFromArrow(y, x, psi); - break; - } - } - } - } - - private boolean wouldCreateBadCollider(Node x, Node y, Graph psi) { - for (Node z : psi.getAdjacentNodes(y)) { - if (x == z) continue; - if (psi.getEndpoint(x, y) != Endpoint.ARROW && psi.getEndpoint(z, y) == Endpoint.ARROW) return true; - } - - return false; - } - - private void stepD(Graph psi, SepsetProducer sepsets, final Map> supSepsets) { - Map> local = new HashMap<>(); - - for (Node node : psi.getNodes()) { - local.put(node, local(psi, node)); - } - - class Task extends RecursiveTask { - private Graph psi; - private SepsetProducer sepsets; - private Map> supSepsets; - private Map> local; - private int from; - private int to; - private int chunk = 20; - - public Task(Graph psi, SepsetProducer sepsets, Map> supSepsets, - Map> local, int from, int to) { - this.psi = psi; - this.sepsets = sepsets; - this.supSepsets = supSepsets; - this.local = local; - this.from = from; - this.to = to; - } - - @Override - protected Boolean compute() { - if (to - from <= chunk) { - for (int i = from; i < to; i++) { - doNodeStepD(psi, sepsets, supSepsets, local, nodes.get(i)); - } - - return true; - } else { - int mid = (to + from) / 2; - - Task left = new Task(psi, sepsets, supSepsets, local, from, mid); - Task right = new Task(psi, sepsets, supSepsets, local, mid, to); - - left.fork(); - right.compute(); - left.join(); - - return true; - } - } - } - - Task task = new Task(psi, sepsets, supSepsets, local, 0, nodes.size()); - - ForkJoinPoolInstance.getInstance().getPool().invoke(task); - } - - private void doNodeStepD(Graph psi, SepsetProducer sepsets, Map> supSepsets, - Map> local, Node b) { - List adj = psi.getAdjacentNodes(b); - - if (adj.size() < 2) { - return; - } - - ChoiceGenerator gen = new ChoiceGenerator(adj.size(), 2); - int[] choice; - - while ((choice = gen.next()) != null) { - List _adj = GraphUtils.asList(choice, adj); - Node a = _adj.get(0); - Node c = _adj.get(1); - - if (!psi.isDefCollider(a, b, c)) continue; - - List S = sepsets.getSepset(a, c); - if (S == null) continue; - ArrayList TT = new ArrayList<>(local.get(a)); - TT.removeAll(S); - TT.remove(b); - TT.remove(c); - - DepthChoiceGenerator gen2 = new DepthChoiceGenerator(TT.size(), -1); - int[] choice2; - - while ((choice2 = gen2.next()) != null) { - Set T = GraphUtils.asSet(choice2, TT); - Set B = new HashSet<>(T); - B.addAll(S); - B.add(b); - - if (sepsets.isIndependent(a, c, new ArrayList<>(B))) { - psi.addDottedUnderlineTriple(a, b, c); - supSepsets.put(new Triple(a, b, c), B); - break; - } - } - } - } - - private void stepE(Map> supSepset, Graph psi) { - TetradLogger.getInstance().log("info", "\nStep E"); - - for (Triple triple : psi.getDottedUnderlines()) { - Node a = triple.getX(); - Node b = triple.getY(); - Node c = triple.getZ(); - - List aAdj = psi.getAdjacentNodes(a); - - for (Node d : aAdj) { - if (d == b) continue; - - if (psi.getEndpoint(b, d) != Endpoint.CIRCLE) { - continue; - } - - if (supSepset.get(triple).contains(d)) { - - // Orient B*-oD as B*-D - psi.setEndpoint(b, d, Endpoint.TAIL); - } else { - if (psi.getEndpoint(d, b) == Endpoint.ARROW) { - continue; - } - - if (wouldCreateBadCollider(b, d, psi)) { - continue; - } - - // Or orient Bo-oD or B-oD as B->D... - psi.removeEdge(b, d); - psi.addDirectedEdge(b, d); - orientAwayFromArrow(b, d, psi); - } - } - - List cAdj = psi.getAdjacentNodes(c); - - for (Node d : cAdj) { - if (d == b) continue; - - if (psi.getEndpoint(b, d) != Endpoint.CIRCLE) { - continue; - } - - if (supSepset.get(triple).contains(d)) { - - // Orient B*-oD as B*-D - psi.setEndpoint(b, d, Endpoint.TAIL); - } else { - if (psi.getEndpoint(d, b) == Endpoint.ARROW) { - continue; - } - - if (wouldCreateBadCollider(b, d, psi)) { - continue; - } - - // Or orient Bo-oD or B-oD as B->D... - psi.removeEdge(b, d); - psi.addDirectedEdge(b, d); - orientAwayFromArrow(b, d, psi); - } - } - } - } - - private void stepF(Graph psi, SepsetProducer sepsets, Map> supSepsets) { - for (Triple triple : psi.getDottedUnderlines()) { - Node a = triple.getX(); - Node b = triple.getY(); - Node c = triple.getZ(); - - Set adj = new HashSet<>(psi.getAdjacentNodes(a)); - adj.addAll(psi.getAdjacentNodes(c)); - - for (Node d : adj) { - if (psi.getEndpoint(b, d) != Endpoint.CIRCLE) { - continue; - } - - if (psi.getEndpoint(d, b) == Endpoint.ARROW) { - continue; - } - - //...and D is not adjacent to both A and C in psi... - if (psi.isAdjacentTo(a, d) && psi.isAdjacentTo(c, d)) { - continue; - } - - //...and B and D are adjacent... - if (!psi.isAdjacentTo(b, d)) { - continue; - } - - Set supSepUnionD = new HashSet<>(); - supSepUnionD.add(d); - supSepUnionD.addAll(supSepsets.get(triple)); - List listSupSepUnionD = new ArrayList<>(supSepUnionD); - - if (wouldCreateBadCollider(b, d, psi)) { - continue; - } - - //If A and C are a pair of vertices d-connected given - //SupSepset union {D} then orient Bo-oD or B-oD - //as B->D in psi. - if (!sepsets.isIndependent(a, c, listSupSepUnionD)) { - psi.removeEdge(b, d); - psi.addDirectedEdge(b, d); - orientAwayFromArrow(b, d, psi); - } - } - } - } - - private List local(Graph psi, Node x) { - Set nodes = new HashSet<>(psi.getAdjacentNodes(x)); - - for (Node y : new HashSet<>(nodes)) { - for (Node z : psi.getAdjacentNodes(y)) { - if (psi.isDefCollider(x, y, z)) { - if (z != x) { - nodes.add(z); - } - } - } - } - - return new ArrayList<>(nodes); - } - - private void orientAwayFromArrow(Node a, Node b, Graph graph) { - if (!isApplyR1()) return; - - for (Node c : graph.getAdjacentNodes(b)) { - if (c == a) continue; - orientAwayFromArrowVisit(a, b, c, graph); - } - } - - private boolean orientAwayFromArrowVisit(Node a, Node b, Node c, Graph graph) { - if (!Edges.isNondirectedEdge(graph.getEdge(b, c))) { - return false; - } - - if (!(graph.isUnderlineTriple(a, b, c))) { - return false; - } - - - if (graph.getEdge(b, c).pointsTowards(b)) { - return false; - } - - graph.removeEdge(b, c); - graph.addDirectedEdge(b, c); - - for (Node d : graph.getAdjacentNodes(c)) { - if (d == b) return true; - - Edge bc = graph.getEdge(b, c); - - if (!orientAwayFromArrowVisit(b, c, d, graph)) { - graph.removeEdge(b, c); - graph.addEdge(bc); - } - } - - return true; - } - - public boolean isApplyR1() { - return applyR1; - } - - public void setApplyR1(boolean applyR1) { - this.applyR1 = applyR1; - } -} - - - - - - diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/GFci.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/GFci.java index edc7f7faa7..3e77f673c7 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/GFci.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/GFci.java @@ -53,8 +53,8 @@ public final class GFci implements GraphSearch { // The maximum length for any discriminating path. -1 if unlimited; otherwise, a positive integer. private int maxPathLength = -1; - // The maxIndegree for the fast adjacency search. - private int maxIndegree = -1; + // The maxDegree for the fast adjacency search. + private int maxDegree = -1; // The logger to use. private TetradLogger logger = TetradLogger.getInstance(); @@ -101,20 +101,20 @@ public Graph search() { this.graph = new EdgeListGraphSingleConnections(nodes); - Fgs fgs = new Fgs(score); - fgs.setKnowledge(getKnowledge()); - fgs.setVerbose(verbose); - fgs.setNumPatternsToStore(0); - fgs.setFaithfulnessAssumed(faithfulnessAssumed); - fgs.setMaxDegree(maxIndegree); - fgs.setOut(out); - graph = fgs.search(); - Graph fgsGraph = new EdgeListGraphSingleConnections(graph); + Fges fges = new Fges(score); + fges.setKnowledge(getKnowledge()); + fges.setVerbose(verbose); + fges.setNumPatternsToStore(0); + fges.setFaithfulnessAssumed(faithfulnessAssumed); + fges.setMaxDegree(maxDegree); + fges.setOut(out); + graph = fges.search(); + Graph fgesGraph = new EdgeListGraphSingleConnections(graph); - sepsets = new SepsetsGreedy(fgsGraph, independenceTest, null, maxIndegree); + sepsets = new SepsetsGreedy(fgesGraph, independenceTest, null, maxDegree); for (Node b : nodes) { - List adjacentNodes = fgsGraph.getAdjacentNodes(b); + List adjacentNodes = fgesGraph.getAdjacentNodes(b); if (adjacentNodes.size() < 2) { continue; @@ -127,7 +127,7 @@ public Graph search() { Node a = adjacentNodes.get(combination[0]); Node c = adjacentNodes.get(combination[1]); - if (graph.isAdjacentTo(a, c) && fgsGraph.isAdjacentTo(a, c)) { + if (graph.isAdjacentTo(a, c) && fgesGraph.isAdjacentTo(a, c)) { if (sepsets.getSepset(a, c) != null) { graph.removeEdge(a, c); } @@ -135,7 +135,7 @@ public Graph search() { } } - modifiedR0(fgsGraph); + modifiedR0(fgesGraph); FciOrient fciOrient = new FciOrient(sepsets); fciOrient.setVerbose(verbose); @@ -151,6 +151,8 @@ public Graph search() { elapsedTime = time2 - time1; + graph.setPag(true); + return graph; } @@ -160,26 +162,26 @@ public long getElapsedTime() { } /** - * @param maxIndegree The maximum indegree of the output graph. + * @param maxDegree The maximum indegree of the output graph. */ - public void setMaxIndegree(int maxIndegree) { - if (maxIndegree < -1) { + public void setMaxDegree(int maxDegree) { + if (maxDegree < -1) { throw new IllegalArgumentException( - "Depth must be -1 (unlimited) or >= 0: " + maxIndegree); + "Depth must be -1 (unlimited) or >= 0: " + maxDegree); } - this.maxIndegree = maxIndegree; + this.maxDegree = maxDegree; } /** * Returns The maximum indegree of the output graph. */ - public int getMaxIndegree() { - return maxIndegree; + public int getMaxDegree() { + return maxDegree; } // Due to Spirtes. - public void modifiedR0(Graph fgsGraph) { + public void modifiedR0(Graph fgesGraph) { graph.reorientAllWith(Endpoint.CIRCLE); fciOrientbk(knowledge, graph, graph.getNodes()); @@ -199,10 +201,10 @@ public void modifiedR0(Graph fgsGraph) { Node a = adjacentNodes.get(combination[0]); Node c = adjacentNodes.get(combination[1]); - if (fgsGraph.isDefCollider(a, b, c)) { + if (fgesGraph.isDefCollider(a, b, c)) { graph.setEndpoint(a, b, Endpoint.ARROW); graph.setEndpoint(c, b, Endpoint.ARROW); - } else if (fgsGraph.isAdjacentTo(a, c) && !graph.isAdjacentTo(a, c)) { + } else if (fgesGraph.isAdjacentTo(a, c) && !graph.isAdjacentTo(a, c)) { List sepset = sepsets.getSepset(a, c); if (sepset != null && !sepset.contains(b)) { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/GFciMax.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/GFciMax.java index 1d957a36b4..72cb0d89be 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/GFciMax.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/GFciMax.java @@ -106,21 +106,21 @@ public Graph search() { this.graph = new EdgeListGraphSingleConnections(nodes); - Fgs fgs = new Fgs(score); - fgs.setKnowledge(getKnowledge()); - fgs.setVerbose(verbose); - fgs.setNumPatternsToStore(0); - fgs.setFaithfulnessAssumed(faithfulnessAssumed); - fgs.setMaxDegree(maxDegree); - fgs.setOut(out); - graph = fgs.search(); - Graph fgsGraph = new EdgeListGraphSingleConnections(graph); - sepsets = new SepsetsGreedy(fgsGraph, independenceTest, null, maxDegree); + Fges fges = new Fges(score); + fges.setKnowledge(getKnowledge()); + fges.setVerbose(verbose); + fges.setNumPatternsToStore(0); + fges.setFaithfulnessAssumed(faithfulnessAssumed); + fges.setMaxDegree(maxDegree); + fges.setOut(out); + graph = fges.search(); + Graph fgesGraph = new EdgeListGraphSingleConnections(graph); + sepsets = new SepsetsGreedy(fgesGraph, independenceTest, null, maxDegree); graph.reorientAllWith(Endpoint.CIRCLE); for (Node b : nodes) { - List adjacentNodes = fgsGraph.getAdjacentNodes(b); + List adjacentNodes = fgesGraph.getAdjacentNodes(b); if (adjacentNodes.size() < 2) { continue; @@ -133,7 +133,7 @@ public Graph search() { Node a = adjacentNodes.get(combination[0]); Node c = adjacentNodes.get(combination[1]); - if (graph.isAdjacentTo(a, c) && fgsGraph.isAdjacentTo(a, c)) { + if (graph.isAdjacentTo(a, c) && fgesGraph.isAdjacentTo(a, c)) { if (sepsets.getSepset(a, c) != null) { graph.removeEdge(a, c); } @@ -141,10 +141,10 @@ public Graph search() { } } -// modifiedR0(fgsGraph); - sepsets = new SepsetsMinScore(fgsGraph, independenceTest, maxDegree); +// modifiedR0(fgesGraph); + sepsets = new SepsetsMinScore(fgesGraph, independenceTest, maxDegree); - addColliders(graph, fgsGraph); + addColliders(graph, fgesGraph); FciOrient fciOrient = new FciOrient(sepsets); fciOrient.setVerbose(verbose); @@ -192,7 +192,7 @@ public int getMaxDegree() { } // Due to Spirtes. - public void modifiedR0(Graph fgsGraph) { + public void modifiedR0(Graph fgesGraph) { graph.reorientAllWith(Endpoint.CIRCLE); fciOrientbk(knowledge, graph, graph.getNodes()); @@ -212,10 +212,10 @@ public void modifiedR0(Graph fgsGraph) { Node a = adjacentNodes.get(combination[0]); Node c = adjacentNodes.get(combination[1]); - if (fgsGraph.isDefCollider(a, b, c)) { + if (fgesGraph.isDefCollider(a, b, c)) { graph.setEndpoint(a, b, Endpoint.ARROW); graph.setEndpoint(c, b, Endpoint.ARROW); - } else if (fgsGraph.isAdjacentTo(a, c) && !graph.isAdjacentTo(a, c)) { + } else if (fgesGraph.isAdjacentTo(a, c) && !graph.isAdjacentTo(a, c)) { List sepset = sepsets.getSepset(a, c); if (sepset != null && !sepset.contains(b)) { @@ -374,7 +374,7 @@ private void fciOrientbk(IKnowledge knowledge, Graph graph, List variables logger.log("info", "Finishing BK Orientation."); } - private void addColliders(Graph graph, final Graph fgsGraph) { + private void addColliders(Graph graph, final Graph fgesGraph) { List nodes = graph.getNodes(); class Task extends RecursiveTask { @@ -396,7 +396,7 @@ public Task(List nodes, Graph graph, int from, int to) { protected Boolean compute() { if (to - from <= chunk) { for (int i = from; i < to; i++) { - doNode(graph, fgsGraph, nodes.get(i)); + doNode(graph, fgesGraph, nodes.get(i)); } return true; @@ -420,7 +420,7 @@ protected Boolean compute() { ForkJoinPoolInstance.getInstance().getPool().invoke(task); } - private void doNode(Graph graph, Graph fgsGraph, Node b) { + private void doNode(Graph graph, Graph fgesGraph, Node b) { List adjacentNodes = graph.getAdjacentNodes(b); if (adjacentNodes.size() < 2) { @@ -434,7 +434,7 @@ private void doNode(Graph graph, Graph fgsGraph, Node b) { Node a = adjacentNodes.get(combination[0]); Node c = adjacentNodes.get(combination[1]); -// if (fgsGraph.isAdjacentTo(a, c) && !graph.isAdjacentTo(a, c)) { +// if (fgesGraph.isAdjacentTo(a, c) && !graph.isAdjacentTo(a, c)) { // S actually has to be non-null here, but the compiler doesn't know that. List S = sepsets.getSepset(a, c); if (S != null && !S.contains(b)) { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/GPc.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/GPc.java index 4a70f0d116..53fac38b61 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/GPc.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/GPc.java @@ -117,7 +117,7 @@ public final class GPc implements GraphSearch { private SepsetProducer sepsets; private long elapsedTime; - private int fgsDepth = -1; + private int fgesDepth = -1; //============================CONSTRUCTORS============================// @@ -152,29 +152,29 @@ public Graph search() { setScore(); } - Fgs fgs = new Fgs(score); - fgs.setKnowledge(getKnowledge()); - fgs.setVerbose(verbose); - fgs.setNumPatternsToStore(0); -// fgs.setHeuristicSpeedup(heuristicSpeedup); -// fgs.setMaxIndegree(fgsDepth); - graph = fgs.search(); + Fges fges = new Fges(score); + fges.setKnowledge(getKnowledge()); + fges.setVerbose(verbose); + fges.setNumPatternsToStore(0); +// fges.setHeuristicSpeedup(heuristicSpeedup); +// fges.setMaxDegree(fgesDepth); + graph = fges.search(); - Graph fgsGraph = new EdgeListGraphSingleConnections(graph); + Graph fgesGraph = new EdgeListGraphSingleConnections(graph); -// System.out.println("GFCI: FGS done"); +// System.out.println("GFCI: FGES done"); - sepsets = new SepsetsGreedy(fgsGraph, independenceTest, null, maxIndegree); -// ((SepsetsGreedy) sepsets).setMaxIndegree(3); -// sepsets = new SepsetsConservative(fgsGraph, independenceTest, null, maxIndegree); -// sepsets = new SepsetsConservativeMajority(fgsGraph, independenceTest, null, maxIndegree); -// sepsets = new SepsetsMaxPValue(fgsGraph, independenceTest, null, maxIndegree); -// sepsets = new SepsetsMinScore(fgsGraph, independenceTest, null, maxIndegree); + sepsets = new SepsetsGreedy(fgesGraph, independenceTest, null, maxIndegree); +// ((SepsetsGreedy) sepsets).setMaxDegree(3); +// sepsets = new SepsetsConservative(fgesGraph, independenceTest, null, maxIndegree); +// sepsets = new SepsetsConservativeMajority(fgesGraph, independenceTest, null, maxIndegree); +// sepsets = new SepsetsMaxPValue(fgesGraph, independenceTest, null, maxIndegree); +// sepsets = new SepsetsMinScore(fgesGraph, independenceTest, null, maxIndegree); // // System.out.println("GFCI: Look inside triangles starting"); for (Node b : nodes) { - List adjacentNodes = fgsGraph.getAdjacentNodes(b); + List adjacentNodes = fgesGraph.getAdjacentNodes(b); if (adjacentNodes.size() < 2) { continue; @@ -187,7 +187,7 @@ public Graph search() { Node a = adjacentNodes.get(combination[0]); Node c = adjacentNodes.get(combination[1]); - if (graph.isAdjacentTo(a, c) && fgsGraph.isAdjacentTo(a, c)) { + if (graph.isAdjacentTo(a, c) && fgesGraph.isAdjacentTo(a, c)) { if (sepsets.getSepset(a, c) != null) { graph.removeEdge(a, c); } @@ -202,8 +202,8 @@ public Graph search() { // Node x = edge.getNode1(); // Node y = edge.getNode2(); // -// List adjx = fgsGraph.getAdjacentNodes(x); -// List adjy = fgsGraph.getAdjacentNodes(y); +// List adjx = fgesGraph.getAdjacentNodes(x); +// List adjy = fgesGraph.getAdjacentNodes(y); // // adjx.remove(y); // adjy.remove(x); @@ -249,8 +249,8 @@ public Graph search() { //// Node a = edge.getNode1(); //// Node c = edge.getNode2(); //// -//// Set x = new HashSet<>(fgsGraph.getAdjacentNodes(a)); -//// x.retainAll(fgsGraph.getAdjacentNodes(c)); +//// Set x = new HashSet<>(fgesGraph.getAdjacentNodes(a)); +//// x.retainAll(fgesGraph.getAdjacentNodes(c)); //// //// if (!x.isEmpty()) { //// if (sepsets.getSepset(a, c) != null) { @@ -259,7 +259,7 @@ public Graph search() { //// } //// } - modifiedR0(fgsGraph); + modifiedR0(fgesGraph); MeekRules rules = new MeekRules(); rules.setAggressivelyPreventCycles(false); @@ -327,7 +327,7 @@ public void setMaxIndegree(int maxIndegree) { } // Due to Spirtes. - public void modifiedR0(Graph fgsGraph) { + public void modifiedR0(Graph fgesGraph) { graph.reorientAllWith(Endpoint.TAIL); pcOrientBk(knowledge, graph, graph.getNodes()); @@ -347,10 +347,10 @@ public void modifiedR0(Graph fgsGraph) { Node a = adjacentNodes.get(combination[0]); Node c = adjacentNodes.get(combination[1]); - if (fgsGraph.isDefCollider(a, b, c)) { + if (fgesGraph.isDefCollider(a, b, c)) { graph.setEndpoint(a, b, Endpoint.ARROW); graph.setEndpoint(c, b, Endpoint.ARROW); - } else if (fgsGraph.isAdjacentTo(a, c) && !graph.isAdjacentTo(a, c)) { + } else if (fgesGraph.isAdjacentTo(a, c) && !graph.isAdjacentTo(a, c)) { List sepset = sepsets.getSepset(a, c); if (sepset != null && !sepset.contains(b)) { @@ -533,12 +533,12 @@ public void setStructurePrior(double structurePrior) { this.structurePrior = structurePrior; } - public int getFgsDepth() { - return fgsDepth; + public int getFgesDepth() { + return fgesDepth; } - public void setFgsDepth(int fgsDepth) { - this.fgsDepth = fgsDepth; + public void setFgesDepth(int fgesDepth) { + this.fgesDepth = fgesDepth; } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/IndTestConditionalGaussianLRT.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/IndTestConditionalGaussianLRT.java index cc528643c9..c979f14d6b 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/IndTestConditionalGaussianLRT.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/IndTestConditionalGaussianLRT.java @@ -47,6 +47,9 @@ public class IndTestConditionalGaussianLRT implements IndependenceTest { // Likelihood function private ConditionalGaussianLikelihood likelihood; + private double pValue = Double.NaN; + private boolean denominatorMixed = true; + private double penaltyDiscount; public IndTestConditionalGaussianLRT(DataSet data, double alpha) { this.data = data; @@ -76,6 +79,8 @@ public IndependenceTest indTestSubset(List vars) { * getVariableNames(). */ public boolean isIndependent(Node x, Node y, List z) { + likelihood.setDenominatorMixed(denominatorMixed); + int _x = nodesHash.get(x); int _y = nodesHash.get(y); @@ -90,15 +95,25 @@ public boolean isIndependent(Node x, Node y, List z) { list2[i] = _z; } - ConditionalGaussianLikelihood.Ret ret1 = likelihood.getLikelihoodRatio(_x, list1); - ConditionalGaussianLikelihood.Ret ret2 = likelihood.getLikelihoodRatio(_x, list2); + ConditionalGaussianLikelihood.Ret ret1 = likelihood.getLikelihood(_x, list1); + ConditionalGaussianLikelihood.Ret ret2 = likelihood.getLikelihood(_x, list2); double lik = ret1.getLik() - ret2.getLik(); double dof = ret1.getDof() - ret2.getDof(); -// if (dof <= 1) dof = 1; + if (dof <= 0) { + dof = 1; +// throw new IllegalArgumentException("DOF must be >= 1"); + } + + double p = 0; + try { + p = 1.0 - new ChiSquaredDistribution(dof).cumulativeProbability(2.0 * lik); + } catch (Exception e) { + e.printStackTrace(); + } - double p = 1.0 - new ChiSquaredDistribution(dof).cumulativeProbability(2.0 * lik); + this.pValue = p; return p > alpha; } @@ -127,7 +142,7 @@ public boolean isDependent(Node x, Node y, Node... z) { * not meaningful for tis test. */ public double getPValue() { - return Double.NaN; + return this.pValue; } /** @@ -173,14 +188,14 @@ public boolean determines(List z, Node y) { * @throws UnsupportedOperationException if there is no significance level. */ public double getAlpha() { - return Double.NaN; + return alpha; } /** * Sets the significance level. */ public void setAlpha(double alpha) { - + this.alpha = alpha; } public DataSet getData() { @@ -209,7 +224,7 @@ public List getCovMatrices() { @Override public double getScore() { - return getPValue(); + return getAlpha() - getPValue(); } /** @@ -219,4 +234,16 @@ public String toString() { NumberFormat nf = new DecimalFormat("0.0000"); return "Multinomial Logistic Regression, alpha = " + nf.format(getAlpha()); } + + public void setDenominatorMixed(boolean denominatorMixed) { + this.denominatorMixed = denominatorMixed; + } + + public void setPenaltyDiscount(double penaltyDiscount) { + likelihood.setPenaltyDiscount(penaltyDiscount); + } + + public double getPenaltyDiscount() { + return penaltyDiscount; + } } \ No newline at end of file diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/IndTestScore.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/IndTestScore.java index f0aa81241d..ab9604f2bd 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/IndTestScore.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/IndTestScore.java @@ -192,7 +192,7 @@ public DataModel getData() { } public ICovarianceMatrix getCov() { - throw new UnsupportedOperationException(); + return ((SemBicScore) score).getCovariances(); } public List getDataSets() { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MeekRules.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MeekRules.java index 46980db3a4..900e2fa186 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MeekRules.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MeekRules.java @@ -372,7 +372,7 @@ private void direct(Node a, Node c, Graph graph) { // Adding last works, checking for c or not. Adding first works, but when it is // checked whether directStack already contains it it seems to produce one in - // 3000 trial error for FGS. Do not understand this yet. + // 3000 trial error for FGES. Do not understand this yet. directStack.addLast(c); } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MixedBicScore.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MixedBicScore.java index af48b94830..0f9a2bead9 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MixedBicScore.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MixedBicScore.java @@ -34,7 +34,7 @@ import static java.lang.Math.sqrt; /** - * Implements the continuous BIC score for FGS. + * Implements the continuous BIC score for FGES. * * @author Joseph Ramsey */ @@ -53,7 +53,7 @@ public class MixedBicScore implements Score { private double penaltyDiscount = 2.0; // True if linear dependencies should return NaN for the score, and hence be - // ignored by FGS + // ignored by FGES private boolean ignoreLinearDependent = false; // The printstream output should be sent to. diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/OrientCollidersMaxP.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/OrientCollidersMaxP.java new file mode 100644 index 0000000000..0c07b9a1a1 --- /dev/null +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/OrientCollidersMaxP.java @@ -0,0 +1,411 @@ +/////////////////////////////////////////////////////////////////////////////// +// For information as to what this class does, see the Javadoc, below. // +// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // +// 2007, 2008, 2009, 2010, 2014, 2015 by Peter Spirtes, Richard Scheines, Joseph // +// Ramsey, and Clark Glymour. // +// // +// This program is free software; you can redistribute it and/or modify // +// it under the terms of the GNU General Public License as published by // +// the Free Software Foundation; either version 2 of the License, or // +// (at your option) any later version. // +// // +// This program is distributed in the hope that it will be useful, // +// but WITHOUT ANY WARRANTY; without even the implied warranty of // +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // +// GNU General Public License for more details. // +// // +// You should have received a copy of the GNU General Public License // +// along with this program; if not, write to the Free Software // +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // +/////////////////////////////////////////////////////////////////////////////// + +package edu.cmu.tetrad.search; + +import edu.cmu.tetrad.data.IKnowledge; +import edu.cmu.tetrad.data.Knowledge2; +import edu.cmu.tetrad.graph.*; +import edu.cmu.tetrad.util.ChoiceGenerator; +import edu.cmu.tetrad.util.DepthChoiceGenerator; +import edu.cmu.tetrad.util.ForkJoinPoolInstance; + +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.RecursiveTask; + +/** + * This is an optimization of the CCD (Cyclic Causal Discovery) algorithm by Thomas Richardson. + * + * @author Joseph Ramsey + */ +public final class OrientCollidersMaxP { + private final IndependenceTest independenceTest; + private int depth = -1; + private long elapsed = 0; + private IKnowledge knowledge = new Knowledge2(); + private boolean useHeuristic = true; + private int maxPathLength = 3; + + public OrientCollidersMaxP(IndependenceTest test) { + if (test == null) throw new NullPointerException(); + this.independenceTest = test; + } + + //======================================== PUBLIC METHODS ====================================// + + /** + * Searches for a PAG satisfying the description in Thomas Richardson (1997), dissertation, + * Carnegie Mellon University. Uses a simplification of that algorithm. + */ + public void orient(Graph graph) { + addColliders(graph); + } + + /** + * @return The depth of search for the Fast Adjacency Search. + */ + public int getDepth() { + return depth; + } + + /** + * @param depth The depth of search for the Fast Adjacency Search. + */ + public void setDepth(int depth) { + this.depth = depth; + } + + /** + * @return The elapsed time in milliseconds. + */ + public long getElapsedTime() { + return elapsed; + } + + //======================================== PRIVATE METHODS ====================================// + + private void addColliders(Graph graph) { + final Map scores = new ConcurrentHashMap<>(); + + List nodes = graph.getNodes(); + + class Task extends RecursiveTask { + int from; + int to; + int chunk = 20; + List nodes; + Graph graph; + + public Task(List nodes, Graph graph, Map scores, int from, int to) { + this.nodes = nodes; + this.graph = graph; + this.from = from; + this.to = to; + } + + @Override + protected Boolean compute() { + if (to - from <= chunk) { + for (int i = from; i < to; i++) { + doNode(graph, scores, nodes.get(i)); + } + + return true; + } else { + int mid = (to + from) / 2; + + Task left = new Task(nodes, graph, scores, from, mid); + Task right = new Task(nodes, graph, scores, mid, to); + + left.fork(); + right.compute(); + left.join(); + + return true; + } + } + } + + Task task = new Task(nodes, graph, scores, 0, nodes.size()); + + ForkJoinPoolInstance.getInstance().getPool().invoke(task); + + List tripleList = new ArrayList<>(scores.keySet()); + + // Most independent ones first. + Collections.sort(tripleList, new Comparator() { + + @Override + public int compare(Triple o1, Triple o2) { + return Double.compare(scores.get(o2), scores.get(o1)); + } + }); + + for (Triple triple : tripleList) { + Node a = triple.getX(); + Node b = triple.getY(); + Node c = triple.getZ(); + + if (!(graph.getEndpoint(b, a) == Endpoint.ARROW || graph.getEndpoint(b, c) == Endpoint.ARROW)) { +// graph.setEndpoint(a, b, Endpoint.ARROW); +// graph.setEndpoint(c, b, Endpoint.ARROW); + orientCollider(graph, a, b, c); + } + } + } + + private void doNode(Graph graph, Map scores, Node b) { + List adjacentNodes = graph.getAdjacentNodes(b); + + if (adjacentNodes.size() < 2) { + return; + } + + ChoiceGenerator cg = new ChoiceGenerator(adjacentNodes.size(), 2); + int[] combination; + + while ((combination = cg.next()) != null) { + Node a = adjacentNodes.get(combination[0]); + Node c = adjacentNodes.get(combination[1]); + + // Skip triples that are shielded. + if (graph.isAdjacentTo(a, c)) { + continue; + } + + if (useHeuristic) { + if (existsShortPath(a, c, maxPathLength, graph)) { + testColliderMaxP(graph, scores, a, b, c); + } else { + testColliderHeuristic(graph, scores, a, b, c); + } + } else { + testColliderMaxP(graph, scores, a, b, c); + } + } + } + + private void testColliderMaxP(Graph graph, Map scores, Node a, Node b, Node c) { + List adja = graph.getAdjacentNodes(a); + double score = Double.POSITIVE_INFINITY; + List S = null; + + DepthChoiceGenerator cg2 = new DepthChoiceGenerator(adja.size(), -1); + int[] comb2; + + while ((comb2 = cg2.next()) != null) { + List s = GraphUtils.asList(comb2, adja); + independenceTest.isIndependent(a, c, s); + double _score = independenceTest.getScore(); + + if (_score < score) { + score = _score; + S = s; + } + } + + List adjc = graph.getAdjacentNodes(c); + + DepthChoiceGenerator cg3 = new DepthChoiceGenerator(adjc.size(), -1); + int[] comb3; + + while ((comb3 = cg3.next()) != null) { + List s = GraphUtils.asList(comb3, adjc); + independenceTest.isIndependent(c, a, s); + double _score = independenceTest.getScore(); + + if (_score < score) { + score = _score; + S = s; + } + } + + // S actually has to be non-null here, but the compiler doesn't know that. + if (S != null && !S.contains(b)) { + scores.put(new Triple(a, b, c), score); + } + } + + private void testColliderHeuristic(Graph graph, Map colliders, Node a, Node b, Node c) { + if (knowledge.isForbidden(a.getName(), b.getName())) { + return; + } + + if (knowledge.isForbidden(c.getName(), b.getName())) { + return; + } + + independenceTest.isIndependent(a, c); + double s1 = independenceTest.getScore(); + independenceTest.isIndependent(a, c, b); + double s2 = independenceTest.getScore(); + + boolean mycollider2 = s2 > s1; + + // Skip triples that are shielded. + if (graph.isAdjacentTo(a, c)) { + return; + } + + if (graph.getEdges(a, b).size() > 1 || graph.getEdges(b, c).size() > 1) { + return; + } + + if (mycollider2) { + colliders.put(new Triple(a, b, c), Math.abs(s2)); + } + } + + private void orientCollider(Graph graph, Node a, Node b, Node c) { + if (wouldCreateBadCollider(graph, a, b)) return; + if (wouldCreateBadCollider(graph, c, b)) return; + if (graph.getEdges(a, b).size() > 1) return; + if (graph.getEdges(b, c).size() > 1) return; + graph.removeEdge(a, b); + graph.removeEdge(c, b); + graph.addDirectedEdge(a, b); + graph.addDirectedEdge(c, b); + } + + private boolean wouldCreateBadCollider(Graph graph, Node x, Node y) { + for (Node z : graph.getAdjacentNodes(y)) { + if (x == z) continue; + + if (!graph.isAdjacentTo(x, z) && + graph.getEndpoint(z, y) == Endpoint.ARROW && + sepset(graph, x, z, set(), set(y)) == null) { + return true; + } + } + + return false; + } + + public IKnowledge getKnowledge() { + return knowledge; + } + + public void setKnowledge(IKnowledge knowledge) { + this.knowledge = knowledge; + } + + private boolean isForbidden(Node i, Node k, List v) { + for (Node w : v) { + if (knowledge.isForbidden(w.getName(), i.getName())) { + return true; + } + + if (knowledge.isForbidden(w.getName(), k.getName())) { + return true; + } + } + + return false; + } + + // Returns a sepset containing the nodes in 'containing' but not the nodes in 'notContaining', or + // null if there is no such sepset. + private List sepset(Graph graph, Node a, Node c, Set containing, Set notContaining) { + List adj = graph.getAdjacentNodes(a); + adj.addAll(graph.getAdjacentNodes(c)); + adj.remove(c); + adj.remove(a); + + for (int d = 0; d <= Math.min((depth == -1 ? 1000 : depth), Math.max(adj.size(), adj.size())); d++) { + if (d <= adj.size()) { + ChoiceGenerator gen = new ChoiceGenerator(adj.size(), d); + int[] choice; + + WHILE: + while ((choice = gen.next()) != null) { + Set v2 = GraphUtils.asSet(choice, adj); + v2.addAll(containing); + v2.removeAll(notContaining); + v2.remove(a); + v2.remove(c); + + if (isForbidden(a, c, new ArrayList<>(v2))) + + getIndependenceTest().isIndependent(a, c, new ArrayList<>(v2)); + double p2 = getIndependenceTest().getScore(); + + if (p2 < 0) { + return new ArrayList<>(v2); + } + } + } + } + + return null; + } + + private Set set(Node... n) { + Set S = new HashSet<>(); + Collections.addAll(S, n); + return S; + } + + private IndependenceTest getIndependenceTest() { + return independenceTest; + } + + // Returns true if there is an undirected path from x to either y or z within the given number of steps. + private boolean existsShortPath(Node x, Node z, int bound, Graph graph) { + Queue Q = new LinkedList<>(); + Set V = new HashSet<>(); + Q.offer(x); + V.add(x); + Node e = null; + int distance = 0; + + while (!Q.isEmpty()) { + Node t = Q.remove(); + + if (e == t) { + e = null; + distance++; + if (distance > (bound == -1 ? 1000 : bound)) return false; + } + + for (Node u : graph.getAdjacentNodes(t)) { + Edge edge = graph.getEdge(t, u); + Node c = Edges.traverse(t, edge); + if (c == null) continue; +// if (t == y && c == z && distance > 2) continue; + if (c == z && distance > 2) return true; + + if (!V.contains(c)) { + V.add(c); + Q.offer(c); + + if (e == null) { + e = u; + } + } + } + } + + return false; + } + + public boolean isUseHeuristic() { + return useHeuristic; + } + + public void setUseHeuristic(boolean useHeuristic) { + this.useHeuristic = useHeuristic; + } + + public int getMaxPathLength() { + return maxPathLength; + } + + public void setMaxPathLength(int maxPathLength) { + this.maxPathLength = maxPathLength; + } +} + + + + + + diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/PcMax.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/PcMax.java index 9b717cf5ac..7581e8327e 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/PcMax.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/PcMax.java @@ -21,16 +21,17 @@ package edu.cmu.tetrad.search; -import edu.cmu.tetrad.algcomparison.Comparison; -import edu.cmu.tetrad.algcomparison.statistic.Statistics; import edu.cmu.tetrad.data.IKnowledge; import edu.cmu.tetrad.data.Knowledge2; -import edu.cmu.tetrad.graph.*; -import edu.cmu.tetrad.util.*; +import edu.cmu.tetrad.graph.Edge; +import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.graph.GraphUtils; +import edu.cmu.tetrad.graph.Node; +import edu.cmu.tetrad.util.TetradLogger; -import java.util.*; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.RecursiveTask; +import java.util.HashSet; +import java.util.List; +import java.util.Set; /** * Implements a modification of the the PC ("Peter/Clark") algorithm, as specified in Chapter 6 of @@ -82,6 +83,8 @@ public class PcMax implements GraphSearch { * True if verbose output should be printed. */ private boolean verbose = false; + private boolean useHeuristic; + private int maxPathLength; //=============================CONSTRUCTORS==========================// @@ -190,7 +193,10 @@ public Graph search(List nodes) { SearchGraphUtils.pcOrientbk(knowledge, graph, nodes); - addColliders(graph); + final OrientCollidersMaxP orientCollidersMaxP = new OrientCollidersMaxP(independenceTest); + orientCollidersMaxP.orient(graph); + orientCollidersMaxP.setUseHeuristic(useHeuristic); + orientCollidersMaxP.setMaxPathLength(maxPathLength); MeekRules rules = new MeekRules(); rules.setKnowledge(knowledge); @@ -246,132 +252,20 @@ public void setVerbose(boolean verbose) { this.verbose = verbose; } - private void addColliders(Graph graph) { - final Map scores = new ConcurrentHashMap<>(); - - List nodes = graph.getNodes(); - - class Task extends RecursiveTask { - int from; - int to; - int chunk = 20; - List nodes; - Graph graph; - - public Task(List nodes, Graph graph, Map scores, int from, int to) { - this.nodes = nodes; - this.graph = graph; - this.from = from; - this.to = to; - } - - @Override - protected Boolean compute() { - if (to - from <= chunk) { - for (int i = from; i < to; i++) { - doNode(graph, scores, nodes.get(i)); - } - - return true; - } else { - int mid = (to + from) / 2; - - Task left = new Task(nodes, graph, scores, from, mid); - Task right = new Task(nodes, graph, scores, mid, to); - - left.fork(); - right.compute(); - left.join(); - - return true; - } - } - } - - Task task = new Task(nodes, graph, scores, 0, nodes.size()); - - ForkJoinPoolInstance.getInstance().getPool().invoke(task); - - List tripleList = new ArrayList<>(scores.keySet()); - - Collections.sort(tripleList, new Comparator() { - - @Override - public int compare(Triple o1, Triple o2) { - return -Double.compare(scores.get(o2), scores.get(o1)); - } - }); - - for (Triple triple : tripleList) { - Node a = triple.getX(); - Node b = triple.getY(); - Node c = triple.getZ(); - - if (!(graph.getEndpoint(b, a) == Endpoint.ARROW || graph.getEndpoint(b, c) == Endpoint.ARROW)) { - graph.setEndpoint(a, b, Endpoint.ARROW); - graph.setEndpoint(c, b, Endpoint.ARROW); - } - } + public void setUseHeuristic(boolean useHeuristic) { + this.useHeuristic = useHeuristic; } - private void doNode(Graph graph, Map scores, Node b) { - List adjacentNodes = graph.getAdjacentNodes(b); - - if (adjacentNodes.size() < 2) { - return; - } - - ChoiceGenerator cg = new ChoiceGenerator(adjacentNodes.size(), 2); - int[] combination; - - while ((combination = cg.next()) != null) { - Node a = adjacentNodes.get(combination[0]); - Node c = adjacentNodes.get(combination[1]); - - // Skip triples that are shielded. - if (graph.isAdjacentTo(a, c)) { - continue; - } - - List adja = graph.getAdjacentNodes(a); - double score = Double.POSITIVE_INFINITY; - List S = null; - - DepthChoiceGenerator cg2 = new DepthChoiceGenerator(adja.size(), -1); - int[] comb2; - - while ((comb2 = cg2.next()) != null) { - List s = GraphUtils.asList(comb2, adja); - independenceTest.isIndependent(a, c, s); - double _score = independenceTest.getScore(); - - if (_score < score) { - score = _score; - S = s; - } - } - - List adjc = graph.getAdjacentNodes(c); - - DepthChoiceGenerator cg3 = new DepthChoiceGenerator(adjc.size(), -1); - int[] comb3; - - while ((comb3 = cg3.next()) != null) { - List s = GraphUtils.asList(comb3, adjc); - independenceTest.isIndependent(c, a, s); - double _score = independenceTest.getScore(); + public boolean isUseHeuristic() { + return useHeuristic; + } - if (_score < score) { - score = _score; - S = s; - } - } + public void setMaxPathLength(int maxPathLength) { + this.maxPathLength = maxPathLength; + } - // S actually has to be non-null here, but the compiler doesn't know that. - if (S != null && !S.contains(b)) { - scores.put(new Triple(a, b, c), score); - } - } + public int getMaxPathLength() { + return maxPathLength; } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Rfci.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Rfci.java index db879dfa40..2a0942e58e 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Rfci.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Rfci.java @@ -216,6 +216,8 @@ public Graph search(IFas fas, List nodes) { logger.log("info", "Elapsed time adjacency search = " + (stop1 - start1) / 1000L + "s"); logger.log("info", "Elapsed time orientation search = " + (stop2 - start2) / 1000L + "s"); + graph.setPag(true); + return graph; } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Score.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Score.java index b5f595ea68..fb1effdf05 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Score.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Score.java @@ -26,7 +26,7 @@ import java.util.List; /** - * Interface for a score suitable for FGS + * Interface for a score suitable for FGES */ public interface Score { double localScore(int node, int...parents); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SearchGraphUtils.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SearchGraphUtils.java index 807fdb66d0..5510f4d5a8 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SearchGraphUtils.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SearchGraphUtils.java @@ -30,7 +30,6 @@ import java.text.DecimalFormat; import java.text.NumberFormat; import java.util.*; -import java.util.concurrent.RecursiveTask; /** * Graph utilities for search algorithm. Lots of orientation method, for instance. @@ -3184,7 +3183,7 @@ public static Graph reorient(Graph graph, DataModel dataModel, IKnowledge knowle dataSets.add(_dataModel); } - Fgs images = new Fgs(new SemBicScoreImages(dataSets)); + Fges images = new Fges(new SemBicScoreImages(dataSets)); images.setBoundGraph(graph); images.setKnowledge(knowledge); @@ -3202,7 +3201,7 @@ public static Graph reorient(Graph graph, DataModel dataModel, IKnowledge knowle throw new NullPointerException(); } - Fgs ges = new Fgs(score); + Fges ges = new Fges(score); ges.setBoundGraph(graph); ges.setKnowledge(knowledge); @@ -3211,7 +3210,7 @@ public static Graph reorient(Graph graph, DataModel dataModel, IKnowledge knowle ICovarianceMatrix cov = (CovarianceMatrix) dataModel; Score score = new SemBicScore(cov); - Fgs ges = new Fgs(score); + Fges ges = new Fges(score); ges.setBoundGraph(graph); ges.setKnowledge(knowledge); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SemBicScore.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SemBicScore.java index d214b02e59..79380c2127 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SemBicScore.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SemBicScore.java @@ -35,7 +35,7 @@ import java.util.Set; /** - * Implements the continuous BIC score for FGS. + * Implements the continuous BIC score for FGES. * * @author Joseph Ramsey */ @@ -54,7 +54,7 @@ public class SemBicScore implements Score { private double penaltyDiscount = 2.0; // True if linear dependencies should return NaN for the score, and hence be - // ignored by FGS + // ignored by FGES private boolean ignoreLinearDependent = false; // The printstream output should be sent to. diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SemBicScore2.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SemBicScore2.java index 89475d4758..5ea1c55d77 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SemBicScore2.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SemBicScore2.java @@ -34,7 +34,7 @@ import java.util.Set; /** - * Implements the continuous BIC score for FGS. + * Implements the continuous BIC score for FGES. * * @author Joseph Ramsey */ @@ -55,7 +55,7 @@ public class SemBicScore2 implements Score { private double penaltyDiscount = 2.0; // True if linear dependencies should return NaN for the score, and hence be - // ignored by FGS + // ignored by FGES private boolean ignoreLinearDependent = false; // The printstream output should be sent to. diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SemBicScoreImages.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SemBicScoreImages.java index 4527933c70..de0dca3e67 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SemBicScoreImages.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SemBicScoreImages.java @@ -35,7 +35,7 @@ import java.util.List; /** - * Implements the continuous BIC score for FGS. + * Implements the continuous BIC score for FGES. * * @author Joseph Ramsey */ @@ -53,7 +53,7 @@ public class SemBicScoreImages implements ISemBicScore, Score { private double penaltyDiscount = 2.0; // True if linear dependencies should return NaN for the score, and hence be - // ignored by FGS + // ignored by FGES private boolean ignoreLinearDependent = false; // The printstream output should be sent to. diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SemBicScoreImages2.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SemBicScoreImages2.java new file mode 100644 index 0000000000..8b69aa4514 --- /dev/null +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SemBicScoreImages2.java @@ -0,0 +1,333 @@ +/////////////////////////////////////////////////////////////////////////////// +// For information as to what this class does, see the Javadoc, below. // +// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // +// 2007, 2008, 2009, 2010, 2014, 2015 by Peter Spirtes, Richard Scheines, Joseph // +// Ramsey, and Clark Glymour. // +// // +// This program is free software; you can redistribute it and/or modify // +// it under the terms of the GNU General Public License as published by // +// the Free Software Foundation; either version 2 of the License, or // +// (at your option) any later version. // +// // +// This program is distributed in the hope that it will be useful, // +// but WITHOUT ANY WARRANTY; without even the implied warranty of // +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // +// GNU General Public License for more details. // +// // +// You should have received a copy of the GNU General Public License // +// along with this program; if not, write to the Free Software // +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // +/////////////////////////////////////////////////////////////////////////////// + +package edu.cmu.tetrad.search; + +import edu.cmu.tetrad.data.CovarianceMatrixOnTheFly; +import edu.cmu.tetrad.data.DataModel; +import edu.cmu.tetrad.data.DataSet; +import edu.cmu.tetrad.data.ICovarianceMatrix; +import edu.cmu.tetrad.graph.Node; +import edu.cmu.tetrad.util.DepthChoiceGenerator; +import edu.cmu.tetrad.util.TetradMatrix; +import edu.cmu.tetrad.util.TetradVector; + +import java.io.PrintStream; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** + * Implements the continuous BIC score for FGES. + * + * @author Joseph Ramsey + */ +public class SemBicScoreImages2 implements Score { + + // The covariance matrix. + private List covariances; + + // The variables of the covariance matrix. + private List variables; + + // The sample size of the covariance matrix. + private int sampleSize; + + // The penalty penaltyDiscount. + private double penaltyDiscount = 2.0; + + // True if linear dependencies should return NaN for the score, and hence be + // ignored by FGES + private boolean ignoreLinearDependent = false; + + // The printstream output should be sent to. + private PrintStream out = System.out; + + // True if verbose output should be sent to out. + private boolean verbose = false; + private Set forbidden = new HashSet<>(); + + /** + * Constructs the score using a covariance matrix. + */ + public SemBicScoreImages2(List dataModels) { + if (dataModels == null) { + throw new NullPointerException(); + } + + this.penaltyDiscount = 2; + this.variables = dataModels.get(0).getVariables(); + + covariances = new ArrayList<>(); + + for (DataModel model : dataModels) { + if (model instanceof DataSet) { + DataSet dataSet = (DataSet) model; + + if (!dataSet.isContinuous()) { + throw new IllegalArgumentException("Datasets must be continuous."); + } + + CovarianceMatrixOnTheFly cov = new CovarianceMatrixOnTheFly(dataSet); + cov.setVariables(variables); + covariances.add(cov); + } else if (model instanceof ICovarianceMatrix) { + ((ICovarianceMatrix) model).setVariables(variables); + covariances.add((ICovarianceMatrix) model); + } else { + throw new IllegalArgumentException("Only continuous data sets and covariance matrices may be used as input."); + } + } + + this.sampleSize = covariances.get(0).getSampleSize(); + } + + /** + * Calculates the sample likelihood and BIC score for i given its parents in a simple SEM model + */ + public double localScore(int i, int... parents) { + for (int p : parents) if (forbidden.contains(p)) return Double.NaN; + double lik = 0.0; + + for (int k = 0; k < covariances.size(); k++) { + double residualVariance = getCovariances(k).getValue(i, i); + TetradMatrix covxx = getSelection1(getCovariances(k), parents); + + try { + TetradMatrix covxxInv = covxx.inverse(); + + TetradVector covxy = getSelection2(getCovariances(k), parents, i); + TetradVector b = covxxInv.times(covxy); + residualVariance -= covxy.dotProduct(b); + + if (residualVariance <= 0) { + if (isVerbose()) { + out.println("Nonpositive residual varianceY: resVar / varianceY = " + + (residualVariance / getCovariances(k).getValue(i, i))); + } + return Double.NaN; + } + + int cols = getCovariances(0).getDimension(); + double q = 2 / (double) cols; + lik += -sampleSize * Math.log(residualVariance); + } catch (Exception e) { + boolean removedOne = true; + + while (removedOne) { + List _parents = new ArrayList<>(); + for (int y = 0; y < parents.length; y++) _parents.add(parents[y]); + _parents.removeAll(forbidden); + parents = new int[_parents.size()]; + for (int y = 0; y < _parents.size(); y++) parents[y] = _parents.get(y); + removedOne = printMinimalLinearlyDependentSet(parents, getCovariances(k)); + } + + return Double.NaN; + } + } + + int p = parents.length; + double c = getPenaltyDiscount(); + return 2 * lik - c * (p + 1) * Math.log(covariances.size() * sampleSize); + } + + @Override + public double localScoreDiff(int x, int y, int[] z) { + return localScore(y, append(z, x)) - localScore(y, z); + } + + @Override + public double localScoreDiff(int x, int y) { + return localScore(y, x) - localScore(y); + } + + private int[] append(int[] parents, int extra) { + int[] all = new int[parents.length + 1]; + System.arraycopy(parents, 0, all, 0, parents.length); + all[parents.length] = extra; + return all; + } + + /** + * Specialized scoring method for a single parent. Used to speed up the effect edges search. + */ + + public double localScore(int i, int parent) { + return localScore(i, new int[]{parent}); + } + + /** + * Specialized scoring method for no parents. Used to speed up the effect edges search. + */ + public double localScore(int i) { + return localScore(i, new int[0]); + } + + /** + * True iff edges that cause linear dependence are ignored. + */ + public boolean isIgnoreLinearDependent() { + return ignoreLinearDependent; + } + + public void setIgnoreLinearDependent(boolean ignoreLinearDependent) { + this.ignoreLinearDependent = ignoreLinearDependent; + } + + public void setOut(PrintStream out) { + this.out = out; + } + + public double getPenaltyDiscount() { + return penaltyDiscount; + } + + private ICovarianceMatrix getCovariances(int i) { + return covariances.get(i); + } + + public int getSampleSize() { + return sampleSize; + } + + @Override + public boolean isEffectEdge(double bump) { + return bump > 0;//-0.25 * getPenaltyDiscount() * Math.log(sampleSize); + } + + public DataSet getDataSet() { + throw new UnsupportedOperationException(); + } + + public void setPenaltyDiscount(double penaltyDiscount) { + this.penaltyDiscount = penaltyDiscount; + } + + public boolean isVerbose() { + return verbose; + } + + public void setVerbose(boolean verbose) { + this.verbose = verbose; + } + + @Override + public List getVariables() { + return variables; + } + + @Override + public double getParameter1() { + return penaltyDiscount; + } + + @Override + public void setParameter1(double alpha) { + this.penaltyDiscount = alpha; + } + + // Calculates the BIC score. + private double score(double residualVariance, int n, double logn, int p, double c) { + int cols = getCovariances(0).getDimension(); + double q = 2 / (double) cols; + double bic = -n * Math.log(residualVariance) - c * (p + 1) * logn; + double structPrior = (p * Math.log(q) + (cols - p) * Math.log(1.0 - q)); + return bic;//+ structPrior; + } + + // Calculates the BIC score. +// private double score(double residualVariance, int n, double logn, int p, double c) { +// int cols = getDataSet().getNumColumns(); +// double q = 2 / (double) cols; +// +// return -n * Math.log(residualVariance) - c * (p + 1) * logn; +// +//// return -n * Math.log(residualVariance) - c * (p + 1) * logn + (p * Math.log(q) + (n - p) * Math.log(1.0 - q)); +// } + + + private TetradMatrix getSelection1(ICovarianceMatrix cov, int[] rows) { + return cov.getSelection(rows, rows); + } + + private TetradVector getSelection2(ICovarianceMatrix cov, int[] rows, int k) { + return cov.getSelection(rows, new int[]{k}).getColumn(0); + } + + // Prints a smallest subset of parents that causes a singular matrix exception. + private boolean printMinimalLinearlyDependentSet(int[] parents, ICovarianceMatrix cov) { + List _parents = new ArrayList<>(); + for (int p : parents) _parents.add(variables.get(p)); + + DepthChoiceGenerator gen = new DepthChoiceGenerator(_parents.size(), _parents.size()); + int[] choice; + + while ((choice = gen.next()) != null) { + int[] sel = new int[choice.length]; + List _sel = new ArrayList<>(); + for (int m = 0; m < choice.length; m++) { + sel[m] = parents[m]; + _sel.add(variables.get(sel[m])); + } + + TetradMatrix m = cov.getSelection(sel, sel); + + try { + m.inverse(); + } catch (Exception e2) { + forbidden.add(sel[0]); + out.println("### Linear dependence among variables: " + _sel); + out.println("### Removing " + _sel.get(0)); + return true; + } + } + + return false; + } + + public void setVariables(List variables) { + for (ICovarianceMatrix cov : covariances) { + cov.setVariables(variables); + } + this.variables = variables; + } + + @Override + public Node getVariable(String targetName) { + for (Node node : variables) { + if (node.getName().equals(targetName)) { + return node; + } + } + + return null; + } + + @Override + public int getMaxDegree() { + return (int) Math.ceil(Math.log(sampleSize)); + } +} + + + diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SemGpScore.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SemGpScore.java index 1ba7d33732..2d38e476d0 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SemGpScore.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SemGpScore.java @@ -36,7 +36,7 @@ //import java.util.Set; // ///** -// * Implements the continuous BIC score for FGS. +// * Implements the continuous BIC score for FGES. // * // * @author Joseph Ramsey // */ @@ -54,7 +54,7 @@ // private double penaltyDiscount = 2.0; // // // True if linear dependencies should return NaN for the score, and hence be -// // ignored by FGS +// // ignored by FGES // private boolean ignoreLinearDependent = false; // // // The printstream output should be sent to. diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SepsetMap.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SepsetMap.java index 16d688d4c4..2c520a54fc 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SepsetMap.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SepsetMap.java @@ -96,12 +96,12 @@ public void setPValue(Node x, Node y, double p) { } /** - * Retrieves the sepset previously set for {x, y}, or null if no such set was previously set. + * Retrieves the sepset previously set for {a, b}, or null if no such set was previously set. */ - public List get(Node x, Node y) { + public List get(Node a, Node b) { Set pair = new HashSet<>(2); - pair.add(x); - pair.add(y); + pair.add(a); + pair.add(b); if (correlations != null && !correlations.contains(pair)) { return Collections.emptyList(); @@ -147,7 +147,6 @@ public boolean equals(Object o) { return sepsets.equals(_sepset.sepsets); } - /** * Adds semantic checks to the default deserialization method. This method must have the standard signature for a * readObject method, and the body of the method must begin with "s.defaultReadObject();". Other than that, any @@ -155,9 +154,6 @@ public boolean equals(Object o) { * this form may be added to any class, even if Tetrad sessions were previously saved out using a version of the * class that didn't include it. (That's what the "s.defaultReadObject();" is for. See J. Bloch, Effective Java, for * help. - * - * @throws java.io.IOException - * @throws ClassNotFoundException */ private void readObject(ObjectInputStream s) throws IOException, ClassNotFoundException { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SepsetsMinScore.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SepsetsMinScore.java index 13703bfe11..94885dea92 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SepsetsMinScore.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SepsetsMinScore.java @@ -27,10 +27,11 @@ import edu.cmu.tetrad.util.ChoiceGenerator; import java.util.List; +import java.util.Set; /** * One is often faced with the following problem. We start by estimating the adjacencies using - * FGS followed by FAS. But if X is not adjacent to Z in the resulting graph, and you ask for + * FGES followed by FAS. But if X is not adjacent to Z in the resulting graph, and you ask for * a sepset, one must be given. So we return a conditioning set that minimizes a score (is as * close to independence as possible, or independent). * @@ -42,6 +43,7 @@ public class SepsetsMinScore implements SepsetProducer { private int depth = 3; private double p = Double.NaN; private boolean verbose = false; + private boolean returnNullWhenIndep; public SepsetsMinScore(Graph graph, IndependenceTest independenceTest, int depth) { this.graph = graph; @@ -90,9 +92,16 @@ private List getMinSepset(Node i, Node k) { getIndependenceTest().isIndependent(i, k, v2); double p2 = getIndependenceTest().getScore(); - if (p2 < _p && p2 < 0) { - _p = p2; - _v = v2; + if (returnNullWhenIndep) { + if (p2 < _p && p2 < 0) { + _p = p2; + _v = v2; + } + } else { + if (p2 < _p) { + _p = p2; + _v = v2; + } } } } @@ -107,9 +116,16 @@ private List getMinSepset(Node i, Node k) { getIndependenceTest().isIndependent(i, k, v2); double p2 = getIndependenceTest().getScore(); - if (p2 < _p && p2 < 0) { - _p = p2; - _v = v2; + if (returnNullWhenIndep) { + if (p2 < _p && p2 < 0) { + _p = p2; + _v = v2; + } + } else { + if (p2 < _p) { + _p = p2; + _v = v2; + } } } } @@ -152,5 +168,84 @@ public boolean isVerbose() { public void setVerbose(boolean verbose) { this.verbose = verbose; } + + public void setReturnNullWhenIndep(boolean returnNullWhenIndep) { + this.returnNullWhenIndep = returnNullWhenIndep; + } + + public boolean isReturnNullWhenIndep() { + return returnNullWhenIndep; + } + + public List getSepset(Node a, Node y, Set inSet) { + return getMinSepset(a, y, inSet); + } + + private List getMinSepset(Node i, Node k, Set insSet) { + double _p = Double.POSITIVE_INFINITY; + List _v = null; + + List adji = graph.getAdjacentNodes(i); + List adjk = graph.getAdjacentNodes(k); + adji.remove(k); + adjk.remove(i); + + for (int d = 0; d <= Math.min((depth == -1 ? 1000 : depth), Math.max(adji.size(), adjk.size())); d++) { + if (d <= adji.size()) { + ChoiceGenerator gen = new ChoiceGenerator(adji.size(), d); + int[] choice; + + while ((choice = gen.next()) != null) { + List v2 = GraphUtils.asList(choice, adji); + + if (!insSet.containsAll(v2)) continue; + + getIndependenceTest().isIndependent(i, k, v2); + double p2 = getIndependenceTest().getScore(); + + if (returnNullWhenIndep) { + if (p2 < _p && p2 < 0) { + _p = p2; + _v = v2; + } + } else { + if (p2 < _p) { + _p = p2; + _v = v2; + } + } + } + } + + if (d <= adjk.size()) { + ChoiceGenerator gen = new ChoiceGenerator(adjk.size(), d); + int[] choice; + + while ((choice = gen.next()) != null) { + List v2 = GraphUtils.asList(choice, adjk); + + if (!insSet.containsAll(v2)) continue; + + getIndependenceTest().isIndependent(i, k, v2); + double p2 = getIndependenceTest().getScore(); + + if (returnNullWhenIndep) { + if (p2 < _p && p2 < 0) { + _p = p2; + _v = v2; + } + } else { + if (p2 < _p) { + _p = p2; + _v = v2; + } + } + } + } + } + + this.p = _p; + return _v; + } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/ShiftSearch.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/ShiftSearch.java index 70af4b4c06..3ae6e6fbbe 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/ShiftSearch.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/ShiftSearch.java @@ -180,9 +180,9 @@ private List ensureNumRows(List dataSets, int numRows) { } private double getAvgBic(List dataSets) { - SemBicScoreImages fgsScore = new SemBicScoreImages(dataSets); - fgsScore.setPenaltyDiscount(c); - Fgs images = new Fgs(fgsScore); + SemBicScoreImages fgesScore = new SemBicScoreImages(dataSets); + fgesScore.setPenaltyDiscount(c); + Fges images = new Fges(fgesScore); images.setKnowledge(knowledge); images.search(); return -images.getModelScore() / dataSets.size(); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Test.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Test.java index cfa3e4902d..494c02223e 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Test.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Test.java @@ -21,49 +21,18 @@ package edu.cmu.tetrad.search; -import edu.cmu.tetrad.data.ContinuousVariable; -import edu.cmu.tetrad.data.DataSet; -import edu.cmu.tetrad.data.DiscreteVariable; -import edu.cmu.tetrad.graph.Node; -import edu.cmu.tetrad.util.CombinationIterator; -import edu.cmu.tetrad.util.TetradMatrix; -import org.apache.commons.math3.stat.correlation.Covariance; +import edu.cmu.tetrad.util.StatUtils; import java.util.ArrayList; -import java.util.Collections; -import java.util.Comparator; import java.util.List; /** - * Implements a conditional Gaussian BIC score for FGS. + * Implements a conditional Gaussian BIC score for FGES. * * @author Joseph Ramsey */ public class Test { - // Calculates the log of a list of terms, where the argument consists of the logs of the terms. - private double logOfSum(List logs) { - - Collections.sort(logs, new Comparator() { - @Override - public int compare(Double o1, Double o2) { - return -Double.compare(o1, o2); - } - }); - - double sum = 0.0; - int N = logs.size() - 1; - double loga0 = logs.get(0); - - for (int i = 1; i <= N; i++) { - sum += Math.exp(logs.get(i) - loga0); - } - - sum += 1; - - return loga0 + Math.log(sum); - } - @org.junit.Test public void test() { double[] a = {.3, .03, .01}; @@ -80,7 +49,7 @@ public void test() { sum += _a; } - double logsum = logOfSum(logs); + double logsum = StatUtils.logsum(logs); System.out.println(Math.exp(logsum) + " " + sum); } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/TimeSeriesUtils.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/TimeSeriesUtils.java index 800bf4e13e..08de56ab55 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/TimeSeriesUtils.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/TimeSeriesUtils.java @@ -171,7 +171,7 @@ public static VarResult structuralVar(DataSet timeSeries, int numLags) { throw new IllegalArgumentException("Mixed data set"); } - Fgs search = new Fgs(score); + Fges search = new Fges(score); search.setKnowledge(knowledge); Graph graph = search.search(); @@ -594,7 +594,7 @@ public int compare(Node o1, Node o2) { return laggedData; } - public static TimeLagGraph GraphToLagGraph(Graph _graph){ + public static TimeLagGraph graphToLagGraph(Graph _graph){ TimeLagGraph graph = new TimeLagGraph(); int numLags = 1; diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/TsFci.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/TsFci.java index 0a1e828404..458fbb2c80 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/TsFci.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/TsFci.java @@ -250,7 +250,7 @@ public Graph search(IFas fas) { // // System.out.println("Starting possible dsep search"); // PossibleDsepFci possibleDSep = new PossibleDsepFci(graph, independenceTest); -// possibleDSep.setMaxIndegree(getPossibleDsepDepth()); +// possibleDSep.setMaxDegree(getPossibleDsepDepth()); // possibleDSep.setKnowledge(getKnowledge()); // possibleDSep.setMaxPathLength(maxPathLength); // this.sepsets.addAll(possibleDSep.search()); @@ -278,6 +278,9 @@ public Graph search(IFas fas) { fciOrient.setKnowledge(knowledge); fciOrient.ruleR0(graph); fciOrient.doFinalOrientation(graph); + + graph.setPag(true); + return graph; } @@ -501,8 +504,8 @@ private void removeSimilarPairs(final IndependenceTest test, Node x, Node y, Lis if (B.equals(tier_x.get(indx_comp)) && A.equals(tier_y.get(indy_comp))) continue; x1 = test.getVariable(A); y1 = test.getVariable(B); - //adjacencies.get(x1).remove(y1); - //adjacencies.get(y1).remove(x1); + //adjacencies.get(x1).remove(y1); + //adjacencies.get(y1).remove(x1); graph.removeEdge(x1,y1); System.out.println("removed edge between " + x1 + " and " + y1 + " because of structure knowledge"); List condSetAB = new ArrayList<>(); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/TsFgs2.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/TsFges2.java similarity index 98% rename from tetrad-lib/src/main/java/edu/cmu/tetrad/search/TsFgs2.java rename to tetrad-lib/src/main/java/edu/cmu/tetrad/search/TsFges2.java index 8a2f44071d..3d51e53c7d 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/TsFgs2.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/TsFges2.java @@ -55,7 +55,7 @@ * @author Joseph Ramsey, Revisions 5/2015 * @author Daniel Malinsky */ -public final class TsFgs2 implements GraphSearch, GraphScorer { +public final class TsFges2 implements GraphSearch, GraphScorer { /** @@ -187,7 +187,7 @@ private enum Mode { * values in case of conditional independence. See Chickering (2002), * locally consistent scoring criterion. */ - public TsFgs2(Score score) { + public TsFges2(Score score) { if (score == null) throw new NullPointerException(); setScore(score); this.graph = new EdgeListGraphSingleConnections(getVariables()); @@ -1109,7 +1109,7 @@ private void calculateArrowsForward(Node a, Node b) { } Set naYX = getNaYX(a, b); - if (!isClique(naYX)) return; + if (!GraphUtils.isClique(naYX, this.graph)) return; List TNeighbors = getTNeighbors(a, b); int _maxIndegree = maxIndegree == -1 ? 1000 : maxIndegree; @@ -1144,7 +1144,7 @@ private void calculateArrowsForward(Node a, Node b) { break FOR; } - if (!isClique(union)) continue; + if (!GraphUtils.isClique(union, this.graph)) continue; newCliques.add(union); double bump = insertEval(a, b, T, naYX, hashIndices); @@ -1567,7 +1567,7 @@ private boolean validInsert(Node x, Node y, Set T, Set naYX) { Set union = new HashSet<>(T); union.addAll(naYX); - boolean clique = isClique(union); + boolean clique = GraphUtils.isClique(union, this.graph); boolean noCycle = !existsUnblockedSemiDirectedPath(y, x, union, cycleBound); return clique && noCycle && !violatesKnowledge; } @@ -1589,7 +1589,7 @@ private boolean validDelete(Node x, Node y, Set H, Set naYX) { Set diff = new HashSet<>(naYX); diff.removeAll(H); - return isClique(diff) && !violatesKnowledge; + return GraphUtils.isClique(diff, this.graph) && !violatesKnowledge; } // Adds edges required by knowledge. @@ -1684,22 +1684,6 @@ private Set getNaYX(Node x, Node y) { return nayx; } - Set cliqueEdges = new HashSet<>(); - - // Returns true iif the given set forms a clique in the given graph. - private boolean isClique(Set nodes) { - List _nodes = new ArrayList<>(nodes); - for (int i = 0; i < _nodes.size() - 1; i++) { - for (int j = i + 1; j < _nodes.size(); j++) { - if (!graph.isAdjacentTo(_nodes.get(i), _nodes.get(j))) { - return false; - } - } - } - - return true; - } - // Returns true if a path consisting of undirected and directed edges toward 'to' exists of // length at most 'bound'. Cycle checker in other words. private boolean existsUnblockedSemiDirectedPath(Node from, Node to, Set cond, int bound) { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/TsGFci.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/TsGFci.java index ba760a4716..9cf028f306 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/TsGFci.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/TsGFci.java @@ -152,27 +152,27 @@ public Graph search() { setScore(); } - TsFgs2 fgs = new TsFgs2(score); - fgs.setKnowledge(getKnowledge()); - fgs.setVerbose(verbose); - fgs.setNumPatternsToStore(0); - fgs.setFaithfulnessAssumed(faithfulnessAssumed); - graph = fgs.search(); - Graph fgsGraph = new EdgeListGraphSingleConnections(graph); - -// System.out.println("GFCI: FGS done"); - - sepsets = new SepsetsGreedy(fgsGraph, independenceTest, null, maxIndegree); -// ((SepsetsGreedy) sepsets).setMaxIndegree(3); -// sepsets = new SepsetsConservative(fgsGraph, independenceTest, null, maxIndegree); -// sepsets = new SepsetsConservativeMajority(fgsGraph, independenceTest, null, maxIndegree); -// sepsets = new SepsetsMaxPValue(fgsGraph, independenceTest, null, maxIndegree); -// sepsets = new SepsetsMinScore(fgsGraph, independenceTest, null, maxIndegree); + TsFges2 fges = new TsFges2(score); + fges.setKnowledge(getKnowledge()); + fges.setVerbose(verbose); + fges.setNumPatternsToStore(0); + fges.setFaithfulnessAssumed(faithfulnessAssumed); + graph = fges.search(); + Graph fgesGraph = new EdgeListGraphSingleConnections(graph); + +// System.out.println("GFCI: FGES done"); + + sepsets = new SepsetsGreedy(fgesGraph, independenceTest, null, maxIndegree); +// ((SepsetsGreedy) sepsets).setMaxDegree(3); +// sepsets = new SepsetsConservative(fgesGraph, independenceTest, null, maxIndegree); +// sepsets = new SepsetsConservativeMajority(fgesGraph, independenceTest, null, maxIndegree); +// sepsets = new SepsetsMaxPValue(fgesGraph, independenceTest, null, maxIndegree); +// sepsets = new SepsetsMinScore(fgesGraph, independenceTest, null, maxIndegree); // // System.out.println("GFCI: Look inside triangles starting"); for (Node b : nodes) { - List adjacentNodes = fgsGraph.getAdjacentNodes(b); + List adjacentNodes = fgesGraph.getAdjacentNodes(b); if (adjacentNodes.size() < 2) { continue; @@ -185,7 +185,7 @@ public Graph search() { Node a = adjacentNodes.get(combination[0]); Node c = adjacentNodes.get(combination[1]); - if (graph.isAdjacentTo(a, c) && fgsGraph.isAdjacentTo(a, c)) { + if (graph.isAdjacentTo(a, c) && fgesGraph.isAdjacentTo(a, c)) { if (sepsets.getSepset(a, c) != null) { graph.removeEdge(a, c); /** removing similar edges to enforce repeating structure **/ @@ -202,17 +202,17 @@ public Graph search() { // Node a = edge.getNode1(); // Node c = edge.getNode2(); // -// Edge e = fgsGraph.getEdge(a, c); +// Edge e = fgesGraph.getEdge(a, c); // // if (e != null && e.isDirected()) { // // // Only the ones that are in triangles. -// Set _adj = new HashSet<>(fgsGraph.getAdjacentNodes(a)); -// _adj.retainAll(fgsGraph.getAdjacentNodes(c)); +// Set _adj = new HashSet<>(fgesGraph.getAdjacentNodes(a)); +// _adj.retainAll(fgesGraph.getAdjacentNodes(c)); // if (_adj.isEmpty()) continue; // // Node f = Edges.getDirectedEdgeHead(e); -// List adj = fgsGraph.getAdjacentNodes(f); +// List adj = fgesGraph.getAdjacentNodes(f); // adj.remove(Edges.getDirectedEdgeTail(e)); // // DepthChoiceGenerator gen = new DepthChoiceGenerator(adj.size(), adj.size()); @@ -231,9 +231,9 @@ public Graph search() { // System.out.println("GFCI: Look inside triangles done"); - modifiedR0(fgsGraph); + modifiedR0(fgesGraph); -// modifiedR0(fgsGraph, map); +// modifiedR0(fgesGraph, map); // System.out.println("GFCI: R0 done"); @@ -251,6 +251,8 @@ public Graph search() { elapsedTime = time2 - time1; + graph.setPag(true); + return graph; } @@ -305,7 +307,7 @@ public void setMaxIndegree(int maxIndegree) { } // Due to Spirtes. - public void modifiedR0(Graph fgsGraph) { + public void modifiedR0(Graph fgesGraph) { graph.reorientAllWith(Endpoint.CIRCLE); fciOrientbk(knowledge, graph, graph.getNodes()); @@ -325,7 +327,7 @@ public void modifiedR0(Graph fgsGraph) { Node a = adjacentNodes.get(combination[0]); Node c = adjacentNodes.get(combination[1]); - if (fgsGraph.isDefCollider(a, b, c)) { + if (fgesGraph.isDefCollider(a, b, c)) { graph.setEndpoint(a, b, Endpoint.ARROW); graph.setEndpoint(c, b, Endpoint.ARROW); /** orienting similar pairs to enforce repeating structure **/ @@ -333,7 +335,7 @@ public void modifiedR0(Graph fgsGraph) { orientSimilarPairs(graph, knowledge, c, b, Endpoint.ARROW); /** **/ - } else if (fgsGraph.isAdjacentTo(a, c) && !graph.isAdjacentTo(a, c)) { + } else if (fgesGraph.isAdjacentTo(a, c) && !graph.isAdjacentTo(a, c)) { List sepset = sepsets.getSepset(a, c); if (sepset != null && !sepset.contains(b)) { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/WFgs.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/WFges.java similarity index 93% rename from tetrad-lib/src/main/java/edu/cmu/tetrad/search/WFgs.java rename to tetrad-lib/src/main/java/edu/cmu/tetrad/search/WFges.java index 23471a7286..dabc0c1cd0 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/WFgs.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/WFges.java @@ -9,19 +9,19 @@ import java.util.*; /** - * "Whimsical"FGS. Handles mixed, continuous, and discrete data. + * "Whimsical"FGES. Handles mixed, continuous, and discrete data. * * @author Joseph Ramsey */ -public class WFgs implements GraphSearch { +public class WFges implements GraphSearch { private List searchVariables; private Map> variablesPerNode = new HashMap<>(); - private Fgs fgs; + private Fges fges; private double penaltyDiscount; private SemBicScore score; - public WFgs(DataSet data) { + public WFges(DataSet data) { if (data == null) throw new NullPointerException("Data was not provided."); this.searchVariables = data.getVariables(); @@ -42,7 +42,7 @@ public WFgs(DataSet data) { SemBicScore score = new SemBicScore(covariances); this.score = score; - this.fgs = new Fgs(score); + this.fges = new Fges(score); } private List expandVariable(DataSet dataSet, Node node) { @@ -77,7 +77,7 @@ private List expandVariable(DataSet dataSet, Node node) { public Graph search() { score.setPenaltyDiscount(penaltyDiscount); - Graph g = fgs.search(); + Graph g = fges.search(); Graph out = new EdgeListGraph(searchVariables); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/WGfci.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/WGfci.java index b20262b1ab..8a3a4cab63 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/WGfci.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/WGfci.java @@ -9,7 +9,7 @@ import java.util.*; /** - * "Whimsical"FGS. Handles mixed, continuous, and discrete data. + * "Whimsical"FGES. Handles mixed, continuous, and discrete data. * * @author Joseph Ramsey */ diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/mb/Mmhc.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/mb/Mmhc.java index b42d6112bb..fce3754c5d 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/mb/Mmhc.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/mb/Mmhc.java @@ -27,7 +27,7 @@ import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.graph.Node; -import edu.cmu.tetrad.search.FgsOrienter; +import edu.cmu.tetrad.search.FgesOrienter; import edu.cmu.tetrad.search.GraphSearch; import edu.cmu.tetrad.search.IndependenceTest; @@ -104,7 +104,7 @@ public Graph search() { } } - FgsOrienter orienter = new FgsOrienter(data); + FgesOrienter orienter = new FgesOrienter(data); orienter.orient(graph); return graph; } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/sem/GeneralizedSemIm.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/sem/GeneralizedSemIm.java index 0d8c40b2bf..0fce701878 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/sem/GeneralizedSemIm.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/sem/GeneralizedSemIm.java @@ -295,7 +295,8 @@ public DataSet simulateData(int sampleSize, boolean latentDataSaved) { // return simulateDataRecursive(sampleSize, latentDataSaved); // return simulateDataMinimizeSurface(sampleSize, latentDataSaved); - return simulateDataAvoidInfinity(sampleSize, latentDataSaved); +// return simulateDataAvoidInfinity(sampleSize, latentDataSaved); + return simulateDataFisher(sampleSize); // return simulateDataNSteps(sampleSize, latentDataSaved); } @@ -821,6 +822,134 @@ public Double getValue(String term) { } + /** + * Simulates data using the model of R. A. Fisher, for a linear model. Shocks are + * applied every so many steps. A data point is recorded before each shock is + * administered. If convergence happens before that number of steps has been reached, + * a data point is recorded and a new shock immediately applied. The model may be + * cyclic. If cyclic, all eigenvalues for the coefficient matrix must be less than 1, + * though this is not checked. Uses an interval between shocks of 50 and a convergence + * threshold of 1e-5. Uncorrelated Gaussian shocks are used. + * + * @param sampleSize The number of samples to be drawn. Must be a positive + * integer. + */ + public DataSet simulateDataFisher(int sampleSize) { + return simulateDataFisher(sampleSize, 50, 1e-5); + } + + /** + * Simulates data using the model of R. A. Fisher, for a linear model. Shocks are + * applied every so many steps. A data point is recorded before each shock is + * administered. If convergence happens before that number of steps has been reached, + * a data point is recorded and a new shock immediately applied. The model may be + * cyclic. If cyclic, all eigenvalues for the coefficient matrix must be less than 1, + * though this is not checked. + * + * @param sampleSize The number of samples to be drawn. + * @param intervalBetweenShocks External shock is applied every this many steps. + * Must be positive integer. + * @param epsilon The convergence criterion; |xi.t - xi.t-1| < epsilon. + */ + public DataSet simulateDataFisher(int sampleSize, int intervalBetweenShocks, + double epsilon) { + if (intervalBetweenShocks < 1) throw new IllegalArgumentException( + "Interval between shocks must be >= 1: " + intervalBetweenShocks); + if (epsilon <= 0.0) throw new IllegalArgumentException( + "Epsilon must be > 0: " + epsilon); + + final Map variableValues = new HashMap<>(); + + final Context context = new Context() { + public Double getValue(String term) { + Double value = parameterValues.get(term); + + if (value != null) { + return value; + } + + value = variableValues.get(term); + + if (value != null) { + return value; + } + + throw new IllegalArgumentException("No value recorded for '" + term + "'"); + } + }; + + final List variableNodes = pm.getVariableNodes(); + + double[] t1 = new double[variableNodes.size()]; + double[] t2 = new double[variableNodes.size()]; + double[] shocks = new double[variableNodes.size()]; + double[][] all = new double[variableNodes.size()][sampleSize]; + + // Do the simulation. + for (int row = 0; row < sampleSize; row++) { + for (int j = 0; j < t1.length; j++) { + Node error = pm.getErrorNode(variableNodes.get(j)); + + if (error == null) { + throw new NullPointerException(); + } + + Expression expression = pm.getNodeExpression(error); + double value = expression.evaluate(context); + + if (Double.isNaN(value)) { + throw new IllegalArgumentException("Undefined value for expression: " + expression); + } + + variableValues.put(error.getName(), value); + shocks[j] = value; + } + + for (int i = 0; i < intervalBetweenShocks; i++) { + for (int j = 0; j < t1.length; j++) { + t2[j] = shocks[j]; + Node node = variableNodes.get(j); + Expression expression = pm.getNodeExpression(node); + t2[j] = expression.evaluate(context); + variableValues.put(node.getName(), t2[j]); + } + + boolean converged = true; + + for (int j = 0; j < t1.length; j++) { + if (Math.abs(t2[j] - t1[j]) > epsilon) { + converged = false; + break; + } + } + + double[] t3 = t1; + t1 = t2; + t2 = t3; + + if (converged) { + break; + } + } + + for (int j = 0; j < t1.length; j++) { + all[j][row] = t1[j]; + } + } + + List continuousVars = new ArrayList<>(); + + for (Node node : variableNodes) { + final ContinuousVariable var = new ContinuousVariable(node.getName()); + var.setNodeType(node.getNodeType()); + continuousVars.add(var); + } + + BoxDataSet boxDataSet = new BoxDataSet(new VerticalDoubleDataBox(all), continuousVars); + return DataUtils.restrictToMeasured(boxDataSet); + } + + public TetradVector simulateOneRecord(TetradVector e) { final Map variableValues = new HashMap<>(); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/sem/LargeScaleSimulation.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/sem/LargeScaleSimulation.java index 7dfc672089..ae82e82d23 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/sem/LargeScaleSimulation.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/sem/LargeScaleSimulation.java @@ -62,6 +62,7 @@ public final class LargeScaleSimulation { private boolean verbose = false; long seed = new Date().getTime(); private boolean alreadySetUp = false; + private boolean coefSymmetric = false; //=============================CONSTRUCTORS============================// @@ -242,7 +243,7 @@ public DataSet simulateDataReducedForm(int sampleSize) { * integer. */ public DataSet simulateDataFisher(int sampleSize) { - return simulateDataFisher(getUncorrelatedGaussianShocks(sampleSize), 50, 1e-5); + return simulateDataFisher(getSoCalledPoissonShocks(sampleSize), 50, 1e-5); } /** @@ -322,6 +323,63 @@ public DataSet simulateDataFisher(double[][] shocks, int intervalBetweenShocks, return DataUtils.restrictToMeasured(boxDataSet); } + public DataSet simulateDataFisher(int intervalBetweenShocks, int intervalBetweenRecordings, int sampleSize, double epsilon) { + if (intervalBetweenShocks < 1) throw new IllegalArgumentException( + "Interval between shocks must be >= 1: " + intervalBetweenShocks); + if (epsilon <= 0.0) throw new IllegalArgumentException( + "Epsilon must be > 0: " + epsilon); + + int size = variableNodes.size(); + + setupModel(size); + + double[] t1 = new double[variableNodes.size()]; + double[] t2 = new double[variableNodes.size()]; + double[][] all = new double[variableNodes.size()][sampleSize]; + + int s = 0; + int shockIndex = 0; + int recordingIndex = 0; + double[] shock = getUncorrelatedGaussianShocks(1)[0]; + + + while (s < sampleSize) { + if ((++shockIndex) % intervalBetweenShocks == 0) { + shock = getUncorrelatedGaussianShocks(1)[0]; + } + + if ((++recordingIndex) % intervalBetweenRecordings == 0) { + for (int j = 0; j < t1.length; j++) { + all[j][s] = t1[j]; + } + + s++; + } + + for (int j = 0; j < t1.length; j++) { + t2[j] = shock[j]; + for (int k = 0; k < parents[j].length; k++) { + t2[j] += t1[parents[j][k]] * coefs[j][k]; + } + } + + double[] t3 = t1; + t1 = t2; + t2 = t3; + } + + List continuousVars = new ArrayList<>(); + + for (Node node : getVariableNodes()) { + final ContinuousVariable var = new ContinuousVariable(node.getName()); + var.setNodeType(node.getNodeType()); + continuousVars.add(var); + } + + BoxDataSet boxDataSet = new BoxDataSet(new VerticalDoubleDataBox(all), continuousVars); + return DataUtils.restrictToMeasured(boxDataSet); + } + private void setupModel(int size) { if (alreadySetUp) return; @@ -361,7 +419,9 @@ private void setupModel(int size) { System.arraycopy(coefs, 0, newCoefs, 0, coefs.length); - newCoefs[newCoefs.length - 1] = edgeCoefDist.nextRandom(); + double coef = edgeCoefDist.nextRandom(); + if (coefSymmetric) coef = Math.abs(coef); + newCoefs[newCoefs.length - 1] = coef; this.parents[_head] = newParents; this.coefs[_head] = newCoefs; @@ -708,6 +768,10 @@ public double[][] getSoCalledPoissonShocks(int sampleSize) { return shocks; } + + public void setCoefSymmetric(boolean coefSymmetric) { + this.coefSymmetric = coefSymmetric; + } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/sem/SemIm.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/sem/SemIm.java index 64ea8b6249..54bd3810e6 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/sem/SemIm.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/sem/SemIm.java @@ -288,7 +288,6 @@ public static List getParameterNames() { } // Types of scores that yield a chi square value when minimized. - // The Fgsl was a typo that unfortunately I had to keep for serialization. public enum ScoreType { Fml, Fgls } @@ -2500,7 +2499,7 @@ private double initialValue(Parameter parameter) { final double covLow = getParams().getDouble("covLow", 0.1); final double covHigh = getParams().getDouble("covHigh", 0.2); double value = new Split(covLow, covHigh).nextRandom(); - if (getParams().getBoolean("coefSymmetric", true)) { + if (getParams().getBoolean("covSymmetric", true)) { return value; } else { return Math.abs(value); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/session/SessionNode.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/session/SessionNode.java index 9401a4df98..21185886c3 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/session/SessionNode.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/session/SessionNode.java @@ -21,10 +21,13 @@ package edu.cmu.tetrad.session; +import edu.cmu.tetrad.graph.Node; +import edu.cmu.tetrad.graph.NodeType; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.util.*; import javax.swing.*; +import java.beans.PropertyChangeListener; import java.io.IOException; import java.io.ObjectInputStream; import java.lang.reflect.Array; @@ -59,7 +62,7 @@ * @see SessionAdapter * @see SessionEvent */ -public class SessionNode implements TetradSerializable { +public class SessionNode implements Node, TetradSerializable { static final long serialVersionUID = 23L; /** @@ -292,7 +295,7 @@ public boolean addParent(SessionNode parent) { numModels[j] = parentClasses[j].length; } - if (isConsistentModelClass(modelClass, parentClasses, false)) { + if (isConsistentModelClass(modelClass, parentClasses, false, null)) { if (this.getModel() == null) { this.parents.add(parent); parent.addChild(this); @@ -314,6 +317,10 @@ public boolean addParent(SessionNode parent) { } public boolean isConsistentParent(SessionNode parent) { + return isConsistentParent(parent, null); + } + + public boolean isConsistentParent(SessionNode parent, List existingNodes) { if (this.parents.contains(parent)) { return false; } @@ -330,7 +337,7 @@ public boolean isConsistentParent(SessionNode parent) { Class[] thisClass = new Class[1]; if (getModel() != null) { - thisClass[0] = getModel().getClass(); + thisClass[0] = getModel().getClass(); } for (Class modelClass : getModel() != null ? thisClass : this.modelClasses) { @@ -347,7 +354,7 @@ public boolean isConsistentParent(SessionNode parent) { parentClasses[j] = node.getModelClasses(); } - if (isConsistentModelClass(modelClass, parentClasses, false)) { + if (isConsistentModelClass(modelClass, parentClasses, false, existingNodes)) { return true; } } @@ -390,7 +397,7 @@ public boolean addParent2(SessionNode parent) { numModels[j] = parentClasses[j].length; } - if (isConsistentModelClass(modelClass, parentClasses, false)) { + if (isConsistentModelClass(modelClass, parentClasses, false, null)) { if (this.getModel() == null) { this.parents.add(parent); parent.addChild(this); @@ -691,13 +698,13 @@ public final void setModelClasses(Class[] modelClasses) { } /** + * @param exact * @return those model classes among the possible model classes that are at * least consistent with the model class of the parent session nodes, in the * sense that possibly with the addition of more parent session nodes, and * assuming that the models of the parent session nodes are non-null, it is * possible to construct a model in one of the legal classes for this node * using the parent models as arguments to some constructor in that class. - * @param exact */ public Class[] getConsistentModelClasses(boolean exact) { List classes = new ArrayList<>(); @@ -724,7 +731,7 @@ public Class[] getConsistentModelClasses(boolean exact) { // If this model class is consistent, add it to the list. if (isConsistentModelClass(this.modelClasses[i], - parentModelClasses, exact)) { + parentModelClasses, exact, null)) { classes.add(modelClasses[i]); } } @@ -1008,6 +1015,26 @@ public Object[] getModelConstructorArguments(Class modelClass) { return null; } + @Override + public String getName() { + return null; + } + + @Override + public void setName(String name) { + + } + + @Override + public NodeType getNodeType() { + return null; + } + + @Override + public void setNodeType(NodeType nodeType) { + + } + /** * Prints out the name of the session node. */ @@ -1015,6 +1042,46 @@ public String toString() { return this.getBoxType(); } + @Override + public int getCenterX() { + return 0; + } + + @Override + public void setCenterX(int centerX) { + + } + + @Override + public int getCenterY() { + return 0; + } + + @Override + public void setCenterY(int centerY) { + + } + + @Override + public void setCenter(int centerX, int centerY) { + + } + + @Override + public void addPropertyChangeListener(PropertyChangeListener l) { + + } + + @Override + public Node like(String name) { + return null; + } + + @Override + public int compareTo(Object o) { + return 0; + } + /** * True iff the next edge should not be added. (Included for GUI user * control.) Reset to true every time an edge is added; edge adds must be @@ -1132,7 +1199,7 @@ public boolean isConsistentModelClass(Class modelClass, List nodes, boole nodeClasses[i] = node.getModelClasses(); } - return isConsistentModelClass(modelClass, nodeClasses, exact); + return isConsistentModelClass(modelClass, nodeClasses, exact, null); } /** @@ -1257,7 +1324,7 @@ public Object[] assignParameters(Class[] parameterTypes, List objects) } } - public boolean assignClasses(Class[] constructorTypes, Class[] modelTypes, boolean exact) + public boolean assignClasses(Class[] constructorTypes, Class[] modelTypes, boolean exact, List existingNodes) throws RuntimeException { for (Class parameterType1 : constructorTypes) { if (parameterType1 == null) { @@ -1266,6 +1333,31 @@ public boolean assignClasses(Class[] constructorTypes, Class[] modelTypes, boole } } + // Is it the case that for this constructor, every argument type is a model class for + // one of the existing session nodes? (You can skip Parameters classes.) + if (existingNodes != null) { + existingNodes.remove(this); + + for (Class type : constructorTypes) { + if (type.equals(Parameters.class)) continue; + boolean foundNode = false; + + FOR: + for (SessionNode node : existingNodes) { + for (Class clazz : node.getModelClasses()) { + if (clazz.equals(type)) { + foundNode = true; + break FOR; + } + } + } + + if (!foundNode) { + return false; + } + } + } + if (exact) { if (modelTypes.length != constructorTypes.length) { return false; @@ -1297,6 +1389,9 @@ public boolean assignClasses(Class[] constructorTypes, Class[] modelTypes, boole if (!constructorType.isAssignableFrom(modelType)) { allAssigned = false; } +// else { +// System.out.println("aa " + constructorType + " assignable from " + modelType); +// } } if (allAssigned) { @@ -1536,7 +1631,8 @@ private void createModelUsingArguments(Class modelClass, List models) /** * New version 2015901. */ - private boolean isConsistentModelClass(Class modelClass, Class[][] parentClasses, boolean exact) { + private boolean isConsistentModelClass(Class modelClass, Class[][] parentClasses, boolean exact, + List existingNodes) { Constructor[] constructors = modelClass.getConstructors(); // If the constructor takes the special form of an array followed by Parameters, @@ -1621,7 +1717,7 @@ private boolean isConsistentModelClass(Class modelClass, Class[][] parentClasses modelTypes[comb.length] = Parameters.class; - if (assignClasses(constructorTypes, modelTypes, exact)) { + if (assignClasses(constructorTypes, modelTypes, exact, existingNodes)) { return true; } } else { @@ -1633,7 +1729,7 @@ private boolean isConsistentModelClass(Class modelClass, Class[][] parentClasses modelTypes[i] = summary.get(i).get(comb[i]); } - if (assignClasses(constructorTypes, modelTypes, exact)) { + if (assignClasses(constructorTypes, modelTypes, exact, existingNodes)) { return true; } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/Gdistance.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/Gdistance.java index 9412cb7a50..e8ece5666c 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/Gdistance.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/Gdistance.java @@ -2,9 +2,12 @@ import edu.cmu.tetrad.data.DataSet; import edu.cmu.tetrad.graph.*; +import edu.cmu.tetrad.util.ForkJoinPoolInstance; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; +import java.util.concurrent.*; /** * Created by Erich on 7/3/2016. @@ -14,50 +17,175 @@ * the distance between two edges is calculated as the distance between their endpoints * the distance between edges calculated this way is a true distance * the distance between two graphs is not a true distance because it is not symmetric + * this version allows for non-cubic voxels, and parallelizes the most expensive loop */ public class Gdistance { - public static List distances(Graph graph1, Graph graph2, DataSet locationMap) { - //first, just impliment the brute force approach: - // compare every edge in graph1 to all in graph2 - // for each edge in graph1, record the shortest distance to any edge in graph2 - // return the list of shortest distances + private DataSet locationMap; + private double xDist; + private double yDist; + private double zDist; - double thisDistance = -1.0; - double leastDistance = -1.0; - int count = 1; - List leastList = new ArrayList<>(); + private List leastList; + + private int chunksize = 2; + + private int cores = ForkJoinPoolInstance.getInstance().getPool().getParallelism(); + + //With the parallel version, it is better to make a constructor for central data like locationMap + public Gdistance(DataSet locationMap, double xDist, double yDist, double zDist){ + this.locationMap = locationMap; + this.xDist=xDist; + this.yDist=yDist; + this.zDist=zDist; + this.leastList = new ArrayList<>(); + } + + public List distances(Graph graph1, Graph graph2) { + // needs to calculate distances for non-cubic voxels. + //dimensions along each dimension should be given as input: xdist, ydist, zdist + //this impliments a less brute force approach, where edge comparisons are restricted + //to edges that are in the "vicinity" of the original edge + + //List leastList = new ArrayList<>(); //System.out.println(locationMap); // Make *SURE* that the graph nodes are the same as the location nodes + System.out.println("Synchronizing variables between graph1, graph2, and the locationMap"); + long time1 = System.nanoTime(); graph1 = GraphUtils.replaceNodes(graph1,locationMap.getVariables()); graph2 = GraphUtils.replaceNodes(graph2,locationMap.getVariables()); + long time2 = System.nanoTime(); + System.out.println("Synchronizing time: " + (time2 - time1)/1000000000 + "s"); + + //constructing vicinity is costy, so do it just once, OUTSIDE any loops + //Using EK's Vicinity5 + System.out.println("Constructing vicinity object"); + long timevic1 = System.nanoTime(); + ArrayList graph2edges = new ArrayList<>(graph2.getEdges()); + final Vicinity vicinity = new Vicinity(graph2edges,locationMap,0,100,0,100,0,100,xDist,yDist,zDist); + long timevic2 = System.nanoTime(); + System.out.println("Done constructing vicinity object. Construction Time : " + (timevic2 - timevic1)/1000000000 + "s" ); + + //This for loop should be parallelized in the future. + //let the for loop do its thing, and create a new thread for each task inside of it. + //int edgetracker=1; + + //ForkJoinPool pool = ForkJoinPoolInstance.getInstance().getPool(); + List> todo = new ArrayList>(); + ExecutorService executorService = Executors.newCachedThreadPool(); + + List taskEdges = new ArrayList<>(); + //can change the times 3.0 part if it seems better to do so + int taskSize = (int) Math.ceil(graph1.getNumEdges()/(5.0*cores)); + System.out.println(" edges1: " + graph1.getNumEdges() + " taskSize: " + taskSize); + + for (final Edge edge1 : graph1.getEdges()) { + // for each choice we will create a task that will run on a separate thread + //System.out.println("edge#"+edgetracker); + //edgetracker++; + + //Add edges to taskEdges until it reaches a certain size, then spin off a thread + taskEdges.add(edge1); + + if (taskEdges.size() >= taskSize) { + //add the taskEdges to a new task, and then empty it + final List runEdges = new ArrayList<>(taskEdges); + todo.add(new Callable(){ + public Void call() throws Exception { + + FindLeastDistanceTask FLDtask = new FindLeastDistanceTask(vicinity); + FLDtask.compute(runEdges); + return null; + } + }); + + taskEdges.clear(); + } + + } + //add any leftover edge to a final task + if (!taskEdges.isEmpty()){ + //add the taskEdges to a new task, and then empty it + final List runEdges = new ArrayList<>(taskEdges); + todo.add(new Callable(){ + public Void call() throws Exception { + + FindLeastDistanceTask FLDtask = new FindLeastDistanceTask(vicinity); + FLDtask.compute(runEdges); + return null; + } + }); + + taskEdges.clear(); + } + //invoke all the things! + try{ + System.out.println("number of parallel tasks being invoked: " + todo.size()); + executorService.invokeAll(todo); + executorService.shutdown(); + } catch (Exception e){ + + } + System.out.println(leastList.size()); + return leastList; + } + + //////+++++******* Method used in multithread task + class FindLeastDistanceTask { + Vicinity vicinity; + + private FindLeastDistanceTask(final Vicinity vicinity) { + this.vicinity = vicinity; + + } + + protected void compute(List edges) { + //System.out.println("running thread"); + /* + try{ + TimeUnit.SECONDS.sleep(3); + } catch (Exception e){ - //This first for loop should be parallelized. - for (Edge edge1 : graph1.getEdges()) { - //the variable "count" is used to initialize leastDistance to the first thisDistance - count = 1; - for (Edge edge2 : graph2.getEdges()) { - thisDistance = edgesDistance(edge1, edge2, locationMap); - //remember only the shortest distance seen - if (count ==1) { - leastDistance = thisDistance; - } else { - if (thisDistance < leastDistance) { + } + */ + + for (Edge edge1 : edges){ + //the variable "count" is used to initialize leastDistance to the first thisDistance + int count = 1; + double thisDistance; + double leastDistance = -1.0; + //the next for loop gets restricted to edges in the vicinity of edge1 + List vicEdges = vicinity.getVicinity(edge1,chunksize); + //System.out.println(vicEdges); + for (Edge edge2 : vicEdges) { + thisDistance = edgesDistance(edge1, edge2, locationMap,xDist,yDist,zDist); + //remember only the shortest distance seen + if (count ==1) { leastDistance = thisDistance; + } else { + if (thisDistance < leastDistance) { + leastDistance = thisDistance; + } } + count++; } - count++; + //add it to a list of the leastDistances + //System.out.println("does this happen?"); + add(leastDistance); + //System.out.println(leastList); } - //add it to a list of the leastDistances - leastList.add(leastDistance); + } - return leastList; + } + + private synchronized void add(Double value) { + leastList.add(value); } //////======***PRIVATE METHODS BELOW *****=====///// - private static double nodesDistance(Node node1, Node node2, DataSet locationMap) { + private static double nodesDistance(Node node1, Node node2, DataSet locationMap, double x, double y, double z) { //calculate distance between two nodes based on their locations //simple starter is simply the taxicab distance: //calc differences in X, Y, and Z axis, then sum them together. @@ -77,12 +205,13 @@ private static double nodesDistance(Node node1, Node node2, DataSet locationMap) //taxicab distance //double taxicab = Math.abs(value11 - value21) + Math.abs(value12 - value22) + Math.abs(value13 - value23); //euclidian distance instead of taxicab - double euclid = Math.sqrt((value11 - value21)*(value11 - value21)+(value12 - value22)*(value12 - value22)+(value13 - value23)*(value13 - value23) ); + double euclid = Math.sqrt((value11 - value21)*x*(value11 - value21)*x+(value12 - value22)*y* + (value12 - value22)*y+(value13 - value23)*z*(value13 - value23)*z ); return euclid; } - private static double edgesDistance(Edge edge1, Edge edge2, DataSet locationMap) { + private static double edgesDistance(Edge edge1, Edge edge2, DataSet locationMap, double xD, double yD, double zD) { //calculate distance between two edges based on distances of their endpoints //if both edges are directed, then: //compare edge1 head to edge2 head, tail to tail. @@ -95,8 +224,8 @@ private static double edgesDistance(Edge edge1, Edge edge2, DataSet locationMap) Node edge2h = Edges.getDirectedEdgeHead(edge2); Node edge2t = Edges.getDirectedEdgeTail(edge2); //compare tail to tail - double tDistance = nodesDistance(edge1t, edge2t, locationMap); - double hDistance = nodesDistance(edge1h, edge2h, locationMap); + double tDistance = nodesDistance(edge1t, edge2t, locationMap,xD,yD,zD); + double hDistance = nodesDistance(edge1h, edge2h, locationMap,xD,yD,zD); return tDistance + hDistance; } else { @@ -109,16 +238,20 @@ private static double edgesDistance(Edge edge1, Edge edge2, DataSet locationMap) Node node22 = edge2.getNode2(); //first compare node1 to node1 and node2 to node2 - double dist11 = nodesDistance(node11, node21, locationMap); - double dist22 = nodesDistance(node12, node22, locationMap); + double dist11 = nodesDistance(node11, node21, locationMap,xD,yD,zD); + double dist22 = nodesDistance(node12, node22, locationMap,xD,yD,zD); //then compare node1 to node2 and node2 to node1 - double dist12 = nodesDistance(node11, node22, locationMap); - double dist21 = nodesDistance(node12, node21, locationMap); + double dist12 = nodesDistance(node11, node22, locationMap,xD,yD,zD); + double dist21 = nodesDistance(node12, node21, locationMap,xD,yD,zD); //then return the minimum of the two ways of pairing nodes from each edge return Math.min(dist11 + dist22, dist12 + dist21); } } + + public void setChunksize(int chunk){ + chunksize = chunk; + } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/GdistanceApply.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/GdistanceApply.java index 4d544c05b2..c0235eeb81 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/GdistanceApply.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/GdistanceApply.java @@ -16,21 +16,39 @@ */ public class GdistanceApply { public static void main (String... args) { + double xdist = 2.4; + double ydist = 2.4; + double zdist = 2; long timestart = System.nanoTime(); System.out.println("Loading first graph"); - Graph graph1 = GraphUtils.loadGraphTxt(new File("images_graph_10sub_pd40_group1.txt")); + Graph graph1 = GraphUtils.loadGraphTxt(new File("Motion_Corrected_Graphs/singlesub_motion_graph_025_04.txt")); long timegraph1 = System.nanoTime(); //System.out.println(graph1); System.out.println("Done loading first graph. Elapsed time: " + (timegraph1 - timestart)/1000000000 + "s"); System.out.println("Loading second graph"); - Graph graph2 = GraphUtils.loadGraphTxt(new File("images_graph_10sub_pd40_group2.txt")); + Graph graph2 = GraphUtils.loadGraphTxt(new File("Motion_Corrected_Graphs/singlesub_motion_graph_027_04.txt")); long timegraph2 = System.nanoTime(); System.out.println("Done loading second graph. Elapsed time: " + (timegraph2 - timegraph1)/1000000000 + "s"); + //+++++++++ these steps are specifically for the motion corrected fMRI graphs ++++++++++++ + graph1.removeNode(graph1.getNode("Motion_1")); + graph1.removeNode(graph1.getNode("Motion_2")); + graph1.removeNode(graph1.getNode("Motion_3")); + graph1.removeNode(graph1.getNode("Motion_4")); + graph1.removeNode(graph1.getNode("Motion_5")); + graph1.removeNode(graph1.getNode("Motion_6")); + + graph2.removeNode(graph2.getNode("Motion_1")); + graph2.removeNode(graph2.getNode("Motion_2")); + graph2.removeNode(graph2.getNode("Motion_3")); + graph2.removeNode(graph2.getNode("Motion_4")); + graph2.removeNode(graph2.getNode("Motion_5")); + graph2.removeNode(graph2.getNode("Motion_6")); + //load the location map String workingDirectory = System.getProperty("user.dir"); System.out.println(workingDirectory); - Path mapPath = Paths.get("erich_coordinates.txt"); + Path mapPath = Paths.get("coords.txt"); System.out.println(mapPath); edu.cmu.tetrad.io.DataReader dataReaderMap = new TabularContinuousDataReader(mapPath, ','); try{ @@ -39,8 +57,9 @@ public static void main (String... args) { System.out.println("Done loading location map. Elapsed time: " + (timegraph3 - timegraph2)/1000000000 + "s"); System.out.println("Running Gdistance"); - //Make this either Gdistance or GdistanceVic - List distance = GdistanceVic.distances(graph1,graph2,locationMap); + + Gdistance gdist = new Gdistance(locationMap,xdist,ydist,zdist); + List distance = gdist.distances(graph1,graph2); System.out.println(distance); System.out.println("Done running Distance. Elapsed time: " + (System.nanoTime() - timegraph3)/1000000000 + "s"); System.out.println("Total elapsed time: " + (System.nanoTime() - timestart)/1000000000 + "s"); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/GdistanceRandom.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/GdistanceRandom.java index a1718aa419..5182d0bb12 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/GdistanceRandom.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/GdistanceRandom.java @@ -18,6 +18,10 @@ public class GdistanceRandom { private int numEdges2; private boolean verbose = false; + double xdist = 2.4; + double ydist = 2.4; + double zdist = 2; + //**************CONSTRUCTORS*********************// public GdistanceRandom(DataSet inMap) { setLocationMap(inMap); @@ -55,7 +59,8 @@ private List randomPairSimulation(){ //run Gdistance on these two graphs if (verbose) System.out.println("running Gdistance on the patterns"); - return Gdistance.distances(graph1, graph2, locationMap); + Gdistance gdist = new Gdistance(locationMap,xdist,ydist,zdist); + return gdist.distances(graph1, graph2); } //**********Methods for setting values of private variables**************// diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/GdistanceTest.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/GdistanceTest.java index 311d6b2d14..c275a673f0 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/GdistanceTest.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/GdistanceTest.java @@ -7,6 +7,7 @@ import edu.cmu.tetrad.graph.Node; import edu.cmu.tetrad.io.TabularContinuousDataReader; +import java.io.PrintWriter; import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; @@ -18,8 +19,8 @@ public class GdistanceTest { public static void main (String... args) { //first generate a couple random graphs - int numVars = 5; - int numEdges = 4; + int numVars = 16; + int numEdges = 16; List vars = new ArrayList<>(); for (int i = 0; i < numVars; i++) { vars.add(new ContinuousVariable("X" + i)); @@ -37,10 +38,19 @@ public static void main (String... args) { edu.cmu.tetrad.io.DataReader dataReaderMap = new TabularContinuousDataReader(mapPath, ','); try{ DataSet locationMap = dataReaderMap.readInData(); + //System.out.println(locationMap); //then compare their distance - List output = Gdistance.distances(testdag1, testdag2, locationMap); + double xdist = 2.4; + double ydist = 2.4; + double zdist = 2; + Gdistance gdist = new Gdistance(locationMap,xdist,ydist,zdist); + List output = gdist.distances(testdag1, testdag2); System.out.println(output); + + PrintWriter writer = new PrintWriter("Gdistances.txt", "UTF-8"); + writer.println(output); + writer.close(); } catch(Exception IOException){ IOException.printStackTrace(); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/GdistanceVic.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/GdistanceVic.java deleted file mode 100644 index 90915a33aa..0000000000 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/GdistanceVic.java +++ /dev/null @@ -1,133 +0,0 @@ -package edu.cmu.tetrad.simulation; - -import edu.cmu.tetrad.data.DataSet; -import edu.cmu.tetrad.graph.*; - -import java.util.ArrayList; -import java.util.List; - -/** - * Created by Erich on 7/3/2016. - * - * This class is used to compare the distance between two graphs learned from fmri data - * the distance is calculated as the mean of the distance of the edges between the graphs - * the distance between two edges is calculated as the distance between their endpoints - * the distance between edges calculated this way is a true distance - * the distance between two graphs is not a true distance because it is not symmetric - */ -public class GdistanceVic { - - public static List distances(Graph graph1, Graph graph2, DataSet locationMap) { - //this impliments a less brute force approach, where edge comparisons are restricted - //to edges that are in the "vicinity" of the original edge - - double thisDistance = -1.0; - double leastDistance = -1.0; - int count = 1; - List leastList = new ArrayList<>(); - //System.out.println(locationMap); - // Make *SURE* that the graph nodes are the same as the location nodes - graph1 = GraphUtils.replaceNodes(graph1,locationMap.getVariables()); - graph2 = GraphUtils.replaceNodes(graph2,locationMap.getVariables()); - - //constructing vicinity is costy, so do it just once, OUTSIDE the first for loop - //Using EK's adapted Vicinity2 instead, since it uses the normal type of edges - System.out.println("Constructing vicinity object"); - long timevic1 = System.nanoTime(); - ArrayList graph2edges = new ArrayList<>(graph2.getEdges()); - Vicinity2 vicinity = new Vicinity2(graph2edges,locationMap,0,100,0,100,0,100); - long timevic2 = System.nanoTime(); - System.out.println("Done constructing vicinity object. Construction Time : " + (timevic2 - timevic1)/1000000000 + "s" ); - - //This first for loop should be parallelized in the future. - for (Edge edge1 : graph1.getEdges()) { - //the variable "count" is used to initialize leastDistance to the first thisDistance - count = 1; - //the next for loop gets restricted to edges in the vicinity of edge1 - List vicEdges = vicinity.getVicinity(edge1,locationMap); - for (Edge edge2 : vicEdges) { - thisDistance = edgesDistance(edge1, edge2, locationMap); - //remember only the shortest distance seen - if (count ==1) { - leastDistance = thisDistance; - } else { - if (thisDistance < leastDistance) { - leastDistance = thisDistance; - } - } - count++; - } - //add it to a list of the leastDistances - leastList.add(leastDistance); - } - return leastList; - } - - - //////======***PRIVATE METHODS BELOW *****=====///// - - private static double nodesDistance(Node node1, Node node2, DataSet locationMap) { - //calculate distance between two nodes based on their locations - //simple starter is simply the taxicab distance: - //calc differences in X, Y, and Z axis, then sum them together. - int column1 = locationMap.getColumn(node1); - int column2 = locationMap.getColumn(node2); - - //System.out.println(column1); - - double value11 = locationMap.getDouble(0,column1); - double value12 = locationMap.getDouble(1,column1); - double value13 = locationMap.getDouble(2,column1); - - double value21 = locationMap.getDouble(0,column2); - double value22 = locationMap.getDouble(1,column2); - double value23 = locationMap.getDouble(2,column2); - - //taxicab distance - //double taxicab = Math.abs(value11 - value21) + Math.abs(value12 - value22) + Math.abs(value13 - value23); - //euclidian distance instead of taxicab - double euclid = Math.sqrt((value11 - value21)*(value11 - value21)+(value12 - value22)*(value12 - value22)+(value13 - value23)*(value13 - value23) ); - - return euclid; - } - - private static double edgesDistance(Edge edge1, Edge edge2, DataSet locationMap) { - //calculate distance between two edges based on distances of their endpoints - //if both edges are directed, then: - //compare edge1 head to edge2 head, tail to tail. - //sum head distance and tail ditance - if (edge1.isDirected() && edge2.isDirected()) { - //find head and tail of edge1 - Node edge1h = Edges.getDirectedEdgeHead(edge1); - Node edge1t = Edges.getDirectedEdgeTail(edge1); - //find head and tail of edge2 - Node edge2h = Edges.getDirectedEdgeHead(edge2); - Node edge2t = Edges.getDirectedEdgeTail(edge2); - //compare tail to tail - double tDistance = nodesDistance(edge1t, edge2t, locationMap); - double hDistance = nodesDistance(edge1h, edge2h, locationMap); - return tDistance + hDistance; - } - else { - //otherwise if either edge is not directed: - //for each of edge1's two endpoints, calc distance to both edge2 endpoints - //store the shorter distances, and sum them. - Node node11 = edge1.getNode1(); - Node node12 = edge1.getNode2(); - Node node21 = edge2.getNode1(); - Node node22 = edge2.getNode2(); - - //first compare node1 to node1 and node2 to node2 - double dist11 = nodesDistance(node11, node21, locationMap); - double dist22 = nodesDistance(node12, node22, locationMap); - - //then compare node1 to node2 and node2 to node1 - double dist12 = nodesDistance(node11, node22, locationMap); - double dist21 = nodesDistance(node12, node21, locationMap); - - //then return the minimum of the two ways of pairing nodes from each edge - return Math.min(dist11 + dist22, dist12 + dist21); - } - - } -} diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/HsimAutoRun.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/HsimAutoRun.java index fc23367016..650556509b 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/HsimAutoRun.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/HsimAutoRun.java @@ -4,7 +4,7 @@ import edu.cmu.tetrad.graph.*; import edu.cmu.tetrad.io.VerticalTabularDiscreteDataReader; import edu.cmu.tetrad.search.BDeuScore; -import edu.cmu.tetrad.search.Fgs; +import edu.cmu.tetrad.search.Fges; import edu.cmu.tetrad.search.PatternToDag; import java.io.FileWriter; @@ -83,12 +83,12 @@ public double[] run(int resimSize) { //ICovarianceMatrix cov = new CovarianceMatrixOnTheFly(dataSet); double penaltyDiscount = 2.0; - Fgs fgs = new Fgs(score); - fgs.setVerbose(false); - fgs.setNumPatternsToStore(0); - fgs.setPenaltyDiscount(penaltyDiscount); + Fges fges = new Fges(score); + fges.setVerbose(false); + fges.setNumPatternsToStore(0); + fges.setPenaltyDiscount(penaltyDiscount); - Graph estGraph = fgs.search(); + Graph estGraph = fges.search(); //if (verbose) System.out.println(estGraph); Graph estPattern = new EdgeListGraphSingleConnections(estGraph); @@ -155,24 +155,24 @@ public double[] run(int resimSize) { DataWriter.writeRectangularData(newDataSet, fileWriter, delimiter); fileWriter.close(); } - //=======Run FGS on the output data, and compare it to the original learned graph + //=======Run FGES on the output data, and compare it to the original learned graph //Path dataFileOut = Paths.get(filenameOut); //edu.cmu.tetrad.io.DataReader dataReaderOut = new VerticalTabularDiscreteDataReader(dataFileOut, delimiter); //DataSet dataSetOut = dataReaderOut.readInData(eVars); BDeuScore newscore = new BDeuScore(newDataSet); - Fgs fgsOut = new Fgs(newscore); - fgsOut.setVerbose(false); - fgsOut.setNumPatternsToStore(0); - fgsOut.setPenaltyDiscount(2.0); - //fgsOut.setOut(out); - //fgsOut.setFaithfulnessAssumed(true); - // fgsOut.setMaxIndegree(1); - // fgsOut.setCycleBound(5); - - Graph estGraphOut = fgsOut.search(); - //if (verbose) System.out.println(" bugchecking: fgs estGraphOut: " + estGraphOut); + Fges fgesOut = new Fges(newscore); + fgesOut.setVerbose(false); + fgesOut.setNumPatternsToStore(0); + fgesOut.setPenaltyDiscount(2.0); + //fgesOut.setOut(out); + //fgesOut.setFaithfulnessAssumed(true); + // fgesOut.setMaxIndegree(1); + // fgesOut.setCycleBound(5); + + Graph estGraphOut = fgesOut.search(); + //if (verbose) System.out.println(" bugchecking: fges estGraphOut: " + estGraphOut); //doing the replaceNodes trick to fix some bugs estGraphOut = GraphUtils.replaceNodes(estGraphOut,estDAG.getNodes()); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/HsimRobustCompare.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/HsimRobustCompare.java index f0942d1747..fe1e545b49 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/HsimRobustCompare.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/HsimRobustCompare.java @@ -7,22 +7,20 @@ import edu.cmu.tetrad.bayes.*; import edu.cmu.tetrad.data.ContinuousVariable; import edu.cmu.tetrad.data.DataSet; -import edu.cmu.tetrad.data.DataWriter; import edu.cmu.tetrad.graph.*; import edu.cmu.tetrad.search.BDeuScore; -import edu.cmu.tetrad.search.Fgs; +import edu.cmu.tetrad.search.Fges; import edu.cmu.tetrad.search.PatternToDag; import edu.cmu.tetrad.util.RandomUtil; -import java.io.FileWriter; import java.util.ArrayList; import java.util.List; /** * generate data from random graph, generated from parameters. - * calculate errors from FGS output for the data and graph + * calculate errors from FGES output for the data and graph * create resimulated data and hybrid resimulated data with various parameters - * calculate errors of FGS on the resimulated and hsim data + * calculate errors of FGES on the resimulated and hsim data * compare errors across all data sets. which simulated data errors are closest to original? */ public class HsimRobustCompare { @@ -55,16 +53,16 @@ public static List run(int numVars,double edgesPerNode, int numCases,d //System.out.println(oData); //System.out.println(odag); - //then run FGS + //then run FGES BDeuScore oscore = new BDeuScore(oData); - Fgs ofgs = new Fgs(oscore); - ofgs.setVerbose(false); - ofgs.setNumPatternsToStore(0); - ofgs.setPenaltyDiscount(penaltyDiscount); - Graph oGraphOut = ofgs.search(); + Fges fges = new Fges(oscore); + fges.setVerbose(false); + fges.setNumPatternsToStore(0); + fges.setPenaltyDiscount(penaltyDiscount); + Graph oGraphOut = fges.search(); if (verbose) System.out.println(oGraphOut); - //calculate FGS errors + //calculate FGES errors oErrors = new double[5]; oErrors = HsimUtils.errorEval(oGraphOut, odag); if (verbose) System.out.println(oErrors[0] + " " + oErrors[1] + " " + oErrors[2] + @@ -72,12 +70,12 @@ public static List run(int numVars,double edgesPerNode, int numCases,d //create various simulated data sets - ////let's do the full simulated data set first: a dag in the FGS pattern fit to the data set. + ////let's do the full simulated data set first: a dag in the FGES pattern fit to the data set. PatternToDag pickdag = new PatternToDag(oGraphOut); - Graph fgsDag = pickdag.patternToDagMeek(); + Graph fgesDag = pickdag.patternToDagMeek(); - Dag fgsdag2 = new Dag(fgsDag); - BayesPm simBayesPm = new BayesPm(fgsdag2, bayesPm); + Dag fgesdag2 = new Dag(fgesDag); + BayesPm simBayesPm = new BayesPm(fgesdag2, bayesPm); DirichletBayesIm simIM = DirichletBayesIm.symmetricDirichletIm(simBayesPm, 1.0); DirichletEstimator simEstimator = new DirichletEstimator(); DirichletBayesIm fittedIM = simEstimator.estimate(simIM, oData); @@ -90,13 +88,13 @@ public static List run(int numVars,double edgesPerNode, int numCases,d //calculate errors for all simulated output graphs ////full simulation errors first BDeuScore simscore = new BDeuScore(simData); - Fgs simfgs = new Fgs(simscore); - simfgs.setVerbose(false); - simfgs.setNumPatternsToStore(0); - simfgs.setPenaltyDiscount(penaltyDiscount); - Graph simGraphOut = simfgs.search(); + Fges simfges = new Fges(simscore); + simfges.setVerbose(false); + simfges.setNumPatternsToStore(0); + simfges.setPenaltyDiscount(penaltyDiscount); + Graph simGraphOut = simfges.search(); //simErrors = new double[5]; - simErrors = HsimUtils.errorEval(simGraphOut, fgsdag2); + simErrors = HsimUtils.errorEval(simGraphOut, fgesdag2); //System.out.println("Full resim errors are: " + simErrors[0] + " " + simErrors[1] + " " + simErrors[2] + " " + simErrors[3] + " " + simErrors[4]); //compare errors. perhaps report differences between original and simulated errors. diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/HsimRun.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/HsimRun.java index 89430ba063..9942bab7f9 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/HsimRun.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/HsimRun.java @@ -47,16 +47,16 @@ public static void run(String readfilename, String filenameOut, char delimiter, double penaltyDiscount = 2.0; SemBicScore score = new SemBicScore(new CovarianceMatrixOnTheFly(dataSet)); score.setPenaltyDiscount(penaltyDiscount); - Fgs fgs = new Fgs(score); - fgs.setVerbose(false); - fgs.setNumPatternsToStore(0); -// fgs.setAlpha(penaltyDiscount); - //fgs.setOut(out); - //fgs.setFaithfulnessAssumed(true); - //fgs.setMaxIndegree(1); - //fgs.setCycleBound(5); - - Graph estGraph = fgs.search(); + Fges fges = new Fges(score); + fges.setVerbose(false); + fges.setNumPatternsToStore(0); +// fges.setAlpha(penaltyDiscount); + //fges.setOut(out); + //fges.setFaithfulnessAssumed(true); + //fges.setMaxIndegree(1); + //fges.setCycleBound(5); + + Graph estGraph = fges.search(); System.out.println(estGraph); Graph estPattern = new EdgeListGraphSingleConnections(estGraph); @@ -95,7 +95,7 @@ public static void run(String readfilename, String filenameOut, char delimiter, //write output to a new file DataWriter.writeRectangularData(newDataSet, new FileWriter(filenameOut), delimiter); - //=======Run FGS on the output data, and compare it to the original learned graph + //=======Run FGES on the output data, and compare it to the original learned graph Path dataFileOut = Paths.get(filenameOut); edu.cmu.tetrad.io.DataReader dataReaderOut = new VerticalTabularDiscreteDataReader(dataFileOut, delimiter); @@ -103,16 +103,16 @@ public static void run(String readfilename, String filenameOut, char delimiter, SemBicScore _score = new SemBicScore(new CovarianceMatrix(dataSetOut)); _score.setPenaltyDiscount(2.0); - Fgs fgsOut = new Fgs(_score); - fgsOut.setVerbose(false); - fgsOut.setNumPatternsToStore(0); -// fgsOut.setAlpha(2.0); - //fgsOut.setOut(out); - //fgsOut.setFaithfulnessAssumed(true); - // fgsOut.setMaxIndegree(1); - // fgsOut.setCycleBound(5); - - Graph estGraphOut = fgsOut.search(); + Fges fgesOut = new Fges(_score); + fgesOut.setVerbose(false); + fgesOut.setNumPatternsToStore(0); +// fgesOut.setAlpha(2.0); + //fgesOut.setOut(out); + //fgesOut.setFaithfulnessAssumed(true); + // fgesOut.setMaxIndegree(1); + // fgesOut.setCycleBound(5); + + Graph estGraphOut = fgesOut.search(); System.out.println(estGraphOut); SearchGraphUtils.graphComparison(estGraphOut, estGraph, System.out); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/HsimStudy.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/HsimStudy.java index 257093b359..7d11e9321e 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/HsimStudy.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/HsimStudy.java @@ -1,20 +1,5 @@ package edu.cmu.tetrad.simulation; -import edu.cmu.tetrad.data.*; -import edu.cmu.tetrad.graph.*; -import edu.cmu.tetrad.search.Fgs; -import edu.cmu.tetrad.search.PatternToDag; - -import edu.cmu.tetrad.io.*; -import edu.cmu.tetrad.search.SearchGraphUtils; - -import java.io.File; -import java.io.FileWriter; -import java.nio.file.Path; -import java.nio.file.Paths; -import java.util.HashSet; -import java.util.Set; - /** * Created by Erich on 3/28/2016. */ diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/HsimStudyAuto.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/HsimStudyAuto.java index 9546648441..b9baf65d64 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/HsimStudyAuto.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/HsimStudyAuto.java @@ -1,20 +1,5 @@ package edu.cmu.tetrad.simulation; -import edu.cmu.tetrad.data.BigDataSetUtility; -import edu.cmu.tetrad.data.DataSet; -import edu.cmu.tetrad.data.DataWriter; -import edu.cmu.tetrad.graph.*; -import edu.cmu.tetrad.io.VerticalTabularDiscreteDataReader; -import edu.cmu.tetrad.search.Fgs; -import edu.cmu.tetrad.search.PatternToDag; -import edu.cmu.tetrad.search.SearchGraphUtils; - -import java.io.File; -import java.io.FileWriter; -import java.nio.file.Path; -import java.nio.file.Paths; -import java.util.*; - /** * Created by Erich on 3/28/2016. */ diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/HsimUtils.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/HsimUtils.java index a1e0c2cb46..0840db44a0 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/HsimUtils.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/HsimUtils.java @@ -1,18 +1,17 @@ package edu.cmu.tetrad.simulation; +import edu.cmu.tetrad.data.ContinuousVariable; import edu.cmu.tetrad.data.DataSet; +import edu.cmu.tetrad.data.DiscreteVariable; import edu.cmu.tetrad.data.VerticalIntDataBox; -import edu.cmu.tetrad.graph.Edge; -import edu.cmu.tetrad.graph.Graph; -import edu.cmu.tetrad.graph.GraphUtils; -import edu.cmu.tetrad.graph.Node; +import edu.cmu.tetrad.graph.*; import edu.cmu.tetrad.search.SearchGraphUtils; +import edu.cmu.tetrad.sem.GeneralizedSemPm; +import edu.cmu.tetrad.sem.TemplateExpander; import edu.cmu.tetrad.util.TextTable; -import java.util.ArrayList; -import java.util.HashSet; -import java.util.List; -import java.util.Set; +import java.text.ParseException; +import java.util.*; /** * Created by Erich on 4/22/2016. @@ -44,6 +43,7 @@ public static Graph evalEdges(Graph inputgraph, Set simnodes, Set re return subgraph; } + //this method returns the set of all parents of a provided set of parents, given a provided graph public static Set getAllParents(Graph inputgraph, Set inputnodes) { List parents = new ArrayList(); List pAdd = new ArrayList(); @@ -63,6 +63,7 @@ public static Set getAllParents(Graph inputgraph, Set inputnodes) { return output; } + //this method returns an array of doubles, which are standard error metrics for graph learning public static double[] errorEval(Graph estPattern, Graph truePattern) { GraphUtils.GraphComparison comparison = SearchGraphUtils.getGraphComparison2(estPattern, truePattern); @@ -146,6 +147,7 @@ public static double correctnessRatio(int[][] counts) { return (correctEdges / (double) estimatedEdges); } + //this method makes a VerticalIntDataBox from a regular data set public static VerticalIntDataBox makeVertIntBox(DataSet dataset) { //this is for turning regular data set into verticalintbox (not doublebox...) int[][] data = new int[dataset.getNumColumns()][dataset.getNumRows()]; @@ -156,5 +158,49 @@ public static VerticalIntDataBox makeVertIntBox(DataSet dataset) { } return new VerticalIntDataBox(data); } - + //returns a String formatted as a latex table, which can be pasted directly into a latex document + public static String makeLatexTable(String[][] tablevalues){ + String nl = System.lineSeparator(); + String output = "\\begin{table}[ht]"+nl; + output = output + "\\begin{center}" +nl; + int dim1 = tablevalues.length; + int dim2 = tablevalues[0].length; + //determines number of columns in the table + output=output + "\\begin{tabular}{|"; + for (int c=0;c1 && j!=dim2-1){ + output = output+" & "; + } + } + output=output+"\\\\ \\hline" + nl; + } + //caps off the environments used by the table + output=output+"\\end{tabular}"+nl+"\\end{center}"+nl+"\\end{table}"+nl; + return output; + } + //this turns an array of doubles into an array of strings, formatted for easier reading + //the input String should be formatted appropriately for the String.format method + public static String[] formatErrorsArray(double[] inputArray,String formatting){ + String[] output = new String[inputArray.length]; + for (int i=0;i varslist = new ArrayList<>(); + for (int i = 0; i < numVars; i++) { + varslist.add(new ContinuousVariable("X" + i)); + } + return GraphUtils.randomGraphRandomForwardEdges(varslist, 0, numEdges, 30, 15, 15, false, true); + } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/Vicinity.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/Vicinity.java index b75585e799..5cd5b7f467 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/Vicinity.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/Vicinity.java @@ -1,23 +1,44 @@ package edu.cmu.tetrad.simulation; +import edu.cmu.tetrad.data.DataSet; +import edu.cmu.tetrad.graph.Edge; +import edu.cmu.tetrad.graph.Node; +import edu.cmu.tetrad.graph.NodeEqualityMode; + import java.util.*; /** + * This version of Vicinity finds nearby nodes by searching with an expanding cube + * Prior to Vicinity4, versions of Vicinity looked at the 3 axis independently instead of collectively. + * + * Vicinity5 improves on Vicinity4 by allowing for the nodes to not be distributed evenly throughout + * the location space. This is needed for fMRI data when the voxels are not perfect cubes. + * * @author jdramsey + * @author Erich Kummerfeld */ public class Vicinity { + //these are value ranges, used to constrain searches at the edges private int xLow; private int xHigh; private int yLow; private int yHigh; private int zLow; private int zHigh; - //EK: I think range is the distance along each axis that a node can be to be "close enough" for Vicinity - private int range = 6; - private Map> xCoords = new HashMap<>(); - private Map> yCoords = new HashMap<>(); - private Map> zCoords = new HashMap<>(); - public Vicinity(List edges, int xLow, int xHigh, int yLow, int yHigh, int zLow, int zHigh) { + //these are the dimensions of the voxels + private double xDist; + private double yDist; + private double zDist; + + + private DataSet locationMap; + + //Vicinity4 just uses two maps, each from array to a set of edges + private Map, Set> Coords1 = new HashMap<>(); + private Map, Set> Coords2 = new HashMap<>(); + + public Vicinity(List edges, DataSet locationMap, int xLow, int xHigh, int yLow, int yHigh, int zLow, int zHigh, + double xDist, double yDist, double zDist) { //EK: the xLow etc. ints are the bounds on the coordinates in the location space, I think this.xLow = xLow; this.xHigh = xHigh; @@ -26,107 +47,266 @@ public Vicinity(List edges, int xLow, int xHigh, int yLow, int yHigh, int this.zLow = zLow; this.zHigh = zHigh; + this.xDist = xDist; + this.yDist = yDist; + this.zDist = zDist; + + this.locationMap = locationMap; + + NodeEqualityMode.setEqualityMode(NodeEqualityMode.Type.OBJECT); + + //make the edge accessible via the map from either of its endpoints for (Edge edge : edges) { - add(xCoords, edge, edge.getP1().getX()); - add(yCoords, edge, edge.getP1().getY()); - add(zCoords, edge, edge.getP1().getZ()); + add(Coords1, edge, Arrays.asList(getX(edge.getNode1(), locationMap),getY(edge.getNode1(), locationMap), + getZ(edge.getNode1(), locationMap)) ); - add(xCoords, edge, edge.getP2().getX()); - add(yCoords, edge, edge.getP2().getY()); - add(zCoords, edge, edge.getP2().getZ()); + add(Coords2, edge, Arrays.asList(getX(edge.getNode2(), locationMap),getY(edge.getNode2(), locationMap), + getZ(edge.getNode2(), locationMap)) ); } } - public List getVicinity(Edge edge) { - Set edges = new HashSet<>(); - - //EK: I'm concerned that this won't actually remove many edges? - //EK: since they only need one endpoint to be sorta close on ANY dimension - //EK: this will carve out a thick 3D "+" shape and add every edge that touches it - //EK: wouldn't it be preferable to carve out a smaller shape, like a cube? - for (int x = edge.getP1().getX() - range; x <= edge.getP1().getX() + range; x++) { - if (x < xLow || x > xHigh) continue; - edges.addAll(xCoords.get(x)); + //chunk basically establishes how quickly the search grows for a nearest edge. It should be small for + //graphs that are dense in the location space, and large for graphs that are sparse in the location space + public List getVicinity(Edge edge, int chunk) { + //the strategy employed here is to start from the input edge nodes, and expand the search from there + //the rate of expansion is based on the chunk parameter + //we start the range at 0, and increase it by chunk until another edge is found + //we're looking for any edge that has one node close to node1 and the other node close to node2 + int baserange; + if (edge.isDirected()){ + baserange = findRangeD(edge,chunk); + } else { + baserange = findRangeU(edge,chunk); } - for (int y = edge.getP1().getY() - range; y <= edge.getP1().getY() + range; y++) { - if (y < yLow || y > yHigh) continue; - edges.addAll(yCoords.get(y)); - } + //System.out.println("baserange: " +baserange); + //since I'm searching in a cube but distance is usually measured euclidian, I increase range by sqrt(3) + int range = (int) Math.ceil(Math.sqrt((double) 3) * (double) baserange); + //System.out.println(findEdges(edge,range)); + return findEdges(edge,range); + } - for (int z = edge.getP1().getZ() - range; z <= edge.getP1().getZ() + range; z++) { - if (z < zLow || z > zHigh) continue; - edges.addAll(zCoords.get(z)); - } + //======%====%=======Private methods===========%==========%========= + //******************* This finds the range when edge is Undirected ********** + private int findRangeU(Edge edge, int chunksize){ + Set edges = new HashSet<>(); + //System.out.println("edges is empty?"+edges.isEmpty()); + //System.out.println("edges is null?"+edges == null); + NodeEqualityMode.setEqualityMode(NodeEqualityMode.Type.OBJECT); + int range = 0-chunksize; + while (edges.isEmpty()){ + //increment range by chunk + range += chunksize; + //initialize the edge sets + Set node1edges1 = new HashSet<>(); + Set node1edges2 = new HashSet<>(); - for (int x = edge.getP2().getX() - range; x <= edge.getP2().getX() + range; x++) { - if (x < xLow || x > xHigh) continue; - edges.addAll(xCoords.get(x)); - } + /*for Vicinity5 the first and second arguments of the for loop need to be modified so that + * they respect that a single increment of the x/y/z grid scales with that dimension of the voxel + */ + //create separate range values for x y and z, scaled by xdist ydist zdist + int xrange = (int) Math.ceil(range/xDist); + int yrange = (int) Math.ceil(range/yDist); + int zrange = (int) Math.ceil(range/zDist); - for (int y = edge.getP2().getY() - range; y <= edge.getP2().getY() + range; y++) { - if (y < yLow || y > yHigh) continue; - edges.addAll(yCoords.get(y)); - } + //list edges with either endpoint near node1 + for (int x = getX(edge.getNode1(), locationMap) - xrange; x <= getX(edge.getNode1(), locationMap) + xrange; x++) { + for (int y = getY(edge.getNode1(), locationMap) - yrange; y <= getY(edge.getNode1(), locationMap) + yrange; y++) { + for (int z = getZ(edge.getNode1(), locationMap) - zrange; z <= getZ(edge.getNode1(), locationMap) + zrange; z++) { + if (x < xLow || x > xHigh || y < yLow || y > yHigh || z < zLow || z > zHigh) continue; + if (Coords1.get(Arrays.asList(x,y,z)) != null) node1edges1.addAll(Coords1.get(Arrays.asList(x,y,z))); + if (Coords2.get(Arrays.asList(x,y,z)) != null) node1edges2.addAll(Coords2.get(Arrays.asList(x,y,z))); + } + } + } + //for bugchecking + //System.out.println("node1edges1 is empty? "+node1edges1.isEmpty()); + //System.out.println("node1edges2 is empty? "+node1edges2.isEmpty()); - for (int z = edge.getP2().getZ() - range; z <= edge.getP1().getZ() + range; z++) { - if (z < zLow || z > zHigh) continue; - edges.addAll(zCoords.get(z)); + int x2 = getX(edge.getNode2(), locationMap); + int y2 = getY(edge.getNode2(), locationMap); + int z2 = getZ(edge.getNode2(), locationMap); + //if one or both of the above lists is nonempty, find edges where the other endpoint is near node2! + if (!node1edges1.isEmpty()){ + for (Edge edge11 : node1edges1){ + int x = getX(edge11.getNode2(), locationMap); + int y = getY(edge11.getNode2(), locationMap); + int z = getZ(edge11.getNode2(), locationMap); + /*for Vicinity5 the first and second arguments of the for loop need to be modified so that + * they respect that a single increment of the x/y/z grid scales with that dimension of the voxel + */ + if (x >= x2 - xrange && x <= x2 + xrange && y >= y2 - yrange && y <= y2 + yrange && z >= z2 - zrange && z <= z2 + zrange){ + edges.add(edge11); + } + } + } + if (!node1edges2.isEmpty()){ + for (Edge edge12 : node1edges2){ + int x = getX(edge12.getNode1(), locationMap); + int y = getY(edge12.getNode1(), locationMap); + int z = getZ(edge12.getNode1(), locationMap); + /*for Vicinity5 the first and second arguments of the for loop need to be modified so that + * * they respect that a single increment of the x/y/z grid scales with that dimension of the voxel + * Any time that xyz indexes are compared using range, then some rescaling needs to be done to + * account for how much distance is covered by one increment of that index dimension + * */ + if (x >= x2 - xrange && x <= x2 + xrange && y >= y2 - yrange && y <= y2 + yrange && z >= z2 - zrange && z <= z2 + zrange){ + edges.add(edge12); + } + } + } + //System.out.println("edges is empty?"+edges.isEmpty()+" at range "+range); + //System.out.println(edges); } - return new ArrayList<>(edges); + return range; } + //**********====== This finds the range when Edge is Directed ============********************* + private int findRangeD(Edge edge, int chunksize){ + Set edges = new HashSet<>(); + //It matters whether Node1 is the tail or the head of the arrow + //Because of how the Edge class works, it looks like Node1 is ALWAYS the TAIL + NodeEqualityMode.setEqualityMode(NodeEqualityMode.Type.OBJECT); + int range = 0-chunksize; + while (edges.isEmpty()){ + //increment range by chunk + range += chunksize; - private void add(Map> xCoords, Edge edge, int x) { - Set edges = xCoords.get(x); + //create separate range values for x y and z, scaled by xdist ydist zdist + int xrange = (int) Math.ceil(range/xDist); + int yrange = (int) Math.ceil(range/yDist); + int zrange = (int) Math.ceil(range/zDist); - if (edges == null) { - edges = new HashSet<>(); - xCoords.put(x, edges); - } + //initialize the edge sets + Set node1edges1 = new HashSet<>(); + Set node1edges2 = new HashSet<>(); + //list edges with either endpoint near node1 + for (int x = getX(edge.getNode1(), locationMap) - xrange; x <= getX(edge.getNode1(), locationMap) + xrange; x++) { + for (int y = getY(edge.getNode1(), locationMap) - yrange; y <= getY(edge.getNode1(), locationMap) + yrange; y++) { + for (int z = getZ(edge.getNode1(), locationMap) - zrange; z <= getZ(edge.getNode1(), locationMap) + zrange; z++) { + if (x < xLow || x > xHigh || y < yLow || y > yHigh || z < zLow || z > zHigh) continue; + if (Coords1.get(Arrays.asList(x,y,z)) != null) node1edges1.addAll(Coords1.get(Arrays.asList(x,y,z))); + if (Coords2.get(Arrays.asList(x,y,z)) != null) node1edges2.addAll(Coords2.get(Arrays.asList(x,y,z))); + } + } + } - xCoords.get(x).add(edge); - } + //** Since edge is directed, node1edges2 is NOT allowed to contain directed edges + //it's okay if the edges in node1edges2 are unidrected, though + if (!node1edges2.isEmpty()){ + List edges12 = new ArrayList<>(node1edges2); + for (Edge thisedge : edges12){ + if (thisedge.isDirected()) node1edges2.remove(thisedge); + } - public class Point { - private int x; - private int y; - private int z; + } - public Point(int x, int y, int z) { - this.x = x; - this.y = y; - this.z = z; + int x2 = getX(edge.getNode2(), locationMap); + int y2 = getY(edge.getNode2(), locationMap); + int z2 = getZ(edge.getNode2(), locationMap); + //if one or both of the above lists is nonempty, find edges where the other endpoint is near node2! + if (!node1edges1.isEmpty()){ + for (Edge edge11 : node1edges1){ + int x = getX(edge11.getNode2(), locationMap); + int y = getY(edge11.getNode2(), locationMap); + int z = getZ(edge11.getNode2(), locationMap); + if (x >= x2 - xrange && x <= x2 + xrange && y >= y2 - yrange && y <= y2 + yrange && z >= z2 - zrange && z <= z2 + zrange){ + edges.add(edge11); + } + } + } + if (!node1edges2.isEmpty()){ + for (Edge edge12 : node1edges2){ + int x = getX(edge12.getNode1(), locationMap); + int y = getY(edge12.getNode1(), locationMap); + int z = getZ(edge12.getNode1(), locationMap); + if (x >= x2 - xrange && x <= x2 + xrange && y >= y2 - yrange && y <= y2 + yrange && z >= z2 - zrange && z <= z2 + zrange){ + edges.add(edge12); + } + } + } + //System.out.println("edges is empty?"+edges.isEmpty()+" at range "+range); + //System.out.println(edges); } - public int getX() { - return x; - } + return range; + } + //***********()(*&(*%^#$%^&*^&%^%^%****** + //this is like findRange, but it returns the edges within the range in one step, without iterating chunksize + private List findEdges(Edge edge, int range){ + Set edges = new HashSet<>(); + NodeEqualityMode.setEqualityMode(NodeEqualityMode.Type.OBJECT); + //create separate range values for x y and z, scaled by xdist ydist zdist + int xrange = (int) Math.ceil(range/xDist); + int yrange = (int) Math.ceil(range/yDist); + int zrange = (int) Math.ceil(range/zDist); - public int getY() { - return y; + //initialize the edge sets + Set node1edges1 = new HashSet<>(); + Set node1edges2 = new HashSet<>(); + //list edges with either endpoint near node1 + for (int x = getX(edge.getNode1(), locationMap) - xrange; x <= getX(edge.getNode1(), locationMap) + xrange; x++) { + for (int y = getY(edge.getNode1(), locationMap) - yrange; y <= getY(edge.getNode1(), locationMap) + yrange; y++) { + for (int z = getZ(edge.getNode1(), locationMap) - zrange; z <= getZ(edge.getNode1(), locationMap) + zrange; z++) { + if (x < xLow || x > xHigh || y < yLow || y > yHigh || z < zLow || z > zHigh) continue; + //if (Coords1.get(new Integer[] {x,y,z}) == null) continue; + if (Coords1.get(Arrays.asList(x,y,z)) != null) node1edges1.addAll(Coords1.get(Arrays.asList(x,y,z))); + if (Coords2.get(Arrays.asList(x,y,z)) != null) node1edges2.addAll(Coords2.get(Arrays.asList(x,y,z))); + } + } + } + int x2 = getX(edge.getNode2(), locationMap); + int y2 = getY(edge.getNode2(), locationMap); + int z2 = getZ(edge.getNode2(), locationMap); + //if one or both of the above lists is nonempty, find edges where the other endpoint is near node2! + if (!node1edges1.isEmpty()){ + for (Edge edge11 : node1edges1){ + int x = getX(edge11.getNode2(), locationMap); + int y = getY(edge11.getNode2(), locationMap); + int z = getZ(edge11.getNode2(), locationMap); + if (x >= x2 - xrange && x <= x2 + xrange && y >= y2 - yrange && y <= y2 + yrange && z >= z2 - zrange && z <= z2 + zrange){ + edges.add(edge11); + } + } + } + if (!node1edges2.isEmpty()){ + for (Edge edge12 : node1edges2){ + int x = getX(edge12.getNode1(), locationMap); + int y = getY(edge12.getNode1(), locationMap); + int z = getZ(edge12.getNode1(), locationMap); + if (x >= x2 - xrange && x <= x2 + xrange && y >= y2 - yrange && y <= y2 + yrange && z >= z2 - zrange && z <= z2 + zrange){ + edges.add(edge12); + } + } } - public int getZ() { - return z; + return new ArrayList<>(edges); + } + //this is just the private method for adding entries to a map + private void add(Map, Set> Coords, Edge edge, List x) { + Set edges = Coords.get(x); + if (edges == null) { + edges = new HashSet<>(); + Coords.put(x, edges); } + Coords.get(x).add(edge); } - public class Edge { - private Point p1; - private Point p2; - - public Edge(Point p1, Point p2) { - this.p1 = p1; - this.p2 = p2; - } + // want to use regular point and edge classes, so replace the below with private methods + //this is where the loaded locationMap should be doing the work + private int getX(Node node, DataSet locationMap) { + //double output = locationMap.getDouble(0,locationMap.getColumn(node)); + return (int) locationMap.getDouble(0, locationMap.getColumn(node)); + } - public Point getP1() { - return p1; - } + private int getY(Node node, DataSet locationMap) { + //double output = locationMap.getDouble(0,locationMap.getColumn(node)); + return (int) locationMap.getDouble(1, locationMap.getColumn(node)); + } - public Point getP2() { - return p2; - } + private int getZ(Node node, DataSet locationMap) { + //double output = locationMap.getDouble(0,locationMap.getColumn(node)); + return (int) locationMap.getDouble(2, locationMap.getColumn(node)); } -} + +} \ No newline at end of file diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/Vicinity2.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/Vicinity2.java deleted file mode 100644 index 78ec41133d..0000000000 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/Vicinity2.java +++ /dev/null @@ -1,124 +0,0 @@ -package edu.cmu.tetrad.simulation; -import edu.cmu.tetrad.data.DataSet; -import edu.cmu.tetrad.graph.Edge; -import edu.cmu.tetrad.graph.Node; -import edu.cmu.tetrad.graph.NodeEqualityMode; - -import java.util.*; - -/** - * @author jdramsey - */ -public class Vicinity2 { - private int xLow; - private int xHigh; - private int yLow; - private int yHigh; - private int zLow; - private int zHigh; - //EK: I think range is the distance along each axis that a node can be to be "close enough" for Vicinity - private int range = 6; - private Map> xCoords = new HashMap<>(); - private Map> yCoords = new HashMap<>(); - private Map> zCoords = new HashMap<>(); - - - public Vicinity2(List edges, DataSet locationMap, int xLow, int xHigh, int yLow, int yHigh, int zLow, int zHigh) { - //EK: the xLow etc. ints are the bounds on the coordinates in the location space, I think - this.xLow = xLow; - this.xHigh = xHigh; - this.yLow = yLow; - this.yHigh = yHigh; - this.zLow = zLow; - this.zHigh = zHigh; - - NodeEqualityMode.setEqualityMode(NodeEqualityMode.Type.OBJECT); - - for (Edge edge : edges) { - add(xCoords, edge, getX(edge.getNode1(), locationMap)); - add(yCoords, edge, getY(edge.getNode1(), locationMap)); - add(zCoords, edge, getZ(edge.getNode1(), locationMap)); - - add(xCoords, edge, getX(edge.getNode2(), locationMap)); - add(yCoords, edge, getY(edge.getNode2(), locationMap)); - add(zCoords, edge, getZ(edge.getNode2(), locationMap)); - } - } - - public List getVicinity(Edge edge, DataSet locationMap) { - Set edges = new HashSet<>(); - NodeEqualityMode.setEqualityMode(NodeEqualityMode.Type.OBJECT); - //EK: I'm concerned that this won't actually remove many edges? - //EK: since they only need one endpoint to be sorta close on ANY dimension - //EK: this will carve out a thick 3D "+" shape and add every edge that touches it - //EK: wouldn't it be preferable to carve out a smaller shape, like a cube? - for (int x = getX(edge.getNode1(), locationMap) - range; x <= getX(edge.getNode1(), locationMap) + range; x++) { - if (x < xLow || x > xHigh) continue; - if (xCoords.get(x) == null) continue; - edges.addAll(xCoords.get(x)); - } - - for (int y = getY(edge.getNode1(), locationMap) - range; y <= getY(edge.getNode1(), locationMap) + range; y++) { - if (y < yLow || y > yHigh) continue; - if (yCoords.get(y) == null) continue; - edges.addAll(yCoords.get(y)); - } - - for (int z = getZ(edge.getNode1(), locationMap) - range; z <= getZ(edge.getNode1(), locationMap) + range; z++) { - if (z < zLow || z > zHigh) continue; - if (zCoords.get(z) == null) continue; - edges.addAll(zCoords.get(z)); - } - - for (int x = getX(edge.getNode2(), locationMap) - range; x <= getX(edge.getNode2(), locationMap) + range; x++) { - if (x < xLow || x > xHigh) continue; - if (xCoords.get(x) == null) continue; - edges.addAll(xCoords.get(x)); - } - - for (int y = getY(edge.getNode2(), locationMap) - range; y <= getY(edge.getNode2(), locationMap) + range; y++) { - if (y < yLow || y > yHigh) continue; - if (yCoords.get(y) == null) continue; - edges.addAll(yCoords.get(y)); - } - - for (int z = getZ(edge.getNode2(), locationMap) - range; z <= getZ(edge.getNode2(), locationMap) + range; z++) { - if (z < zLow || z > zHigh) continue; - if (zCoords.get(z) == null) continue; - edges.addAll(zCoords.get(z)); - } - - return new ArrayList<>(edges); - } - - private void add(Map> xCoords, Edge edge, int x) { - Set edges = xCoords.get(x); - - if (edges == null) { - edges = new HashSet<>(); - xCoords.put(x, edges); - } - - xCoords.get(x).add(edge); - } - - //=================Private methods============================== - // want to use regular point and edge classes, so replace the below with private methods - //this is where the loaded locationMap should be doing the work - - private int getX(Node node, DataSet locationMap) { - //double output = locationMap.getDouble(0,locationMap.getColumn(node)); - return (int) locationMap.getDouble(0, locationMap.getColumn(node)); - } - - private int getY(Node node, DataSet locationMap) { - //double output = locationMap.getDouble(0,locationMap.getColumn(node)); - return (int) locationMap.getDouble(1, locationMap.getColumn(node)); - } - - private int getZ(Node node, DataSet locationMap) { - //double output = locationMap.getDouble(0,locationMap.getColumn(node)); - return (int) locationMap.getDouble(2, locationMap.getColumn(node)); - } - -} \ No newline at end of file diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/util/ParamDescriptions.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/util/ParamDescriptions.java index 1867133c6a..6bc0a9480a 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/util/ParamDescriptions.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/util/ParamDescriptions.java @@ -18,7 +18,7 @@ public ParamDescriptions() { put("numMeasures", new ParamDescription("Number of measured variables", 10, 1, Integer.MAX_VALUE)); put("numLatents", new ParamDescription("Number of latent variables", 0, 0, Integer.MAX_VALUE)); put("avgDegree", new ParamDescription("Average degree of graph", 2, 1, Integer.MAX_VALUE)); - put("maxDegree", new ParamDescription("The maximum degree of the output graph. Use -1 for unbounded.", -1, -1, Integer.MAX_VALUE)); + put("maxDegree", new ParamDescription("The maximum degree of the graph.", 100, -1, Integer.MAX_VALUE)); put("maxIndegree", new ParamDescription("Maximum indegree of graph", 100, 1, Integer.MAX_VALUE)); put("maxOutdegree", new ParamDescription("Maximum outdegree of graph", 100, 1, Integer.MAX_VALUE)); put("connected", new ParamDescription("Yes if graph should be connected", false)); @@ -27,7 +27,7 @@ public ParamDescriptions() { put("differentGraphs", new ParamDescription("Yes if a different graph should be used for each run", false)); put("alpha", new ParamDescription("Cutoff for p values (alpha)", 0.01, 0.0, 1.0)); put("penaltyDiscount", new ParamDescription("Penalty discount", 4.0, 0.0, Double.MAX_VALUE)); - put("fgsDepth", new ParamDescription("Maximum number of new colliders", 1, 1, Integer.MAX_VALUE)); + put("fgesDepth", new ParamDescription("Maximum number of new colliders", 1, 1, Integer.MAX_VALUE)); put("standardize", new ParamDescription("Yes if the data should be standardized", false)); put("measurementVariance", new ParamDescription("Additive measurement noise variance", 0.0, 0, Double.MAX_VALUE)); put("depth", new ParamDescription("Maximum size of conditioning set", -1, -1, Integer.MAX_VALUE)); @@ -88,8 +88,8 @@ public ParamDescriptions() { put("pixelDigitalization", new ParamDescription("Pixel digitalization", 0.025, 0.0, Double.MAX_VALUE)); put("includeDishAndChipColumns", new ParamDescription("Yes if Dish and Chip columns should be included in output", true)); - put("numRandomSelections", new ParamDescription("The number random selections of data sets that should be taken", 5)); - put("randomSelectionSize", new ParamDescription("The number of datasets that should be taken in each random sample", 5)); + put("numRandomSelections", new ParamDescription("The number random selections of data sets that should be taken", 1)); + put("randomSelectionSize", new ParamDescription("The number of datasets that should be taken in each random sample", 1)); put("maxit", new ParamDescription("MAXIT parameter (GLASSO)", 10000, 1, Integer.MAX_VALUE)); put("ia", new ParamDescription("IA parameter (GLASSO)", false)); @@ -114,12 +114,42 @@ public ParamDescriptions() { put("measuredMeasuredImpureAssociations", new ParamDescription("Number of Measured <-> Measured impure edges", 0)); // put("useRuleC", new ParamDescription("Yes if rule C for CCD should be used", false)); - put("applyR1", new ParamDescription("Yes if the orient away from arrow should be applied", false)); + put("applyR1", new ParamDescription("Yes if the orient away from arrow rule should be applied", true)); put("probCycle", new ParamDescription("The probability of adding a cycle to the graph", 1.0, 0.0, 1.0)); put("intervalBetweenShocks", new ParamDescription("Interval beween shocks (R. A. Fisher simulation model)", 10, 1, Integer.MAX_VALUE)); + put("intervalBetweenRecordings", new ParamDescription( + "Interval between data recordings for the linear Fisher model", + 10, 1, Integer.MAX_VALUE)); + + put("skipNumRecords", new ParamDescription("Number of records that should be skipped between recordings", + 0, 0, Integer.MAX_VALUE)); put("fisherEpsilon", new ParamDescription("Epsilon where |xi.t - xi.t-1| < epsilon, criterion for convergence", .001, Double.MIN_VALUE, Double.MAX_VALUE)); + + put("useMaxPOrientationHeuristic", new ParamDescription( + "Yes if the heuristic for orienting unshielded colliders for max P should be used", + true)); + put("maxPOrientationMaxPathLength", new ParamDescription( + "Maximum path length for the unshielded collider heuristic for max P", + 3, 0, Integer.MAX_VALUE)); + put("orientTowardDConnections", new ParamDescription( + "Yes if Richardson's step C (orient toward d-connection) should be used", + true)); + put("orientVisibleFeedbackLoops", new ParamDescription( + "Yes if visible feedback loops should be oriented", + true)); + put("doColliderOrientation", new ParamDescription( + "Yes if unshielded collider orientation should be done", + true)); + + put("completeRuleSetUsed", new ParamDescription( + "Yes if the complete FCI rule set should be used", + false)); + + put("maxDistinctValuesDiscrete", new ParamDescription( + "The maximum number of distinct values in a column for discrete variables", + 0, 0, Integer.MAX_VALUE)); } public static ParamDescriptions instance() { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/util/StatUtils.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/util/StatUtils.java index 64ee861ca4..b58dc99894 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/util/StatUtils.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/util/StatUtils.java @@ -89,12 +89,16 @@ public static double mean(long array[], int N) { */ public static double mean(double array[], int N) { double sum = 0.0; + int count = 0; for (int i = 0; i < N; i++) { - sum += array[i]; + if (!Double.isNaN(array[i])) { + sum += array[i]; + count++; + } } - return sum / N; + return sum / (double) count; } /** @@ -104,12 +108,16 @@ public static double mean(double array[], int N) { */ public static double mean(TetradVector data, int N) { double sum = 0.0; + int count = 0; for (int i = 0; i < N; i++) { - sum += data.get(i); + if (!Double.isNaN( data.get(i))) { + sum += data.get(i); + count++; + } } - return sum / N; + return sum / (double) count; } /** @@ -1755,13 +1763,7 @@ public static synchronized double partialCorrelation(TetradMatrix submatrix) { try { TetradMatrix inverse = submatrix.inverse(); - - double a = -1.0 * inverse.get(0, 1); - double v0 = inverse.get(0, 0); - double v1 = inverse.get(1, 1); - double b = Math.sqrt(v0 * v1); - - return a / b; + return -(1.0 * inverse.get(0, 1)) / Math.sqrt(inverse.get(0, 0) * inverse.get(1, 1)); } catch (Exception e) { e.printStackTrace(); return Double.NaN; @@ -1989,7 +1991,7 @@ public static double getChiSquareCutoff(double alpha, int df) { } // Calculates the log of a list of terms, where the argument consists of the logs of the terms. - public static double logOfSum(List logs) { + public static double logsum(List logs) { Collections.sort(logs, new Comparator() { @Override diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/util/TetradMatrix.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/util/TetradMatrix.java index c9ffa45698..0e4a25f999 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/util/TetradMatrix.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/util/TetradMatrix.java @@ -466,7 +466,7 @@ public TetradMatrix plus(TetradMatrix mb) { } public TetradMatrix scalarMult(double scalar) { - return new TetradMatrix(apacheData.scalarMultiply(scalar), rows(), columns()); + return new TetradMatrix(apacheData.copy().scalarMultiply(scalar), rows(), columns()); } public int rank() { @@ -533,6 +533,40 @@ public void assign(TetradMatrix matrix) { } } } + + public TetradVector sum(int direction) { + if (direction == 1) { + TetradVector sums = new TetradVector(columns()); + + for (int j = 0; j < columns(); j++) { + double sum = 0.0; + + for (int i = 0; i < rows(); i++) { + sum += apacheData.getEntry(i, j); + } + + sums.set(j, sum); + } + + return sums; + } else if (direction == 2) { + TetradVector sums = new TetradVector(rows()); + + for (int i = 0; i < rows(); i++) { + double sum = 0.0; + + for (int j = 0; j < columns(); j++) { + sum += apacheData.getEntry(i, j); + } + + sums.set(i, sum); + } + + return sums; + } else { + throw new IllegalArgumentException("Expecting 1 (sum columns) or 2 (sum rows)."); + } + } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/util/TetradVector.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/util/TetradVector.java index 5ce0bd9fd2..21c7c89fb7 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/util/TetradVector.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/util/TetradVector.java @@ -98,7 +98,8 @@ public TetradVector plus(TetradVector mb) { } public TetradVector scalarMult(double scalar) { - return new TetradVector(data.mapDivideToSelf(scalar)); + + return new TetradVector(data.mapMultiplyToSelf(scalar)); } public TetradMatrix diag() { diff --git a/tetrad-lib/src/main/java/edu/pitt/csb/mgm/MGM.java b/tetrad-lib/src/main/java/edu/pitt/csb/mgm/MGM.java index c99a6c2a57..7717fe0129 100644 --- a/tetrad-lib/src/main/java/edu/pitt/csb/mgm/MGM.java +++ b/tetrad-lib/src/main/java/edu/pitt/csb/mgm/MGM.java @@ -1352,7 +1352,7 @@ private static void runTests2(){ System.out.println(im); int samps = 1000; - DataSet ds = im.simulateDataAvoidInfinity(samps, false); + DataSet ds = im.simulateDataFisher(samps); ds = MixedUtils.makeMixedData(ds, nd); //System.out.println(ds); diff --git a/tetrad-lib/src/main/java/edu/pitt/csb/mgm/MixedUtils.java b/tetrad-lib/src/main/java/edu/pitt/csb/mgm/MixedUtils.java index 1558f23672..611b03d7f2 100644 --- a/tetrad-lib/src/main/java/edu/pitt/csb/mgm/MixedUtils.java +++ b/tetrad-lib/src/main/java/edu/pitt/csb/mgm/MixedUtils.java @@ -1058,7 +1058,7 @@ public static void main(String[] args){ System.out.println(im); int samps = 15; - DataSet ds = im.simulateDataAvoidInfinity(samps, false); + DataSet ds = im.simulateDataFisher(samps); System.out.println(ds); System.out.println("num cats " + ((DiscreteVariable) g.getNode("X4")).getNumCategories()); diff --git a/tetrad-lib/src/main/java/edu/pitt/csb/stability/SearchWrappers.java b/tetrad-lib/src/main/java/edu/pitt/csb/stability/SearchWrappers.java index 104d8f8ac0..b0ce8267e5 100644 --- a/tetrad-lib/src/main/java/edu/pitt/csb/stability/SearchWrappers.java +++ b/tetrad-lib/src/main/java/edu/pitt/csb/stability/SearchWrappers.java @@ -61,17 +61,17 @@ public Graph search(DataSet ds) { } } - public static class FgsWrapper extends DataGraphSearch{ - public FgsWrapper(double...params){ + public static class FgesWrapper extends DataGraphSearch{ + public FgesWrapper(double...params){ super(params); } - public FgsWrapper copy() {return new FgsWrapper(searchParams);} + public FgesWrapper copy() {return new FgesWrapper(searchParams);} public Graph search(DataSet ds){ SemBicScore score = new SemBicScore(new CovarianceMatrixOnTheFly(MixedUtils.makeContinuousData(ds))); score.setPenaltyDiscount(searchParams[0]); - Fgs fg = new Fgs(score); + Fges fg = new Fges(score); return fg.search(); } }