Search in sources :

Example 26 with DataModel

use of edu.cmu.tetrad.data.DataModel in project tetrad by cmu-phil.

the class GeneralBootstrapSearch method search.

public List<Graph> search() {
    PAGs.clear();
    parameters.set("bootstrapSampleSize", 0);
    long start, stop;
    if (!this.runParallel) {
        // Running in the sequential form
        if (verbose) {
            out.println("Running Bootstraps in Sequential Mode, numBoostrap = " + numBootstrap);
        }
        for (int i1 = 0; i1 < this.numBootstrap; i1++) {
            start = System.currentTimeMillis();
            GeneralBootstrapSearchRunnable task = null;
            if (data != null) {
                DataSet dataSet = DataUtils.getBootstrapSample(data, data.getNumRows());
                task = new GeneralBootstrapSearchRunnable(dataSet, algorithm, parameters, this, verbose);
            // GeneralBootstrapSearchAction task = new GeneralBootstrapSearchAction(i1, 1, algorithm, parameters, this, verbose);
            } else {
                List<DataModel> dataModels = new ArrayList<>();
                for (DataSet data : dataSets) {
                    DataSet dataSet = DataUtils.getBootstrapSample(data, data.getNumRows());
                    dataModels.add(dataSet);
                }
                task = new GeneralBootstrapSearchRunnable(dataModels, multiDataSetAlgorithm, parameters, this, verbose);
            }
            if (initialGraph != null) {
                task.setInitialGraph(initialGraph);
            }
            task.setKnowledge(knowledge);
            task.run();
            // task.compute();
            stop = System.currentTimeMillis();
            if (verbose) {
                out.println("processing time of bootstrap : " + (stop - start) / 1000.0 + " sec");
            }
        }
    } else {
        // Running in the parallel multiThread form
        if (verbose) {
            out.println("Running Bootstraps in Parallel Mode, numBoostrap = " + numBootstrap);
        }
        for (int i1 = 0; i1 < this.numBootstrap; i1++) {
            GeneralBootstrapSearchRunnable task = null;
            if (data != null) {
                DataSet dataSet = DataUtils.getBootstrapSample(data, data.getNumRows());
                task = new GeneralBootstrapSearchRunnable(dataSet, algorithm, parameters, this, verbose);
            } else {
                List<DataModel> dataModels = new ArrayList<>();
                for (DataSet data : dataSets) {
                    DataSet dataSet = DataUtils.getBootstrapSample(data, data.getNumRows());
                    dataModels.add(dataSet);
                }
                task = new GeneralBootstrapSearchRunnable(dataModels, multiDataSetAlgorithm, parameters, this, verbose);
            }
            task.setKnowledge(knowledge);
            pool.submit(task);
        }
        pool.shutdown();
        while (!pool.isTerminated()) {
            try {
                Thread.sleep(1000);
            // out.println("Waiting...");
            } catch (InterruptedException e) {
                // TODO Auto-generated catch block
                e.printStackTrace();
            }
        }
    // out.println("Is terminated: " + pool.isTerminated());
    }
    parameters.set("bootstrapping", true);
    return PAGs;
}
Also used : GeneralBootstrapSearchRunnable(edu.pitt.dbmi.algo.bootstrap.task.GeneralBootstrapSearchRunnable) DataSet(edu.cmu.tetrad.data.DataSet) DataModel(edu.cmu.tetrad.data.DataModel) ArrayList(java.util.ArrayList)

Example 27 with DataModel

use of edu.cmu.tetrad.data.DataModel in project tetrad by cmu-phil.

the class GeneralAlgorithmRunner method execute.

