Search in sources :

Example 1 with DataType

use of edu.cmu.tetrad.data.DataType in project tetrad by cmu-phil.

the class GeneralAlgorithmEditor method initComponents.

private void initComponents() {
    algoDescTextArea.setWrapStyleWord(true);
    algoDescTextArea.setLineWrap(true);
    algoDescTextArea.setEditable(false);
    populateAlgoTypeOptions(algoTypeOpts);
    knowledgeChkBox.addActionListener((e) -> {
        refreshAlgorithmList();
    });
    linearVarChkBox.addActionListener((ActionEvent e) -> {
        refreshTestAndScoreList();
    });
    gaussianVarChkBox.addActionListener((ActionEvent e) -> {
        refreshTestAndScoreList();
    });
    algorithmList.addListSelectionListener((e) -> {
        if (!(e.getValueIsAdjusting() || algorithmList.isSelectionEmpty())) {
            setAlgorithmDescription();
            refreshTestAndScoreList();
            validateAlgorithmOption();
        }
    });
    paramSetFwdBtn.addActionListener((e) -> {
        AlgorithmModel algoModel = algorithmList.getSelectedValue();
        IndependenceTestModel indTestModel = indTestComboBox.getItemAt(indTestComboBox.getSelectedIndex());
        ScoreModel scoreModel = scoreComboBox.getItemAt(scoreComboBox.getSelectedIndex());
        if (isValid(algoModel, indTestModel, scoreModel)) {
            setParameterPanel(algoModel, indTestModel, scoreModel);
            changeCard(PARAMETER_CARD);
        }
    });
    indTestComboBox.addActionListener((e) -> {
        if (!updatingTestModels && indTestComboBox.getSelectedIndex() >= 0) {
            AlgorithmModel algoModel = algorithmList.getSelectedValue();
            Map<DataType, IndependenceTestModel> map = defaultIndTestModels.get(algoModel);
            if (map == null) {
                map = new EnumMap<>(DataType.class);
                defaultIndTestModels.put(algoModel, map);
            }
            map.put(dataType, indTestComboBox.getItemAt(indTestComboBox.getSelectedIndex()));
        }
    });
    scoreComboBox.addActionListener((e) -> {
        if (!updatingScoreModels && scoreComboBox.getSelectedIndex() >= 0) {
            AlgorithmModel algoModel = algorithmList.getSelectedValue();
            Map<DataType, ScoreModel> map = defaultScoreModels.get(algoModel);
            if (map == null) {
                map = new EnumMap<>(DataType.class);
                defaultScoreModels.put(algoModel, map);
            }
            map.put(dataType, scoreComboBox.getItemAt(scoreComboBox.getSelectedIndex()));
        }
    });
    mainPanel.add(new AlgorithmCard(), ALGORITHM_CARD);
    mainPanel.add(new ParameterCard(), PARAMETER_CARD);
    mainPanel.add(new GraphCard(), GRAPH_CARD);
    mainPanel.setPreferredSize(new Dimension(940, 640));
    setLayout(new BorderLayout());
    add(mainPanel, BorderLayout.CENTER);
// add(new JScrollPane(mainPanel), BorderLayout.CENTER);
}
Also used : IndependenceTestModel(edu.cmu.tetradapp.ui.model.IndependenceTestModel) ActionEvent(java.awt.event.ActionEvent) Dimension(java.awt.Dimension) ScoreModel(edu.cmu.tetradapp.ui.model.ScoreModel) BorderLayout(java.awt.BorderLayout) DataType(edu.cmu.tetrad.data.DataType) AlgorithmModel(edu.cmu.tetradapp.ui.model.AlgorithmModel)

Example 2 with DataType

use of edu.cmu.tetrad.data.DataType in project tetrad by cmu-phil.

the class GeneralAlgorithmEditor method refreshScoreList.

