Search in sources :

Example 1 with Algorithm

use of edu.cmu.tetrad.algcomparison.algorithm.Algorithm in project tetrad by cmu-phil.

the class GeneralAlgorithmEditor method restorePreviousState.

private void restorePreviousState(Map<String, Object> models) {
    Object obj = models.get(LINEAR_PARAM);
    if ((obj != null) && (obj instanceof Boolean)) {
        linearVarChkBox.setSelected((Boolean) obj);
    }
    obj = models.get(GAUSSIAN_PARAM);
    if ((obj != null) && (obj instanceof Boolean)) {
        gaussianVarChkBox.setSelected((Boolean) obj);
    }
    obj = models.get(KNOWLEDGE_PARAM);
    if ((obj != null) && (obj instanceof Boolean)) {
        knowledgeChkBox.setSelected((Boolean) obj);
    }
    obj = models.get(ALGO_TYPE_PARAM);
    if ((obj != null) && (obj instanceof String)) {
        String actCmd = String.valueOf(obj);
        Optional<JRadioButton> opt = algoTypeOpts.stream().filter(e -> e.getActionCommand().equals(actCmd)).findFirst();
        if (opt.isPresent()) {
            opt.get().setSelected(true);
        }
    }
    refreshAlgorithmList();
    refreshTestAndScoreList();
    obj = models.get(ALGO_PARAM);
    if ((obj != null) && (obj instanceof AlgorithmModel)) {
        String value = ((AlgorithmModel) obj).toString();
        Enumeration<AlgorithmModel> enums = algoModels.elements();
        while (enums.hasMoreElements()) {
            AlgorithmModel model = enums.nextElement();
            if (model.toString().equals(value)) {
                models.put(ALGO_PARAM, model);
                algorithmList.setSelectedValue(model, true);
                String title = String.format("Algorithm: %s", model.getAlgorithm().getAnnotation().name());
                algorithmGraphTitle.setText(title);
                break;
            }
        }
    }
    obj = models.get(IND_TEST_PARAM);
    if ((obj != null) && (obj instanceof IndependenceTestModel)) {
        String value = ((IndependenceTestModel) obj).toString();
        ComboBoxModel<IndependenceTestModel> comboBoxModels = indTestComboBox.getModel();
        int size = comboBoxModels.getSize();
        for (int i = 0; i < size; i++) {
            IndependenceTestModel model = comboBoxModels.getElementAt(i);
            if (model.toString().equals(value)) {
                models.put(IND_TEST_PARAM, model);
                indTestComboBox.getModel().setSelectedItem(model);
                break;
            }
        }
    }
    obj = models.get(SCORE_PARAM);
    if ((obj != null) && (obj instanceof ScoreModel)) {
        String value = ((ScoreModel) obj).toString();
        ComboBoxModel<ScoreModel> comboBoxModels = scoreComboBox.getModel();
        int size = comboBoxModels.getSize();
        for (int i = 0; i < size; i++) {
            ScoreModel model = comboBoxModels.getElementAt(i);
            if (model.toString().equals(value)) {
                models.put(SCORE_PARAM, model);
                scoreComboBox.getModel().setSelectedItem(model);
                break;
            }
        }
    }
}
Also used : Enumeration(java.util.Enumeration) JDialog(javax.swing.JDialog) TetradDesktop(edu.cmu.tetradapp.app.TetradDesktop) LoggerFactory(org.slf4j.LoggerFactory) Parameters(edu.cmu.tetrad.util.Parameters) Node(edu.cmu.tetrad.graph.Node) PaddingPanel(edu.cmu.tetradapp.ui.PaddingPanel) StringUtils(org.apache.commons.lang3.StringUtils) Linear(edu.cmu.tetrad.annotation.Linear) AlgorithmModel(edu.cmu.tetradapp.ui.model.AlgorithmModel) Map(java.util.Map) ScoreModel(edu.cmu.tetradapp.ui.model.ScoreModel) ICovarianceMatrix(edu.cmu.tetrad.data.ICovarianceMatrix) HpcJobActivityAction(edu.cmu.tetradapp.app.hpc.action.HpcJobActivityAction) BorderLayout(java.awt.BorderLayout) ButtonModel(javax.swing.ButtonModel) JComboBox(javax.swing.JComboBox) GraphSelectionWrapper(edu.cmu.tetradapp.model.GraphSelectionWrapper) Method(java.lang.reflect.Method) Path(java.nio.file.Path) GeneralAlgorithmRunner(edu.cmu.tetradapp.model.GeneralAlgorithmRunner) BdeuScore(edu.cmu.tetrad.algcomparison.score.BdeuScore) HpcJobManager(edu.cmu.tetradapp.app.hpc.manager.HpcJobManager) Frame(java.awt.Frame) HpcAccountManager(edu.cmu.tetradapp.app.hpc.manager.HpcAccountManager) WatchedProcess(edu.cmu.tetradapp.util.WatchedProcess) EnumMap(java.util.EnumMap) AlgType(edu.cmu.tetrad.annotation.AlgType) HpcParameter(edu.pitt.dbmi.tetrad.db.entity.HpcParameter) Set(java.util.Set) BorderFactory(javax.swing.BorderFactory) ComboBoxModel(javax.swing.ComboBoxModel) Gaussian(edu.cmu.tetrad.annotation.Gaussian) Nonexecutable(edu.cmu.tetrad.annotation.Nonexecutable) JsonWebToken(edu.pitt.dbmi.ccd.rest.client.dto.user.JsonWebToken) HpcJobInfo(edu.pitt.dbmi.tetrad.db.entity.HpcJobInfo) JRadioButton(javax.swing.JRadioButton) DataModel(edu.cmu.tetrad.data.DataModel) InvocationTargetException(java.lang.reflect.InvocationTargetException) Box(javax.swing.Box) Dimension(java.awt.Dimension) List(java.util.List) DataSet(edu.cmu.tetrad.data.DataSet) JCheckBox(javax.swing.JCheckBox) Optional(java.util.Optional) MessageDigestHash(edu.pitt.dbmi.ccd.commons.file.MessageDigestHash) JPanel(javax.swing.JPanel) Toolkit(java.awt.Toolkit) CardLayout(java.awt.CardLayout) JOptionUtils(edu.cmu.tetrad.util.JOptionUtils) HashMap(java.util.HashMap) AlgorithmModels(edu.cmu.tetradapp.ui.model.AlgorithmModels) FinalizingEditor(edu.cmu.tetradapp.util.FinalizingEditor) AlgorithmParamRequest(edu.pitt.dbmi.tetrad.db.entity.AlgorithmParamRequest) SwingConstants(javax.swing.SwingConstants) ArrayList(java.util.ArrayList) SingleGraphAlg(edu.cmu.tetrad.algcomparison.algorithm.oracle.pattern.SingleGraphAlg) HashSet(java.util.HashSet) HasKnowledge(edu.cmu.tetrad.algcomparison.utils.HasKnowledge) SemBicScore(edu.cmu.tetrad.algcomparison.score.SemBicScore) DataType(edu.cmu.tetrad.data.DataType) JsonUtils(edu.cmu.tetrad.util.JsonUtils) IndependenceTestModel(edu.cmu.tetradapp.ui.model.IndependenceTestModel) IndependenceTestModels(edu.cmu.tetradapp.ui.model.IndependenceTestModels) TsImages(edu.cmu.tetrad.algcomparison.algorithm.oracle.pag.TsImages) LinkedList(java.util.LinkedList) AlgorithmFactory(edu.cmu.tetrad.algcomparison.algorithm.AlgorithmFactory) Algorithm(edu.cmu.tetrad.algcomparison.algorithm.Algorithm) JButton(javax.swing.JButton) Logger(org.slf4j.Logger) AlgorithmParameter(edu.pitt.dbmi.tetrad.db.entity.AlgorithmParameter) Files(java.nio.file.Files) ButtonGroup(javax.swing.ButtonGroup) TakesInitialGraph(edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph) Window(java.awt.Window) JList(javax.swing.JList) DesktopController(edu.cmu.tetradapp.util.DesktopController) Graph(edu.cmu.tetrad.graph.Graph) HpcAccountUtils(edu.cmu.tetradapp.app.hpc.util.HpcAccountUtils) JvmOptions(edu.pitt.dbmi.tetrad.db.entity.JvmOptions) IOException(java.io.IOException) JOptionPane(javax.swing.JOptionPane) ActionEvent(java.awt.event.ActionEvent) JScrollPane(javax.swing.JScrollPane) LayoutStyle(javax.swing.LayoutStyle) DataModelList(edu.cmu.tetrad.data.DataModelList) ScoreModels(edu.cmu.tetradapp.ui.model.ScoreModels) DefaultListModel(javax.swing.DefaultListModel) JLabel(javax.swing.JLabel) GroupLayout(javax.swing.GroupLayout) HpcAccount(edu.pitt.dbmi.tetrad.db.entity.HpcAccount) JTextArea(javax.swing.JTextArea) Knowledge2(edu.cmu.tetrad.data.Knowledge2) JRadioButton(javax.swing.JRadioButton) IndependenceTestModel(edu.cmu.tetradapp.ui.model.IndependenceTestModel) ScoreModel(edu.cmu.tetradapp.ui.model.ScoreModel) AlgorithmModel(edu.cmu.tetradapp.ui.model.AlgorithmModel)