// ============================PUBLIC METHODS==========================//
@Override
public void execute() {
    List<Graph> graphList = new ArrayList<>();
    int i = 0;
    if (getDataModelList().isEmpty()) {
        if (getSourceGraph() != null) {
            Algorithm algo = getAlgorithm();
            if (algo instanceof HasKnowledge) {
                ((HasKnowledge) algo).setKnowledge(getKnowledge());
            }
            graphList.add(algo.search(null, parameters));
        } else {
            throw new IllegalArgumentException("The parent boxes did not include any datasets or graphs. Try opening\n" + "the editors for those boxes and loading or simulating them.");
        }
    } else {
        if (getAlgorithm() instanceof MultiDataSetAlgorithm) {
            for (int k = 0; k < parameters.getInt("numRuns"); k++) {
                List<DataSet> dataSets = getDataModelList().stream().map(e -> (DataSet) e).collect(Collectors.toCollection(ArrayList::new));
                if (dataSets.size() < parameters.getInt("randomSelectionSize")) {
                    throw new IllegalArgumentException("Sorry, the 'random selection size' is greater than " + "the number of data sets.");
                }
                Collections.shuffle(dataSets);
                List<DataModel> sub = new ArrayList<>();
                for (int j = 0; j < parameters.getInt("randomSelectionSize"); j++) {
                    sub.add(dataSets.get(j));
                }
                Algorithm algo = getAlgorithm();
                if (algo instanceof HasKnowledge) {
                    ((HasKnowledge) algo).setKnowledge(getKnowledge());
                }
                graphList.add(((MultiDataSetAlgorithm) algo).search(sub, parameters));
            }
        } else if (getAlgorithm() instanceof ClusterAlgorithm) {
            for (int k = 0; k < parameters.getInt("numRuns"); k++) {
                getDataModelList().forEach(dataModel -> {
                    if (dataModel instanceof ICovarianceMatrix) {
                        ICovarianceMatrix dataSet = (ICovarianceMatrix) dataModel;
                        graphList.add(algorithm.search(dataSet, parameters));
                    } else if (dataModel instanceof DataSet) {
                        DataSet dataSet = (DataSet) dataModel;
                        if (!dataSet.isContinuous()) {
                            throw new IllegalArgumentException("Sorry, you need a continuous dataset for a cluster algorithm.");
                        }
                        graphList.add(algorithm.search(dataSet, parameters));
                    }
                });
            }
        } else {
            getDataModelList().forEach(data -> {
                IKnowledge knowledgeFromData = data.getKnowledge();
                if (!(knowledgeFromData == null || knowledgeFromData.getVariables().isEmpty())) {
                    this.knowledge = knowledgeFromData;
                }
                Algorithm algo = getAlgorithm();
                if (algo instanceof HasKnowledge) {
                    ((HasKnowledge) algo).setKnowledge(getKnowledge());
                }
                DataType algDataType = algo.getDataType();
                if (data.isContinuous() && (algDataType == DataType.Continuous || algDataType == DataType.Mixed)) {
                    graphList.add(algo.search(data, parameters));
                } else if (data.isDiscrete() && (algDataType == DataType.Discrete || algDataType == DataType.Mixed)) {
                    graphList.add(algo.search(data, parameters));
                } else if (data.isMixed() && algDataType == DataType.Mixed) {
                    graphList.add(algo.search(data, parameters));
                } else {
                    throw new IllegalArgumentException("The type of data changed; try opening up the search editor and " + "running the algorithm there.");
                }
            });
        }
    }
    if (getKnowledge().getVariablesNotInTiers().size() < getKnowledge().getVariables().size()) {
        for (Graph graph : graphList) {
            SearchGraphUtils.arrangeByKnowledgeTiers(graph, getKnowledge());
        }
    } else {
        for (Graph graph : graphList) {
            GraphUtils.circleLayout(graph, 225, 200, 150);
        }
    }
    this.graphList = graphList;
}
Also used : GraphUtils(edu.cmu.tetrad.graph.GraphUtils) ObjectInputStream(java.io.ObjectInputStream) Parameters(edu.cmu.tetrad.util.Parameters) HashMap(java.util.HashMap) Triple(edu.cmu.tetrad.graph.Triple) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) HasKnowledge(edu.cmu.tetrad.algcomparison.utils.HasKnowledge) DataType(edu.cmu.tetrad.data.DataType) KnowledgeBoxInput(edu.cmu.tetrad.data.KnowledgeBoxInput) Map(java.util.Map) ClusterAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.cluster.ClusterAlgorithm) ICovarianceMatrix(edu.cmu.tetrad.data.ICovarianceMatrix) Fges(edu.cmu.tetrad.algcomparison.algorithm.oracle.pattern.Fges) BdeuScore(edu.cmu.tetrad.algcomparison.score.BdeuScore) Algorithm(edu.cmu.tetrad.algcomparison.algorithm.Algorithm) IKnowledge(edu.cmu.tetrad.data.IKnowledge) Graph(edu.cmu.tetrad.graph.Graph) IOException(java.io.IOException) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) Collectors(java.util.stream.Collectors) DataModel(edu.cmu.tetrad.data.DataModel) DataModelList(edu.cmu.tetrad.data.DataModelList) List(java.util.List) ParamsResettable(edu.cmu.tetrad.session.ParamsResettable) DataSet(edu.cmu.tetrad.data.DataSet) ImpliedOrientation(edu.cmu.tetrad.search.ImpliedOrientation) MultiDataSetAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm) SessionModel(edu.cmu.tetrad.session.SessionModel) IndependenceTest(edu.cmu.tetrad.search.IndependenceTest) SearchGraphUtils(edu.cmu.tetrad.search.SearchGraphUtils) Knowledge2(edu.cmu.tetrad.data.Knowledge2) Unmarshallable(edu.cmu.tetrad.util.Unmarshallable) Collections(java.util.Collections) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) DataSet(edu.cmu.tetrad.data.DataSet) ICovarianceMatrix(edu.cmu.tetrad.data.ICovarianceMatrix) ArrayList(java.util.ArrayList) HasKnowledge(edu.cmu.tetrad.algcomparison.utils.HasKnowledge) ClusterAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.cluster.ClusterAlgorithm) Algorithm(edu.cmu.tetrad.algcomparison.algorithm.Algorithm) MultiDataSetAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm) IKnowledge(edu.cmu.tetrad.data.IKnowledge) ClusterAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.cluster.ClusterAlgorithm) Graph(edu.cmu.tetrad.graph.Graph) MultiDataSetAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm) DataModel(edu.cmu.tetrad.data.DataModel) DataType(edu.cmu.tetrad.data.DataType)