private void refreshScoreList() {
    updatingScoreModels = true;
    scoreComboBox.removeAllItems();
    AlgorithmModel algoModel = algorithmList.getSelectedValue();
    if (algoModel != null && algoModel.isRequiredScore()) {
        boolean linear = linearVarChkBox.isSelected();
        boolean gaussian = gaussianVarChkBox.isSelected();
        List<ScoreModel> models = ScoreModels.getInstance().getModels(dataType);
        List<ScoreModel> scoreModels = new LinkedList<>();
        if (linear && gaussian) {
            models.stream().filter(e -> e.getScore().getClazz().isAnnotationPresent(Linear.class)).filter(e -> e.getScore().getClazz().isAnnotationPresent(Gaussian.class)).forEach(e -> scoreModels.add(e));
        } else if (linear) {
            models.stream().filter(e -> e.getScore().getClazz().isAnnotationPresent(Linear.class)).filter(e -> !e.getScore().getClazz().isAnnotationPresent(Gaussian.class)).forEach(e -> scoreModels.add(e));
        } else if (gaussian) {
            models.stream().filter(e -> !e.getScore().getClazz().isAnnotationPresent(Linear.class)).filter(e -> e.getScore().getClazz().isAnnotationPresent(Gaussian.class)).forEach(e -> scoreModels.add(e));
        } else {
            models.stream().forEach(e -> scoreModels.add(e));
        }
        // or BDeu score for discrete data
        if (TsImages.class.equals(algoModel.getAlgorithm().getClazz())) {
            switch(dataType) {
                case Continuous:
                    scoreModels.stream().filter(e -> e.getScore().getClazz().equals(SemBicScore.class)).forEach(e -> scoreComboBox.addItem(e));
                    break;
                case Discrete:
                    scoreModels.stream().filter(e -> e.getScore().getClazz().equals(BdeuScore.class)).forEach(e -> scoreComboBox.addItem(e));
                    break;
            }
        } else {
            scoreModels.forEach(e -> scoreComboBox.addItem(e));
        }
    }
    updatingScoreModels = false;
    if (scoreComboBox.getItemCount() > 0) {
        scoreComboBox.setEnabled(true);
        Map<DataType, ScoreModel> map = defaultScoreModels.get(algoModel);
        if (map == null) {
            map = new EnumMap<>(DataType.class);
            defaultScoreModels.put(algoModel, map);
        }
        ScoreModel scoreModel = map.get(dataType);
        if (scoreModel == null) {
            scoreModel = ScoreModels.getInstance().getDefaultModel(dataType);
            if (scoreModel == null) {
                scoreModel = scoreComboBox.getItemAt(0);
            }
        }
        scoreComboBox.setSelectedItem(scoreModel);
    } else {
        scoreComboBox.setEnabled(false);
    }
}
Also used : Enumeration(java.util.Enumeration) JDialog(javax.swing.JDialog) TetradDesktop(edu.cmu.tetradapp.app.TetradDesktop) LoggerFactory(org.slf4j.LoggerFactory) Parameters(edu.cmu.tetrad.util.Parameters) Node(edu.cmu.tetrad.graph.Node) PaddingPanel(edu.cmu.tetradapp.ui.PaddingPanel) StringUtils(org.apache.commons.lang3.StringUtils) Linear(edu.cmu.tetrad.annotation.Linear) AlgorithmModel(edu.cmu.tetradapp.ui.model.AlgorithmModel) Map(java.util.Map) ScoreModel(edu.cmu.tetradapp.ui.model.ScoreModel) ICovarianceMatrix(edu.cmu.tetrad.data.ICovarianceMatrix) HpcJobActivityAction(edu.cmu.tetradapp.app.hpc.action.HpcJobActivityAction) BorderLayout(java.awt.BorderLayout) ButtonModel(javax.swing.ButtonModel) JComboBox(javax.swing.JComboBox) GraphSelectionWrapper(edu.cmu.tetradapp.model.GraphSelectionWrapper) Method(java.lang.reflect.Method) Path(java.nio.file.Path) GeneralAlgorithmRunner(edu.cmu.tetradapp.model.GeneralAlgorithmRunner) BdeuScore(edu.cmu.tetrad.algcomparison.score.BdeuScore) HpcJobManager(edu.cmu.tetradapp.app.hpc.manager.HpcJobManager) Frame(java.awt.Frame) HpcAccountManager(edu.cmu.tetradapp.app.hpc.manager.HpcAccountManager) WatchedProcess(edu.cmu.tetradapp.util.WatchedProcess) EnumMap(java.util.EnumMap) AlgType(edu.cmu.tetrad.annotation.AlgType) HpcParameter(edu.pitt.dbmi.tetrad.db.entity.HpcParameter) Set(java.util.Set) BorderFactory(javax.swing.BorderFactory) ComboBoxModel(javax.swing.ComboBoxModel) Gaussian(edu.cmu.tetrad.annotation.Gaussian) Nonexecutable(edu.cmu.tetrad.annotation.Nonexecutable) JsonWebToken(edu.pitt.dbmi.ccd.rest.client.dto.user.JsonWebToken) HpcJobInfo(edu.pitt.dbmi.tetrad.db.entity.HpcJobInfo) JRadioButton(javax.swing.JRadioButton) DataModel(edu.cmu.tetrad.data.DataModel) InvocationTargetException(java.lang.reflect.InvocationTargetException) Box(javax.swing.Box) Dimension(java.awt.Dimension) List(java.util.List) DataSet(edu.cmu.tetrad.data.DataSet) JCheckBox(javax.swing.JCheckBox) Optional(java.util.Optional) MessageDigestHash(edu.pitt.dbmi.ccd.commons.file.MessageDigestHash) JPanel(javax.swing.JPanel) Toolkit(java.awt.Toolkit) CardLayout(java.awt.CardLayout) JOptionUtils(edu.cmu.tetrad.util.JOptionUtils) HashMap(java.util.HashMap) AlgorithmModels(edu.cmu.tetradapp.ui.model.AlgorithmModels) FinalizingEditor(edu.cmu.tetradapp.util.FinalizingEditor) AlgorithmParamRequest(edu.pitt.dbmi.tetrad.db.entity.AlgorithmParamRequest) SwingConstants(javax.swing.SwingConstants) ArrayList(java.util.ArrayList) SingleGraphAlg(edu.cmu.tetrad.algcomparison.algorithm.oracle.pattern.SingleGraphAlg) HashSet(java.util.HashSet) HasKnowledge(edu.cmu.tetrad.algcomparison.utils.HasKnowledge) SemBicScore(edu.cmu.tetrad.algcomparison.score.SemBicScore) DataType(edu.cmu.tetrad.data.DataType) JsonUtils(edu.cmu.tetrad.util.JsonUtils) IndependenceTestModel(edu.cmu.tetradapp.ui.model.IndependenceTestModel) IndependenceTestModels(edu.cmu.tetradapp.ui.model.IndependenceTestModels) TsImages(edu.cmu.tetrad.algcomparison.algorithm.oracle.pag.TsImages) LinkedList(java.util.LinkedList) AlgorithmFactory(edu.cmu.tetrad.algcomparison.algorithm.AlgorithmFactory) Algorithm(edu.cmu.tetrad.algcomparison.algorithm.Algorithm) JButton(javax.swing.JButton) Logger(org.slf4j.Logger) AlgorithmParameter(edu.pitt.dbmi.tetrad.db.entity.AlgorithmParameter) Files(java.nio.file.Files) ButtonGroup(javax.swing.ButtonGroup) TakesInitialGraph(edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph) Window(java.awt.Window) JList(javax.swing.JList) DesktopController(edu.cmu.tetradapp.util.DesktopController) Graph(edu.cmu.tetrad.graph.Graph) HpcAccountUtils(edu.cmu.tetradapp.app.hpc.util.HpcAccountUtils) JvmOptions(edu.pitt.dbmi.tetrad.db.entity.JvmOptions) IOException(java.io.IOException) JOptionPane(javax.swing.JOptionPane) ActionEvent(java.awt.event.ActionEvent) JScrollPane(javax.swing.JScrollPane) LayoutStyle(javax.swing.LayoutStyle) DataModelList(edu.cmu.tetrad.data.DataModelList) ScoreModels(edu.cmu.tetradapp.ui.model.ScoreModels) DefaultListModel(javax.swing.DefaultListModel) JLabel(javax.swing.JLabel) GroupLayout(javax.swing.GroupLayout) HpcAccount(edu.pitt.dbmi.tetrad.db.entity.HpcAccount) JTextArea(javax.swing.JTextArea) Knowledge2(edu.cmu.tetrad.data.Knowledge2) DataType(edu.cmu.tetrad.data.DataType) AlgorithmModel(edu.cmu.tetradapp.ui.model.AlgorithmModel) ScoreModel(edu.cmu.tetradapp.ui.model.ScoreModel) LinkedList(java.util.LinkedList) Linear(edu.cmu.tetrad.annotation.Linear)