Example 2 with Algorithm

use of edu.cmu.tetrad.algcomparison.algorithm.Algorithm in project tetrad by cmu-phil.

the class GeneralAlgorithmEditor method validateAlgorithmOption.

private void validateAlgorithmOption() {
    paramSetFwdBtn.setEnabled(true);
    AlgorithmModel algoModel = algorithmList.getSelectedValue();
    Class algoClass = algoModel.getAlgorithm().getClazz();
    if (algoClass.isAnnotationPresent(Nonexecutable.class)) {
        String msg;
        try {
            Object algo = algoClass.newInstance();
            Method m = algoClass.getDeclaredMethod("getDescription");
            m.setAccessible(true);
            try {
                msg = String.valueOf(m.invoke(algo));
            } catch (InvocationTargetException exception) {
                msg = "";
            }
        } catch (IllegalAccessException | InstantiationException | NoSuchMethodException exception) {
            LOGGER.error("", exception);
            msg = "";
        }
        paramSetFwdBtn.setEnabled(false);
        JOptionPane.showMessageDialog(desktop, msg, "Please Note", JOptionPane.INFORMATION_MESSAGE);
    } else {
        // Check if initial graph is provided for those pairwise algorithms
        if (TakesInitialGraph.class.isAssignableFrom(algoClass)) {
            if (runner.getSourceGraph() == null || runner.getDataModelList().isEmpty()) {
                try {
                    Object algo = algoClass.newInstance();
                    Method m = algoClass.getDeclaredMethod("setInitialGraph", Algorithm.class);
                    m.setAccessible(true);
                    try {
                        Algorithm algorithm = null;
                        m.invoke(algo, algorithm);
                    } catch (InvocationTargetException | IllegalArgumentException exception) {
                        paramSetFwdBtn.setEnabled(false);
                        JOptionPane.showMessageDialog(desktop, exception.getCause().getMessage(), "Please Note", JOptionPane.INFORMATION_MESSAGE);
                    }
                } catch (IllegalAccessException | InstantiationException | NoSuchMethodException exception) {
                    LOGGER.error("", exception);
                }
            }
        }
    }
// Check dataset data type for those algorithms take mixed data?
}
Also used : Method(java.lang.reflect.Method) Algorithm(edu.cmu.tetrad.algcomparison.algorithm.Algorithm) InvocationTargetException(java.lang.reflect.InvocationTargetException) AlgorithmModel(edu.cmu.tetradapp.ui.model.AlgorithmModel)