Example 28 with DataModel

use of edu.cmu.tetrad.data.DataModel in project tetrad by cmu-phil.

the class SubsetDiscreteVariablesAction method actionPerformed.

/**
 * Performs the action of loading a session from a file.
 */
public void actionPerformed(ActionEvent e) {
    DataModel selectedDataModel = getDataEditor().getSelectedDataModel();
    if (selectedDataModel instanceof DataSet) {
        DataSet dataSet = (DataSet) selectedDataModel;
        for (int i = dataSet.getNumColumns(); i >= 0; i--) {
            if (dataSet.getVariable(i) instanceof DiscreteVariable) {
                dataSet.removeColumn(i);
            }
        }
        DataModelList list = new DataModelList();
        list.add(dataSet);
        getDataEditor().reset(list);
        getDataEditor().selectFirstTab();
    } else {
        JOptionPane.showMessageDialog(JOptionUtils.centeringComp(), "Requires a tabular data set.");
    }
}
Also used : DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable) DataSet(edu.cmu.tetrad.data.DataSet) DataModelList(edu.cmu.tetrad.data.DataModelList) DataModel(edu.cmu.tetrad.data.DataModel)

Example 29 with DataModel

use of edu.cmu.tetrad.data.DataModel in project tetrad by cmu-phil.

the class ShiftDataParamsEditor method setup.

/**
 * Builds the panel.
 */