Example 3 with DataType

use of edu.cmu.tetrad.data.DataType in project tetrad by cmu-phil.

the class IndependenceTestModels method initDefaultModelMap.

private void initDefaultModelMap() {
    DataType[] dataTypes = DataType.values();
    for (DataType dataType : dataTypes) {
        List<IndependenceTestModel> list = getModels(dataType);
        if (!list.isEmpty()) {
            String property = getProperty(dataType);
            if (property == null) {
                defaultModelMap.put(dataType, list.get(0));
            } else {
                String value = TetradProperties.getInstance().getValue(property);
                if (value == null) {
                    defaultModelMap.put(dataType, list.get(0));
                } else {
                    Optional<IndependenceTestModel> result = list.stream().filter(e -> e.getIndependenceTest().getClazz().getName().equals(value)).findFirst();
                    defaultModelMap.put(dataType, result.isPresent() ? result.get() : list.get(0));
                }
            }
        }
    }
}
Also used : TestOfIndependenceAnnotations(edu.cmu.tetrad.annotation.TestOfIndependenceAnnotations) List(java.util.List) DataType(edu.cmu.tetrad.data.DataType) Stream(java.util.stream.Stream) EnumMap(java.util.EnumMap) Map(java.util.Map) TetradProperties(edu.cmu.tetrad.util.TetradProperties) Optional(java.util.Optional) LinkedList(java.util.LinkedList) Collections(java.util.Collections) Collectors(java.util.stream.Collectors) DataType(edu.cmu.tetrad.data.DataType)