Example 3 with Algorithm

use of edu.cmu.tetrad.algcomparison.algorithm.Algorithm in project tetrad by cmu-phil.

the class ExampleFirstInflection method main.

public static void main(String... args) {
    Parameters parameters = new Parameters();
    parameters.set("numMeasures", 40, 100);
    parameters.set("avgDegree", 2);
    parameters.set("sampleSize", 400, 800);
    parameters.set("numRuns", 10);
    parameters.set("differentGraphs", true);
    parameters.set("numLatents", 0);
    parameters.set("maxDegree", 100);
    parameters.set("maxIndegree", 100);
    parameters.set("maxOutdegree", 100);
    parameters.set("connected", false);
    parameters.set("coefLow", 0.2);
    parameters.set("coefHigh", 0.9);
    parameters.set("varLow", 1);
    parameters.set("varHigh", 3);
    parameters.set("verbose", false);
    parameters.set("coefSymmetric", true);
    parameters.set("percentDiscrete", 0);
    parameters.set("numCategories", 3);
    parameters.set("differentGraphs", true);
    parameters.set("intervalBetweenShocks", 10);
    parameters.set("intervalBetweenRecordings", 10);
    parameters.set("fisherEpsilon", 0.001);
    parameters.set("randomizeColumns", true);
    parameters.set("alpha", 1e-8);
    parameters.set("depth", -1);
    parameters.set("penaltyDiscount", 4);
    parameters.set("useMaxPOrientationHeuristic", false);
    parameters.set("maxPOrientationMaxPathLength", 3);
    parameters.set("verbose", false);
    parameters.set("scaleFreeAlpha", 0.00001);
    parameters.set("scaleFreeBeta", 0.4);
    parameters.set("scaleFreeDeltaIn", .1);
    parameters.set("scaleFreeDeltaOut", 3);
    parameters.set("symmetricFirstStep", false);
    parameters.set("faithfulnessAssumed", true);
    parameters.set("maxDegree", 100);
    // parameters.set("logScale", true);
    Statistics statistics = new Statistics();
    statistics.add(new ParameterColumn("numMeasures"));
    statistics.add(new ParameterColumn("avgDegree"));
    statistics.add(new ParameterColumn("sampleSize"));
    statistics.add(new AdjacencyPrecision());
    statistics.add(new AdjacencyRecall());
    statistics.add(new ArrowheadPrecision());
    statistics.add(new ArrowheadRecall());
    statistics.add(new ElapsedTime());
    statistics.setWeight("AP", 0.25);
    statistics.setWeight("AR", 0.25);
    statistics.setWeight("AHP", 0.25);
    statistics.setWeight("AHR", 0.25);
    Algorithms algorithms = new Algorithms();
    Algorithm fges = new Fges(new SemBicScore());
    // algorithms.add(new FirstInflection(fges, "alpha", -7, -2, -.5));
    algorithms.add(new FirstInflection(fges, "penaltyDiscount", 0.7, 5, 1));
    Simulations simulations = new Simulations();
    simulations.add(new LinearFisherModel(new RandomForward()));
    Comparison comparison = new Comparison();
    comparison.setShowAlgorithmIndices(true);
    comparison.setShowSimulationIndices(true);
    comparison.setSortByUtility(false);
    comparison.setShowUtilities(false);
    comparison.setParallelized(true);
    comparison.setComparisonGraph(Comparison.ComparisonGraph.Pattern_of_the_true_DAG);
    comparison.compareFromSimulations("first.inflection", simulations, algorithms, statistics, parameters);
}
Also used : Simulations(edu.cmu.tetrad.algcomparison.simulation.Simulations) Parameters(edu.cmu.tetrad.util.Parameters) LinearFisherModel(edu.cmu.tetrad.algcomparison.simulation.LinearFisherModel) RandomForward(edu.cmu.tetrad.algcomparison.graph.RandomForward) Algorithm(edu.cmu.tetrad.algcomparison.algorithm.Algorithm) Fges(edu.cmu.tetrad.algcomparison.algorithm.oracle.pattern.Fges) Algorithms(edu.cmu.tetrad.algcomparison.algorithm.Algorithms) Comparison(edu.cmu.tetrad.algcomparison.Comparison) SemBicScore(edu.cmu.tetrad.algcomparison.score.SemBicScore) FirstInflection(edu.cmu.tetrad.algcomparison.algorithm.FirstInflection)