public void setup() {
    DataModelList dataModelList = null;
    for (Object parentModel : parentModels) {
        if (parentModel instanceof DataWrapper) {
            DataWrapper dataWrapper = (DataWrapper) parentModel;
            dataModelList = dataWrapper.getDataModelList();
        }
    }
    if (dataModelList == null) {
        throw new NullPointerException("Null data model list.");
    }
    for (DataModel model : dataModelList) {
        if (!(model instanceof DataSet)) {
            JOptionPane.showMessageDialog(JOptionUtils.centeringComp(), "For the shift search, all of the data in the data box must be in the form of data sets.");
            return;
        }
    }
    final List<DataModel> dataSets = new ArrayList<>();
    for (Object aDataModelList : dataModelList) {
        dataSets.add((DataSet) aDataModelList);
    }
    SpinnerModel maxVarsModel = new SpinnerNumberModel(Preferences.userRoot().getInt("shiftSearchMaxNumShifts", 3), 1, 50, 1);
    JSpinner maxVarsSpinner = new JSpinner(maxVarsModel);
    maxVarsSpinner.setMaximumSize(maxVarsSpinner.getPreferredSize());
    maxVarsSpinner.addChangeListener(new ChangeListener() {

        public void stateChanged(ChangeEvent e) {
            JSpinner spinner = (JSpinner) e.getSource();
            SpinnerNumberModel model = (SpinnerNumberModel) spinner.getModel();
            int value = (Integer) model.getValue();
            Preferences.userRoot().putInt("shiftSearchMaxNumShifts", value);
        }
    });
    SpinnerModel maxShiftModel = new SpinnerNumberModel(Preferences.userRoot().getInt("shiftSearchMaxShift", 2), 1, 50, 1);
    JSpinner maxShiftSpinner = new JSpinner(maxShiftModel);
    maxShiftSpinner.setMaximumSize(maxShiftSpinner.getPreferredSize());
    maxShiftSpinner.addChangeListener(new ChangeListener() {

        public void stateChanged(ChangeEvent e) {
            JSpinner spinner = (JSpinner) e.getSource();
            SpinnerNumberModel model = (SpinnerNumberModel) spinner.getModel();
            int value = (Integer) model.getValue();
            Preferences.userRoot().putInt("shiftSearchMaxShift", value);
        }
    });
    JButton searchButton = new JButton("Search");
    final JButton stopButton = new JButton("Stop");
    final JTextArea textArea = new JTextArea();
    JScrollPane textScroll = new JScrollPane(textArea);
    textScroll.setPreferredSize(new Dimension(500, 200));
    searchButton.addActionListener(new ActionListener() {

        public void actionPerformed(ActionEvent actionEvent) {
            final Thread thread = new Thread() {

                public void run() {
                    textArea.setText("");
                    doShiftSearch(dataSets, textArea);
                }
            };
            thread.start();
        }
    });
    stopButton.addActionListener(new ActionListener() {

        public void actionPerformed(ActionEvent actionEvent) {
            if (search != null) {
                search.stop();
                TaskManager.getInstance().setCanceled(true);
            }
        }
    });
    JComboBox directionBox = new JComboBox(new String[] { "forward", "backward" });
    directionBox.setSelectedItem(params.getBoolean("forwardSearch", true) ? "forward" : "backward");
    directionBox.setMaximumSize(directionBox.getPreferredSize());
    directionBox.addActionListener(new ActionListener() {

        public void actionPerformed(ActionEvent actionEvent) {
            JComboBox source = (JComboBox) actionEvent.getSource();
            String selected = (String) source.getSelectedItem();
            params.set("forwardSearch", "forward".equals(selected));
        }
    });
    Box b1 = Box.createVerticalBox();
    Box b2 = Box.createHorizontalBox();
    b2.add(new JLabel("Maximum number of variables in shift set is: "));
    b2.add(maxVarsSpinner);
    b2.add(Box.createHorizontalGlue());
    b1.add(b2);
    Box b3 = Box.createHorizontalBox();
    b3.add(new JLabel("Maximum "));
    b3.add(directionBox);
    b3.add(new JLabel(" shift: "));
    b3.add(maxShiftSpinner);
    b3.add(Box.createHorizontalGlue());
    b1.add(b3);
    Box b4 = Box.createHorizontalBox();
    b4.add(new JLabel("Output:"));
    b4.add(Box.createHorizontalGlue());
    b1.add(b4);
    Box b5 = Box.createHorizontalBox();
    b5.add(textScroll);
    b1.add(b5);
    Box b6 = Box.createHorizontalBox();
    b6.add(searchButton);
    b6.add(stopButton);
    b1.add(b6);
    final Box a1 = Box.createVerticalBox();
    Box a2 = Box.createHorizontalBox();
    a2.add(new JLabel("Specify the shift (positive or negative) for each variable:"));
    a2.add(Box.createHorizontalGlue());
    a1.add(a2);
    a1.add(Box.createVerticalStrut(20));
    setUpA1(dataSets, a1);
    JTabbedPane tabbedPane = new JTabbedPane();
    tabbedPane.addTab("Shift", new JScrollPane(a1));
    tabbedPane.addTab("Search", new JScrollPane(b1));
    add(tabbedPane, BorderLayout.CENTER);
    tabbedPane.addChangeListener(new ChangeListener() {

        public void stateChanged(ChangeEvent changeEvent) {
            System.out.println("a1 shown");
            a1.removeAll();
            setUpA1(dataSets, a1);
        }
    });
}
Also used : DataSet(edu.cmu.tetrad.data.DataSet) ActionEvent(java.awt.event.ActionEvent) ArrayList(java.util.ArrayList) DataWrapper(edu.cmu.tetradapp.model.DataWrapper) DataModelList(edu.cmu.tetrad.data.DataModelList) ChangeListener(javax.swing.event.ChangeListener) ChangeEvent(javax.swing.event.ChangeEvent) ActionListener(java.awt.event.ActionListener) DataModel(edu.cmu.tetrad.data.DataModel)