Example 4 with DataType

use of edu.cmu.tetrad.data.DataType in project tetrad by cmu-phil.

the class GeneralAlgorithmRunner method execute.

// ============================PUBLIC METHODS==========================//
@Override
public void execute() {
    List<Graph> graphList = new ArrayList<>();
    int i = 0;
    if (getDataModelList().isEmpty()) {
        if (getSourceGraph() != null) {
            Algorithm algo = getAlgorithm();
            if (algo instanceof HasKnowledge) {
                ((HasKnowledge) algo).setKnowledge(getKnowledge());
            }
            graphList.add(algo.search(null, parameters));
        } else {
            throw new IllegalArgumentException("The parent boxes did not include any datasets or graphs. Try opening\n" + "the editors for those boxes and loading or simulating them.");
        }
    } else {
        if (getAlgorithm() instanceof MultiDataSetAlgorithm) {
            for (int k = 0; k < parameters.getInt("numRuns"); k++) {
                List<DataSet> dataSets = getDataModelList().stream().map(e -> (DataSet) e).collect(Collectors.toCollection(ArrayList::new));
                if (dataSets.size() < parameters.getInt("randomSelectionSize")) {
                    throw new IllegalArgumentException("Sorry, the 'random selection size' is greater than " + "the number of data sets.");
                }
                Collections.shuffle(dataSets);
                List<DataModel> sub = new ArrayList<>();
                for (int j = 0; j < parameters.getInt("randomSelectionSize"); j++) {
                    sub.add(dataSets.get(j));
                }
                Algorithm algo = getAlgorithm();
                if (algo instanceof HasKnowledge) {
                    ((HasKnowledge) algo).setKnowledge(getKnowledge());
                }
                graphList.add(((MultiDataSetAlgorithm) algo).search(sub, parameters));
            }
        } else if (getAlgorithm() instanceof ClusterAlgorithm) {
            for (int k = 0; k < parameters.getInt("numRuns"); k++) {
                getDataModelList().forEach(dataModel -> {
                    if (dataModel instanceof ICovarianceMatrix) {
                        ICovarianceMatrix dataSet = (ICovarianceMatrix) dataModel;
                        graphList.add(algorithm.search(dataSet, parameters));
                    } else if (dataModel instanceof DataSet) {
                        DataSet dataSet = (DataSet) dataModel;
                        if (!dataSet.isContinuous()) {
                            throw new IllegalArgumentException("Sorry, you need a continuous dataset for a cluster algorithm.");
                        }
                        graphList.add(algorithm.search(dataSet, parameters));
                    }
                });
            }
        } else {
            getDataModelList().forEach(data -> {
                IKnowledge knowledgeFromData = data.getKnowledge();
                if (!(knowledgeFromData == null || knowledgeFromData.getVariables().isEmpty())) {
                    this.knowledge = knowledgeFromData;
                }
                Algorithm algo = getAlgorithm();
                if (algo instanceof HasKnowledge) {
                    ((HasKnowledge) algo).setKnowledge(getKnowledge());
                }
                DataType algDataType = algo.getDataType();
                if (data.isContinuous() && (algDataType == DataType.Continuous || algDataType == DataType.Mixed)) {
                    graphList.add(algo.search(data, parameters));
                } else if (data.isDiscrete() && (algDataType == DataType.Discrete || algDataType == DataType.Mixed)) {
                    graphList.add(algo.search(data, parameters));
                } else if (data.isMixed() && algDataType == DataType.Mixed) {
                    graphList.add(algo.search(data, parameters));
                } else {
                    throw new IllegalArgumentException("The type of data changed; try opening up the search editor and " + "running the algorithm there.");
                }
            });
        }
    }
    if (getKnowledge().getVariablesNotInTiers().size() < getKnowledge().getVariables().size()) {
        for (Graph graph : graphList) {
            SearchGraphUtils.arrangeByKnowledgeTiers(graph, getKnowledge());
        }
    } else {
        for (Graph graph : graphList) {
            GraphUtils.circleLayout(graph, 225, 200, 150);
        }
    }
    this.graphList = graphList;
}
Also used : GraphUtils(edu.cmu.tetrad.graph.GraphUtils) ObjectInputStream(java.io.ObjectInputStream) Parameters(edu.cmu.tetrad.util.Parameters) HashMap(java.util.HashMap) Triple(edu.cmu.tetrad.graph.Triple) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) HasKnowledge(edu.cmu.tetrad.algcomparison.utils.HasKnowledge) DataType(edu.cmu.tetrad.data.DataType) KnowledgeBoxInput(edu.cmu.tetrad.data.KnowledgeBoxInput) Map(java.util.Map) ClusterAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.cluster.ClusterAlgorithm) ICovarianceMatrix(edu.cmu.tetrad.data.ICovarianceMatrix) Fges(edu.cmu.tetrad.algcomparison.algorithm.oracle.pattern.Fges) BdeuScore(edu.cmu.tetrad.algcomparison.score.BdeuScore) Algorithm(edu.cmu.tetrad.algcomparison.algorithm.Algorithm) IKnowledge(edu.cmu.tetrad.data.IKnowledge) Graph(edu.cmu.tetrad.graph.Graph) IOException(java.io.IOException) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) Collectors(java.util.stream.Collectors) DataModel(edu.cmu.tetrad.data.DataModel) DataModelList(edu.cmu.tetrad.data.DataModelList) List(java.util.List) ParamsResettable(edu.cmu.tetrad.session.ParamsResettable) DataSet(edu.cmu.tetrad.data.DataSet) ImpliedOrientation(edu.cmu.tetrad.search.ImpliedOrientation) MultiDataSetAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm) SessionModel(edu.cmu.tetrad.session.SessionModel) IndependenceTest(edu.cmu.tetrad.search.IndependenceTest) SearchGraphUtils(edu.cmu.tetrad.search.SearchGraphUtils) Knowledge2(edu.cmu.tetrad.data.Knowledge2) Unmarshallable(edu.cmu.tetrad.util.Unmarshallable) Collections(java.util.Collections) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) DataSet(edu.cmu.tetrad.data.DataSet) ICovarianceMatrix(edu.cmu.tetrad.data.ICovarianceMatrix) ArrayList(java.util.ArrayList) HasKnowledge(edu.cmu.tetrad.algcomparison.utils.HasKnowledge) ClusterAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.cluster.ClusterAlgorithm) Algorithm(edu.cmu.tetrad.algcomparison.algorithm.Algorithm) MultiDataSetAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm) IKnowledge(edu.cmu.tetrad.data.IKnowledge) ClusterAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.cluster.ClusterAlgorithm) Graph(edu.cmu.tetrad.graph.Graph) MultiDataSetAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm) DataModel(edu.cmu.tetrad.data.DataModel) DataType(edu.cmu.tetrad.data.DataType)