Example 4 with Algorithm

use of edu.cmu.tetrad.algcomparison.algorithm.Algorithm in project tetrad by cmu-phil.

the class TimeoutComparison method printStats.

private void printStats(double[][][] statTables, Statistics statistics, Mode mode, int[] newOrder, List<AlgorithmSimulationWrapper> algorithmSimulationWrappers, List<AlgorithmWrapper> algorithmWrappers, List<SimulationWrapper> simulationWrappers, double[] utilities, Parameters parameters) {
    if (mode == Mode.Average) {
        out.println("AVERAGE STATISTICS");
    } else if (mode == Mode.StandardDeviation) {
        out.println("STANDARD DEVIATIONS");
    } else if (mode == Mode.WorstCase) {
        out.println("WORST CASE");
    } else {
        throw new IllegalStateException();
    }
    int numTables = statTables.length;
    int numStats = statistics.size();
    NumberFormat nf = new DecimalFormat("0.00");
    NumberFormat smallNf = new DecimalFormat("0.00E0");
    out.println();
    for (int u = 0; u < numTables; u++) {
        if (!graphTypeUsed[u]) {
            continue;
        }
        int rows = algorithmSimulationWrappers.size() + 1;
        int cols = (isShowSimulationIndices() ? 1 : 0) + (isShowAlgorithmIndices() ? 1 : 0) + numStats + (isShowUtilities() ? 1 : 0);
        TextTable table = new TextTable(rows, cols);
        table.setTabDelimited(isTabDelimitedTables());
        int initialColumn = 0;
        if (isShowSimulationIndices()) {
            table.setToken(0, initialColumn, "Sim");
            for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
                Simulation simulation = algorithmSimulationWrappers.get(newOrder[t]).getSimulationWrapper();
                table.setToken(t + 1, initialColumn, "" + (simulationWrappers.indexOf(simulation) + 1));
            }
            initialColumn++;
        }
        if (isShowAlgorithmIndices()) {
            table.setToken(0, initialColumn, "Alg");
            for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
                AlgorithmWrapper algorithm = algorithmSimulationWrappers.get(newOrder[t]).getAlgorithmWrapper();
                table.setToken(t + 1, initialColumn, "" + (algorithmWrappers.indexOf(algorithm) + 1));
            }
            initialColumn++;
        }
        for (int statIndex = 0; statIndex < numStats; statIndex++) {
            String statLabel = statistics.getStatistics().get(statIndex).getAbbreviation();
            table.setToken(0, initialColumn + statIndex, statLabel);
        }
        if (isShowUtilities()) {
            table.setToken(0, initialColumn + numStats, "U");
        }
        for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
            for (int statIndex = 0; statIndex < numStats; statIndex++) {
                Statistic statistic = statistics.getStatistics().get(statIndex);
                final AlgorithmWrapper algorithmWrapper = algorithmSimulationWrappers.get(newOrder[t]).getAlgorithmWrapper();
                final SimulationWrapper simulationWrapper = algorithmSimulationWrappers.get(newOrder[t]).getSimulationWrapper();
                Algorithm algorithm = algorithmWrapper.getAlgorithm();
                Simulation simulation = simulationWrapper.getSimulation();
                if (algorithm instanceof HasParameterValues) {
                    parameters.putAll(((HasParameterValues) algorithm).getParameterValues());
                }
                if (simulation instanceof HasParameterValues) {
                    parameters.putAll(((HasParameterValues) simulation).getParameterValues());
                }
                final String abbreviation = statistic.getAbbreviation();
                Object[] o = parameters.getValues(abbreviation);
                if (o.length == 1 && o[0] instanceof String) {
                    table.setToken(t + 1, initialColumn + statIndex, (String) o[0]);
                    continue;
                }
                double stat = statTables[u][newOrder[t]][statIndex];
                if (stat == 0.0) {
                    table.setToken(t + 1, initialColumn + statIndex, "-");
                } else if (stat == Double.POSITIVE_INFINITY) {
                    table.setToken(t + 1, initialColumn + statIndex, "Yes");
                } else if (stat == Double.NEGATIVE_INFINITY) {
                    table.setToken(t + 1, initialColumn + statIndex, "No");
                } else if (Double.isNaN(stat)) {
                    table.setToken(t + 1, initialColumn + statIndex, "*");
                } else {
                    table.setToken(t + 1, initialColumn + statIndex, Math.abs(stat) < Math.pow(10, -smallNf.getMaximumFractionDigits()) && stat != 0 ? smallNf.format(stat) : nf.format(stat));
                }
            }
            if (isShowUtilities()) {
                table.setToken(t + 1, initialColumn + numStats, nf.format(utilities[newOrder[t]]));
            }
        }
        out.println(getHeader(u));
        out.println();
        out.println(table);
    }
}
Also used : HasParameterValues(edu.cmu.tetrad.algcomparison.utils.HasParameterValues) DecimalFormat(java.text.DecimalFormat) ExternalAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.ExternalAlgorithm) Algorithm(edu.cmu.tetrad.algcomparison.algorithm.Algorithm) MultiDataSetAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm) Simulation(edu.cmu.tetrad.algcomparison.simulation.Simulation) Statistic(edu.cmu.tetrad.algcomparison.statistic.Statistic) TextTable(edu.cmu.tetrad.util.TextTable) NumberFormat(java.text.NumberFormat)