Example 30 with DataModel

use of edu.cmu.tetrad.data.DataModel in project tetrad by cmu-phil.

the class SplitCasesAction method actionPerformed.

/**
 * Performs the action of loading a session from a file.
 */
public void actionPerformed(ActionEvent e) {
    DataModel selectedDataModel = getDataEditor().getSelectedDataModel();
    if (!(selectedDataModel instanceof DataSet)) {
        JOptionPane.showMessageDialog(JOptionUtils.centeringComp(), "Requires a tabular data set.");
    }
    List<Node> selectedVariables = new LinkedList<>();
    DataSet dataSet = (DataSet) selectedDataModel;
    int numColumns = dataSet.getNumColumns();
    for (int i = 0; i < numColumns; i++) {
        Node variable = dataSet.getVariable(i);
        if (dataSet.isSelected(variable)) {
            selectedVariables.add(variable);
        }
    }
    if (dataSet.getNumRows() == 0) {
        JOptionPane.showMessageDialog(JOptionUtils.centeringComp(), "Data set is empty.");
        return;
    }
    if (selectedVariables.isEmpty()) {
        selectedVariables.addAll(dataSet.getVariables());
    }
// 
// ParamsEditor editor = new ParamsEditor(dataSet, 3);
// 
// int ret = JOptionPane.showOptionDialog(JOptionUtils.centeringComp(),
// editor, "Split Data by Cases", JOptionPane.OK_CANCEL_OPTION,
// JOptionPane.PLAIN_MESSAGE, null, null, null);
// 
// if (ret == JOptionPane.CANCEL_OPTION) {
// return;
// }
// 
// DataModelList list = editor.getSplits();
// getDataEditor().reset(list);
// getDataEditor().selectLastTab();
}
Also used : DataSet(edu.cmu.tetrad.data.DataSet) DataModel(edu.cmu.tetrad.data.DataModel) Node(edu.cmu.tetrad.graph.Node) LinkedList(java.util.LinkedList)

Aggregations

DataModel (edu.cmu.tetrad.data.DataModel)39 DataSet (edu.cmu.tetrad.data.DataSet)22 ArrayList (java.util.ArrayList)15 DataWrapper (edu.cmu.tetradapp.model.DataWrapper)13 Graph (edu.cmu.tetrad.graph.Graph)9 ICovarianceMatrix (edu.cmu.tetrad.data.ICovarianceMatrix)8 Parameters (edu.cmu.tetrad.util.Parameters)8 DataModelList (edu.cmu.tetrad.data.DataModelList)7 Node (edu.cmu.tetrad.graph.Node)7 ActionEvent (java.awt.event.ActionEvent)7 ActionListener (java.awt.event.ActionListener)7 List (java.util.List)5 LayoutMenu (edu.cmu.tetradapp.workbench.LayoutMenu)4 DoubleTextField (edu.cmu.tetradapp.util.DoubleTextField)3 WatchedProcess (edu.cmu.tetradapp.util.WatchedProcess)3 GraphWorkbench (edu.cmu.tetradapp.workbench.GraphWorkbench)3 IOException (java.io.IOException)3 Algorithm (edu.cmu.tetrad.algcomparison.algorithm.Algorithm)2 MultiDataSetAlgorithm (edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm)2 BdeuScore (edu.cmu.tetrad.algcomparison.score.BdeuScore)2