Example 5 with DataType

use of edu.cmu.tetrad.data.DataType in project tetrad by cmu-phil.

the class TimeoutComparison method compareFromSimulations.

/**
 * Compares algorithms.
 *
 * @param resultsPath Path to the file where the output should be printed.
 * @param simulations The list of simulationWrapper that is used to generate
 * graphs and data for the comparison.
 * @param algorithms The list of algorithms to be compared.
 * @param statistics The list of statistics on which to compare the
 * algorithm, and their utility weights.
 */
public void compareFromSimulations(String resultsPath, Simulations simulations, String outputFileName, Algorithms algorithms, Statistics statistics, Parameters parameters, long timeout, TimeUnit unit) {
    this.resultsPath = resultsPath;
    // Create output file.
    try {
        File dir = new File(resultsPath);
        dir.mkdirs();
        File file = new File(dir, outputFileName);
        this.out = new PrintStream(new FileOutputStream(file));
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
    out.println(new Date());
    // Set up simulations--create data and graphs, read in parameters. The parameters
    // are set in the parameters object.
    List<SimulationWrapper> simulationWrappers = new ArrayList<>();
    int numRuns = parameters.getInt("numRuns");
    for (Simulation simulation : simulations.getSimulations()) {
        List<SimulationWrapper> wrappers = getSimulationWrappers(simulation, parameters);
        for (SimulationWrapper wrapper : wrappers) {
            wrapper.createData(wrapper.getSimulationSpecificParameters());
            simulationWrappers.add(wrapper);
        }
    }
    // Set up the algorithms.
    List<AlgorithmWrapper> algorithmWrappers = new ArrayList<>();
    for (Algorithm algorithm : algorithms.getAlgorithms()) {
        List<Integer> _dims = new ArrayList<>();
        List<String> varyingParameters = new ArrayList<>();
        final List<String> parameters1 = algorithm.getParameters();
        for (String name : parameters1) {
            if (parameters.getNumValues(name) > 1) {
                _dims.add(parameters.getNumValues(name));
                varyingParameters.add(name);
            }
        }
        if (varyingParameters.isEmpty()) {
            algorithmWrappers.add(new AlgorithmWrapper(algorithm, parameters));
        } else {
            int[] dims = new int[_dims.size()];
            for (int i = 0; i < _dims.size(); i++) {
                dims[i] = _dims.get(i);
            }
            CombinationGenerator gen = new CombinationGenerator(dims);
            int[] choice;
            while ((choice = gen.next()) != null) {
                AlgorithmWrapper wrapper = new AlgorithmWrapper(algorithm, parameters);
                for (int h = 0; h < dims.length; h++) {
                    String parameter = varyingParameters.get(h);
                    Object[] values = parameters.getValues(parameter);
                    Object value = values[choice[h]];
                    wrapper.setValue(parameter, value);
                }
                algorithmWrappers.add(wrapper);
            }
        }
    }
    // Create the algorithm-simulation wrappers for every combination of algorithm and
    // simulation.
    List<AlgorithmSimulationWrapper> algorithmSimulationWrappers = new ArrayList<>();
    for (SimulationWrapper simulationWrapper : simulationWrappers) {
        for (AlgorithmWrapper algorithmWrapper : algorithmWrappers) {
            DataType algDataType = algorithmWrapper.getDataType();
            DataType simDataType = simulationWrapper.getDataType();
            if (!(algDataType == DataType.Mixed || (algDataType == simDataType))) {
                System.out.println("Type mismatch: " + algorithmWrapper.getDescription() + " / " + simulationWrapper.getDescription());
            }
            if (algorithmWrapper.getAlgorithm() instanceof ExternalAlgorithm) {
                ExternalAlgorithm external = (ExternalAlgorithm) algorithmWrapper.getAlgorithm();
                // external.setSimulation(simulationWrapper.getSimulation());
                // external.setPath(dirs.get(simulationWrappers.indexOf(simulationWrapper)));
                // external.setPath(resultsPath);
                external.setSimIndex(simulationWrappers.indexOf(external.getSimulation()));
            }
            algorithmSimulationWrappers.add(new AlgorithmSimulationWrapper(algorithmWrapper, simulationWrapper));
        }
    }
    // Run all of the algorithms and compile statistics.
    double[][][][] allStats = calcStats(algorithmSimulationWrappers, algorithmWrappers, simulationWrappers, statistics, numRuns, timeout, unit);
    // Print out the preliminary information for statistics types, etc.
    if (allStats != null) {
        out.println();
        out.println("Statistics:");
        out.println();
        for (Statistic stat : statistics.getStatistics()) {
            out.println(stat.getAbbreviation() + " = " + stat.getDescription());
        }
    }
    out.println();
    // out.println();
    if (allStats != null) {
        int numTables = allStats.length;
        int numStats = allStats[0][0].length - 1;
        double[][][] statTables = calcStatTables(allStats, Mode.Average, numTables, algorithmSimulationWrappers, numStats, statistics);
        double[] utilities = calcUtilities(statistics, algorithmSimulationWrappers, statTables[0]);
        // Add utilities to table as the last column.
        for (int u = 0; u < numTables; u++) {
            for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
                statTables[u][t][numStats] = utilities[t];
            }
        }
        int[] newOrder;
        if (isSortByUtility()) {
            newOrder = sort(algorithmSimulationWrappers, utilities);
        } else {
            newOrder = new int[algorithmSimulationWrappers.size()];
            for (int q = 0; q < algorithmSimulationWrappers.size(); q++) {
                newOrder[q] = q;
            }
        }
        out.println("Simulations:");
        out.println();
        // if (simulationWrappers.size() == 1) {
        // out.println(simulationWrappers.get(0).getDescription());
        // } else {
        int i = 0;
        for (SimulationWrapper simulation : simulationWrappers) {
            out.print("Simulation " + (++i) + ": ");
            out.println(simulation.getDescription());
            out.println();
            printParameters(simulation.getParameters(), simulation.getSimulationSpecificParameters(), out);
            // for (String param : simulation.getParameters()) {
            // out.println(param + " = " + simulation.getValue(param));
            // }
            out.println();
        }
        // }
        out.println("Algorithms:");
        out.println();
        for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
            AlgorithmSimulationWrapper wrapper = algorithmSimulationWrappers.get(t);
            if (wrapper.getSimulationWrapper() == simulationWrappers.get(0)) {
                out.println((t + 1) + ". " + wrapper.getAlgorithmWrapper().getDescription());
            }
        }
        if (isSortByUtility()) {
            out.println();
            out.println("Sorting by utility, high to low.");
        }
        if (isShowUtilities()) {
            out.println();
            out.println("Weighting of statistics:");
            out.println();
            out.println("U = ");
            for (Statistic stat : statistics.getStatistics()) {
                String statName = stat.getAbbreviation();
                double weight = statistics.getWeight(stat);
                if (weight != 0.0) {
                    out.println("    " + weight + " * f(" + statName + ")");
                }
            }
            out.println();
            out.println("...normed to range between 0 and 1.");
            out.println();
            out.println("Note that f for each statistic is a function that maps the statistic to the ");
            out.println("interval [0, 1], with higher being better.");
        }
        out.println();
        out.println("Graphs are being compared to the " + comparisonGraph.toString().replace("_", " ") + ".");
        out.println();
        // Add utilities to table as the last column.
        for (int u = 0; u < numTables; u++) {
            for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
                statTables[u][t][numStats] = utilities[t];
            }
        }
        // Print all of the tables.
        printStats(statTables, statistics, Mode.Average, newOrder, algorithmSimulationWrappers, algorithmWrappers, simulationWrappers, utilities, parameters);
        statTables = calcStatTables(allStats, Mode.StandardDeviation, numTables, algorithmSimulationWrappers, numStats, statistics);
        printStats(statTables, statistics, Mode.StandardDeviation, newOrder, algorithmSimulationWrappers, algorithmWrappers, simulationWrappers, utilities, parameters);
        statTables = calcStatTables(allStats, Mode.WorstCase, numTables, algorithmSimulationWrappers, numStats, statistics);
        // Add utilities to table as the last column.
        for (int u = 0; u < numTables; u++) {
            for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
                statTables[u][t][numStats] = utilities[t];
            }
        }
        printStats(statTables, statistics, Mode.WorstCase, newOrder, algorithmSimulationWrappers, algorithmWrappers, simulationWrappers, utilities, parameters);
    }
    out.close();
}
Also used : ArrayList(java.util.ArrayList) ExternalAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.ExternalAlgorithm) Statistic(edu.cmu.tetrad.algcomparison.statistic.Statistic) DataType(edu.cmu.tetrad.data.DataType) PrintStream(java.io.PrintStream) CombinationGenerator(edu.cmu.tetrad.util.CombinationGenerator) ExternalAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.ExternalAlgorithm) Algorithm(edu.cmu.tetrad.algcomparison.algorithm.Algorithm) MultiDataSetAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm) TimeoutException(java.util.concurrent.TimeoutException) FileNotFoundException(java.io.FileNotFoundException) IOException(java.io.IOException) ExecutionException(java.util.concurrent.ExecutionException) Date(java.util.Date) Simulation(edu.cmu.tetrad.algcomparison.simulation.Simulation) FileOutputStream(java.io.FileOutputStream) File(java.io.File)