Example 5 with Algorithm

use of edu.cmu.tetrad.algcomparison.algorithm.Algorithm in project tetrad by cmu-phil.

the class Comparison method compareFromSimulations.

/**
 * Compares algorithms.
 *
 * @param resultsPath Path to the file where the output should be printed.
 * @param simulations The list of simulationWrapper that is used to generate graphs and data for the comparison.
 * @param algorithms  The list of algorithms to be compared.
 * @param statistics  The list of statistics on which to compare the algorithm, and their utility weights.
 */
public void compareFromSimulations(String resultsPath, Simulations simulations, String outputFileName, Algorithms algorithms, Statistics statistics, Parameters parameters) {
    this.resultsPath = resultsPath;
    // Create output file.
    try {
        File dir = new File(resultsPath);
        dir.mkdirs();
        File file = new File(dir, outputFileName);
        this.out = new PrintStream(new FileOutputStream(file));
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
    out.println(new Date());
    // Set up simulations--create data and graphs, read in parameters. The parameters
    // are set in the parameters object.
    List<SimulationWrapper> simulationWrappers = new ArrayList<>();
    int numRuns = parameters.getInt("numRuns");
    for (Simulation simulation : simulations.getSimulations()) {
        List<SimulationWrapper> wrappers = getSimulationWrappers(simulation, parameters);
        for (SimulationWrapper wrapper : wrappers) {
            wrapper.createData(wrapper.getSimulationSpecificParameters());
            simulationWrappers.add(wrapper);
        }
    }
    // Set up the algorithms.
    List<AlgorithmWrapper> algorithmWrappers = new ArrayList<>();
    for (Algorithm algorithm : algorithms.getAlgorithms()) {
        List<Integer> _dims = new ArrayList<>();
        List<String> varyingParameters = new ArrayList<>();
        final List<String> parameters1 = algorithm.getParameters();
        for (String name : parameters1) {
            if (parameters.getNumValues(name) > 1) {
                _dims.add(parameters.getNumValues(name));
                varyingParameters.add(name);
            }
        }
        if (varyingParameters.isEmpty()) {
            algorithmWrappers.add(new AlgorithmWrapper(algorithm, parameters));
        } else {
            int[] dims = new int[_dims.size()];
            for (int i = 0; i < _dims.size(); i++) dims[i] = _dims.get(i);
            CombinationGenerator gen = new CombinationGenerator(dims);
            int[] choice;
            while ((choice = gen.next()) != null) {
                AlgorithmWrapper wrapper = new AlgorithmWrapper(algorithm, parameters);
                for (int h = 0; h < dims.length; h++) {
                    String parameter = varyingParameters.get(h);
                    Object[] values = parameters.getValues(parameter);
                    Object value = values[choice[h]];
                    wrapper.setValue(parameter, value);
                }
                algorithmWrappers.add(wrapper);
            }
        }
    }
    // Create the algorithm-simulation wrappers for every combination of algorithm and
    // simulation.
    List<AlgorithmSimulationWrapper> algorithmSimulationWrappers = new ArrayList<>();
    for (SimulationWrapper simulationWrapper : simulationWrappers) {
        for (AlgorithmWrapper algorithmWrapper : algorithmWrappers) {
            DataType algDataType = algorithmWrapper.getDataType();
            DataType simDataType = simulationWrapper.getDataType();
            if (!(algDataType == DataType.Mixed || (algDataType == simDataType))) {
                System.out.println("Type mismatch: " + algorithmWrapper.getDescription() + " / " + simulationWrapper.getDescription());
            }
            if (algorithmWrapper.getAlgorithm() instanceof ExternalAlgorithm) {
                ExternalAlgorithm external = (ExternalAlgorithm) algorithmWrapper.getAlgorithm();
                // external.setSimulation(simulationWrapper.getSimulation());
                // external.setPath(dirs.get(simulationWrappers.indexOf(simulationWrapper)));
                // external.setPath(resultsPath);
                external.setSimIndex(simulationWrappers.indexOf(external.getSimulation()));
            }
            algorithmSimulationWrappers.add(new AlgorithmSimulationWrapper(algorithmWrapper, simulationWrapper));
        }
    }
    // Run all of the algorithms and compile statistics.
    double[][][][] allStats = calcStats(algorithmSimulationWrappers, algorithmWrappers, simulationWrappers, statistics, numRuns);
    // Print out the preliminary information for statistics types, etc.
    if (allStats != null) {
        out.println();
        out.println("Statistics:");
        out.println();
        for (Statistic stat : statistics.getStatistics()) {
            out.println(stat.getAbbreviation() + " = " + stat.getDescription());
        }
    }
    out.println();
    if (allStats != null) {
        int numTables = allStats.length;
        int numStats = allStats[0][0].length - 1;
        double[][][] statTables = calcStatTables(allStats, Mode.Average, numTables, algorithmSimulationWrappers, numStats, statistics);
        double[] utilities = calcUtilities(statistics, algorithmSimulationWrappers, statTables[0]);
        // Add utilities to table as the last column.
        for (int u = 0; u < numTables; u++) {
            for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
                statTables[u][t][numStats] = utilities[t];
            }
        }
        int[] newOrder;
        if (isSortByUtility()) {
            newOrder = sort(algorithmSimulationWrappers, utilities);
        } else {
            newOrder = new int[algorithmSimulationWrappers.size()];
            for (int q = 0; q < algorithmSimulationWrappers.size(); q++) {
                newOrder[q] = q;
            }
        }
        out.println("Simulations:");
        out.println();
        // if (simulationWrappers.size() == 1) {
        // out.println(simulationWrappers.get(0).getDescription());
        // } else {
        int i = 0;
        for (SimulationWrapper simulation : simulationWrappers) {
            out.print("Simulation " + (++i) + ": ");
            out.println(simulation.getDescription());
            out.println();
            printParameters(simulation.getParameters(), simulation.getSimulationSpecificParameters(), out);
            // for (String param : simulation.getParameters()) {
            // out.println(param + " = " + simulation.getValue(param));
            // }
            out.println();
        }
        // }
        out.println("Algorithms:");
        out.println();
        for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
            AlgorithmSimulationWrapper wrapper = algorithmSimulationWrappers.get(t);
            if (wrapper.getSimulationWrapper() == simulationWrappers.get(0)) {
                out.println((t + 1) + ". " + wrapper.getAlgorithmWrapper().getDescription());
            }
        }
        if (isSortByUtility()) {
            out.println();
            out.println("Sorting by utility, high to low.");
        }
        if (isShowUtilities()) {
            out.println();
            out.println("Weighting of statistics:");
            out.println();
            out.println("U = ");
            for (Statistic stat : statistics.getStatistics()) {
                String statName = stat.getAbbreviation();
                double weight = statistics.getWeight(stat);
                if (weight != 0.0) {
                    out.println("    " + weight + " * f(" + statName + ")");
                }
            }
            out.println();
            out.println("...normed to range between 0 and 1.");
            out.println();
            out.println("Note that f for each statistic is a function that maps the statistic to the ");
            out.println("interval [0, 1], with higher being better.");
        }
        out.println();
        out.println("Graphs are being compared to the " + comparisonGraph.toString().replace("_", " ") + ".");
        out.println();
        // Add utilities to table as the last column.
        for (int u = 0; u < numTables; u++) {
            for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
                statTables[u][t][numStats] = utilities[t];
            }
        }
        // Print all of the tables.
        printStats(statTables, statistics, Mode.Average, newOrder, algorithmSimulationWrappers, algorithmWrappers, simulationWrappers, utilities, parameters);
        statTables = calcStatTables(allStats, Mode.StandardDeviation, numTables, algorithmSimulationWrappers, numStats, statistics);
        printStats(statTables, statistics, Mode.StandardDeviation, newOrder, algorithmSimulationWrappers, algorithmWrappers, simulationWrappers, utilities, parameters);
        statTables = calcStatTables(allStats, Mode.WorstCase, numTables, algorithmSimulationWrappers, numStats, statistics);
        // Add utilities to table as the last column.
        for (int u = 0; u < numTables; u++) {
            for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
                statTables[u][t][numStats] = utilities[t];
            }
        }
        printStats(statTables, statistics, Mode.WorstCase, newOrder, algorithmSimulationWrappers, algorithmWrappers, simulationWrappers, utilities, parameters);
    }
    out.close();
}
Also used : ExternalAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.ExternalAlgorithm) Statistic(edu.cmu.tetrad.algcomparison.statistic.Statistic) ExternalAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.ExternalAlgorithm) Algorithm(edu.cmu.tetrad.algcomparison.algorithm.Algorithm) MultiDataSetAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm) Simulation(edu.cmu.tetrad.algcomparison.simulation.Simulation)