Aggregations

DataType (edu.cmu.tetrad.data.DataType)9 List (java.util.List)5 Map (java.util.Map)5 Algorithm (edu.cmu.tetrad.algcomparison.algorithm.Algorithm)4 EnumMap (java.util.EnumMap)4 LinkedList (java.util.LinkedList)4 Optional (java.util.Optional)4 BdeuScore (edu.cmu.tetrad.algcomparison.score.BdeuScore)3 HasKnowledge (edu.cmu.tetrad.algcomparison.utils.HasKnowledge)3 DataModel (edu.cmu.tetrad.data.DataModel)3 DataModelList (edu.cmu.tetrad.data.DataModelList)3 DataSet (edu.cmu.tetrad.data.DataSet)3 ICovarianceMatrix (edu.cmu.tetrad.data.ICovarianceMatrix)3 Knowledge2 (edu.cmu.tetrad.data.Knowledge2)3 Graph (edu.cmu.tetrad.graph.Graph)3 Node (edu.cmu.tetrad.graph.Node)3 Parameters (edu.cmu.tetrad.util.Parameters)3 IOException (java.io.IOException)3 ArrayList (java.util.ArrayList)3 AlgorithmFactory (edu.cmu.tetrad.algcomparison.algorithm.AlgorithmFactory)2