Aggregations

Algorithm (edu.cmu.tetrad.algcomparison.algorithm.Algorithm)25 MultiDataSetAlgorithm (edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm)13 ExternalAlgorithm (edu.cmu.tetrad.algcomparison.algorithm.ExternalAlgorithm)12 Parameters (edu.cmu.tetrad.util.Parameters)10 Simulation (edu.cmu.tetrad.algcomparison.simulation.Simulation)9 ScoreWrapper (edu.cmu.tetrad.algcomparison.score.ScoreWrapper)8 Statistic (edu.cmu.tetrad.algcomparison.statistic.Statistic)8 Graph (edu.cmu.tetrad.graph.Graph)8 IndependenceWrapper (edu.cmu.tetrad.algcomparison.independence.IndependenceWrapper)7 DataSet (edu.cmu.tetrad.data.DataSet)7 Test (org.junit.Test)6 Simulations (edu.cmu.tetrad.algcomparison.simulation.Simulations)5 DagToPag (edu.cmu.tetrad.search.DagToPag)5 GeneralBootstrapTest (edu.pitt.dbmi.algo.bootstrap.GeneralBootstrapTest)5 Fges (edu.cmu.tetrad.algcomparison.algorithm.oracle.pattern.Fges)4 RandomForward (edu.cmu.tetrad.algcomparison.graph.RandomForward)4 BdeuScore (edu.cmu.tetrad.algcomparison.score.BdeuScore)4 SemBicScore (edu.cmu.tetrad.algcomparison.score.SemBicScore)4 LinearFisherModel (edu.cmu.tetrad.algcomparison.simulation.LinearFisherModel)4 LoadDataAndGraphs (edu.cmu.tetrad.algcomparison.simulation.LoadDataAndGraphs)4