Search in sources :

Example 1 with Simulation

use of edu.cmu.tetrad.algcomparison.simulation.Simulation in project tetrad by cmu-phil.

the class ExampleSave method main.

public static void main(String... args) {
    Parameters parameters = new Parameters();
    parameters.set("numRuns", 10);
    parameters.set("numMeasures", 100);
    parameters.set("avgDegree", 4);
    parameters.set("sampleSize", 100, 500, 1000);
    Simulation simulation = new SemSimulation(new RandomForward());
    Comparison comparison = new Comparison();
    comparison.setShowAlgorithmIndices(true);
    comparison.saveToFiles("comparison", simulation, parameters);
}
Also used : Parameters(edu.cmu.tetrad.util.Parameters) SemSimulation(edu.cmu.tetrad.algcomparison.simulation.SemSimulation) Simulation(edu.cmu.tetrad.algcomparison.simulation.Simulation) Comparison(edu.cmu.tetrad.algcomparison.Comparison) SemSimulation(edu.cmu.tetrad.algcomparison.simulation.SemSimulation) RandomForward(edu.cmu.tetrad.algcomparison.graph.RandomForward)

Example 2 with Simulation

use of edu.cmu.tetrad.algcomparison.simulation.Simulation in project tetrad by cmu-phil.

the class Save method main.

public static void main(String... args) {
    Parameters parameters = new Parameters();
    parameters.set("numRuns", 10);
    parameters.set("numMeasures", 50, 100);
    parameters.set("avgDegree", 4);
    parameters.set("sampleSize", 100, 500);
    parameters.set("numCategories", 3);
    parameters.set("percentDiscrete", 50);
    parameters.set("differentGraphs", true);
    Simulation simulation = new LeeHastieSimulation(new RandomForward());
    Comparison comparison = new Comparison();
    comparison.setShowAlgorithmIndices(true);
    comparison.saveToFiles("comparison", simulation, parameters);
}
Also used : Parameters(edu.cmu.tetrad.util.Parameters) SemSimulation(edu.cmu.tetrad.algcomparison.simulation.SemSimulation) Simulation(edu.cmu.tetrad.algcomparison.simulation.Simulation) LeeHastieSimulation(edu.cmu.tetrad.algcomparison.simulation.LeeHastieSimulation) LeeHastieSimulation(edu.cmu.tetrad.algcomparison.simulation.LeeHastieSimulation) Comparison(edu.cmu.tetrad.algcomparison.Comparison) RandomForward(edu.cmu.tetrad.algcomparison.graph.RandomForward)

Example 3 with Simulation

use of edu.cmu.tetrad.algcomparison.simulation.Simulation in project tetrad by cmu-phil.

the class ExampleSave method main.

public static void main(String... args) {
    Parameters parameters = new Parameters();
    parameters.set("numRuns", 10);
    parameters.set("numMeasures", 50, 100, 500);
    parameters.set("avgDegree", 2, 4, 6);
    parameters.set("sampleSize", 100, 500, 1000);
    parameters.set("differentGraphs", true);
    parameters.set("maxDegree", 100);
    parameters.set("maxIndegree", 100);
    parameters.set("maxOutdegree", 100);
    parameters.set("connected", false);
    parameters.set("coefLow", 0.2);
    parameters.set("coefHigh", 0.9);
    parameters.set("coefSymmetric", true);
    parameters.set("varLow", 1);
    parameters.set("varHigh", 3);
    parameters.set("randomizeColumns", true);
    NumberFormatUtil.getInstance().setNumberFormat(new DecimalFormat("0.000000"));
    Simulation simulation = new SemSimulation(new RandomForward());
    Comparison comparison = new Comparison();
    comparison.saveToFiles("/Users/user/comparison-data/condition_2", simulation, parameters);
}
Also used : Parameters(edu.cmu.tetrad.util.Parameters) SemSimulation(edu.cmu.tetrad.algcomparison.simulation.SemSimulation) Simulation(edu.cmu.tetrad.algcomparison.simulation.Simulation) Comparison(edu.cmu.tetrad.algcomparison.Comparison) SemSimulation(edu.cmu.tetrad.algcomparison.simulation.SemSimulation) DecimalFormat(java.text.DecimalFormat) RandomForward(edu.cmu.tetrad.algcomparison.graph.RandomForward)

Example 4 with Simulation

use of edu.cmu.tetrad.algcomparison.simulation.Simulation in project tetrad by cmu-phil.

the class TimeoutComparison method printStats.

private void printStats(double[][][] statTables, Statistics statistics, Mode mode, int[] newOrder, List<AlgorithmSimulationWrapper> algorithmSimulationWrappers, List<AlgorithmWrapper> algorithmWrappers, List<SimulationWrapper> simulationWrappers, double[] utilities, Parameters parameters) {
    if (mode == Mode.Average) {
        out.println("AVERAGE STATISTICS");
    } else if (mode == Mode.StandardDeviation) {
        out.println("STANDARD DEVIATIONS");
    } else if (mode == Mode.WorstCase) {
        out.println("WORST CASE");
    } else {
        throw new IllegalStateException();
    }
    int numTables = statTables.length;
    int numStats = statistics.size();
    NumberFormat nf = new DecimalFormat("0.00");
    NumberFormat smallNf = new DecimalFormat("0.00E0");
    out.println();
    for (int u = 0; u < numTables; u++) {
        if (!graphTypeUsed[u]) {
            continue;
        }
        int rows = algorithmSimulationWrappers.size() + 1;
        int cols = (isShowSimulationIndices() ? 1 : 0) + (isShowAlgorithmIndices() ? 1 : 0) + numStats + (isShowUtilities() ? 1 : 0);
        TextTable table = new TextTable(rows, cols);
        table.setTabDelimited(isTabDelimitedTables());
        int initialColumn = 0;
        if (isShowSimulationIndices()) {
            table.setToken(0, initialColumn, "Sim");
            for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
                Simulation simulation = algorithmSimulationWrappers.get(newOrder[t]).getSimulationWrapper();
                table.setToken(t + 1, initialColumn, "" + (simulationWrappers.indexOf(simulation) + 1));
            }
            initialColumn++;
        }
        if (isShowAlgorithmIndices()) {
            table.setToken(0, initialColumn, "Alg");
            for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
                AlgorithmWrapper algorithm = algorithmSimulationWrappers.get(newOrder[t]).getAlgorithmWrapper();
                table.setToken(t + 1, initialColumn, "" + (algorithmWrappers.indexOf(algorithm) + 1));
            }
            initialColumn++;
        }
        for (int statIndex = 0; statIndex < numStats; statIndex++) {
            String statLabel = statistics.getStatistics().get(statIndex).getAbbreviation();
            table.setToken(0, initialColumn + statIndex, statLabel);
        }
        if (isShowUtilities()) {
            table.setToken(0, initialColumn + numStats, "U");
        }
        for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
            for (int statIndex = 0; statIndex < numStats; statIndex++) {
                Statistic statistic = statistics.getStatistics().get(statIndex);
                final AlgorithmWrapper algorithmWrapper = algorithmSimulationWrappers.get(newOrder[t]).getAlgorithmWrapper();
                final SimulationWrapper simulationWrapper = algorithmSimulationWrappers.get(newOrder[t]).getSimulationWrapper();
                Algorithm algorithm = algorithmWrapper.getAlgorithm();
                Simulation simulation = simulationWrapper.getSimulation();
                if (algorithm instanceof HasParameterValues) {
                    parameters.putAll(((HasParameterValues) algorithm).getParameterValues());
                }
                if (simulation instanceof HasParameterValues) {
                    parameters.putAll(((HasParameterValues) simulation).getParameterValues());
                }
                final String abbreviation = statistic.getAbbreviation();
                Object[] o = parameters.getValues(abbreviation);
                if (o.length == 1 && o[0] instanceof String) {
                    table.setToken(t + 1, initialColumn + statIndex, (String) o[0]);
                    continue;
                }
                double stat = statTables[u][newOrder[t]][statIndex];
                if (stat == 0.0) {
                    table.setToken(t + 1, initialColumn + statIndex, "-");
                } else if (stat == Double.POSITIVE_INFINITY) {
                    table.setToken(t + 1, initialColumn + statIndex, "Yes");
                } else if (stat == Double.NEGATIVE_INFINITY) {
                    table.setToken(t + 1, initialColumn + statIndex, "No");
                } else if (Double.isNaN(stat)) {
                    table.setToken(t + 1, initialColumn + statIndex, "*");
                } else {
                    table.setToken(t + 1, initialColumn + statIndex, Math.abs(stat) < Math.pow(10, -smallNf.getMaximumFractionDigits()) && stat != 0 ? smallNf.format(stat) : nf.format(stat));
                }
            }
            if (isShowUtilities()) {
                table.setToken(t + 1, initialColumn + numStats, nf.format(utilities[newOrder[t]]));
            }
        }
        out.println(getHeader(u));
        out.println();
        out.println(table);
    }
}
Also used : HasParameterValues(edu.cmu.tetrad.algcomparison.utils.HasParameterValues) DecimalFormat(java.text.DecimalFormat) ExternalAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.ExternalAlgorithm) Algorithm(edu.cmu.tetrad.algcomparison.algorithm.Algorithm) MultiDataSetAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm) Simulation(edu.cmu.tetrad.algcomparison.simulation.Simulation) Statistic(edu.cmu.tetrad.algcomparison.statistic.Statistic) TextTable(edu.cmu.tetrad.util.TextTable) NumberFormat(java.text.NumberFormat)

Example 5 with Simulation

use of edu.cmu.tetrad.algcomparison.simulation.Simulation in project tetrad by cmu-phil.

the class Comparison method compareFromSimulations.

/**
 * Compares algorithms.
 *
 * @param resultsPath Path to the file where the output should be printed.
 * @param simulations The list of simulationWrapper that is used to generate graphs and data for the comparison.
 * @param algorithms  The list of algorithms to be compared.
 * @param statistics  The list of statistics on which to compare the algorithm, and their utility weights.
 */
public void compareFromSimulations(String resultsPath, Simulations simulations, String outputFileName, Algorithms algorithms, Statistics statistics, Parameters parameters) {
    this.resultsPath = resultsPath;
    // Create output file.
    try {
        File dir = new File(resultsPath);
        dir.mkdirs();
        File file = new File(dir, outputFileName);
        this.out = new PrintStream(new FileOutputStream(file));
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
    out.println(new Date());
    // Set up simulations--create data and graphs, read in parameters. The parameters
    // are set in the parameters object.
    List<SimulationWrapper> simulationWrappers = new ArrayList<>();
    int numRuns = parameters.getInt("numRuns");
    for (Simulation simulation : simulations.getSimulations()) {
        List<SimulationWrapper> wrappers = getSimulationWrappers(simulation, parameters);
        for (SimulationWrapper wrapper : wrappers) {
            wrapper.createData(wrapper.getSimulationSpecificParameters());
            simulationWrappers.add(wrapper);
        }
    }
    // Set up the algorithms.
    List<AlgorithmWrapper> algorithmWrappers = new ArrayList<>();
    for (Algorithm algorithm : algorithms.getAlgorithms()) {
        List<Integer> _dims = new ArrayList<>();
        List<String> varyingParameters = new ArrayList<>();
        final List<String> parameters1 = algorithm.getParameters();
        for (String name : parameters1) {
            if (parameters.getNumValues(name) > 1) {
                _dims.add(parameters.getNumValues(name));
                varyingParameters.add(name);
            }
        }
        if (varyingParameters.isEmpty()) {
            algorithmWrappers.add(new AlgorithmWrapper(algorithm, parameters));
        } else {
            int[] dims = new int[_dims.size()];
            for (int i = 0; i < _dims.size(); i++) dims[i] = _dims.get(i);
            CombinationGenerator gen = new CombinationGenerator(dims);
            int[] choice;
            while ((choice = gen.next()) != null) {
                AlgorithmWrapper wrapper = new AlgorithmWrapper(algorithm, parameters);
                for (int h = 0; h < dims.length; h++) {
                    String parameter = varyingParameters.get(h);
                    Object[] values = parameters.getValues(parameter);
                    Object value = values[choice[h]];
                    wrapper.setValue(parameter, value);
                }
                algorithmWrappers.add(wrapper);
            }
        }
    }
    // Create the algorithm-simulation wrappers for every combination of algorithm and
    // simulation.
    List<AlgorithmSimulationWrapper> algorithmSimulationWrappers = new ArrayList<>();
    for (SimulationWrapper simulationWrapper : simulationWrappers) {
        for (AlgorithmWrapper algorithmWrapper : algorithmWrappers) {
            DataType algDataType = algorithmWrapper.getDataType();
            DataType simDataType = simulationWrapper.getDataType();
            if (!(algDataType == DataType.Mixed || (algDataType == simDataType))) {
                System.out.println("Type mismatch: " + algorithmWrapper.getDescription() + " / " + simulationWrapper.getDescription());
            }
            if (algorithmWrapper.getAlgorithm() instanceof ExternalAlgorithm) {
                ExternalAlgorithm external = (ExternalAlgorithm) algorithmWrapper.getAlgorithm();
                // external.setSimulation(simulationWrapper.getSimulation());
                // external.setPath(dirs.get(simulationWrappers.indexOf(simulationWrapper)));
                // external.setPath(resultsPath);
                external.setSimIndex(simulationWrappers.indexOf(external.getSimulation()));
            }
            algorithmSimulationWrappers.add(new AlgorithmSimulationWrapper(algorithmWrapper, simulationWrapper));
        }
    }
    // Run all of the algorithms and compile statistics.
    double[][][][] allStats = calcStats(algorithmSimulationWrappers, algorithmWrappers, simulationWrappers, statistics, numRuns);
    // Print out the preliminary information for statistics types, etc.
    if (allStats != null) {
        out.println();
        out.println("Statistics:");
        out.println();
        for (Statistic stat : statistics.getStatistics()) {
            out.println(stat.getAbbreviation() + " = " + stat.getDescription());
        }
    }
    out.println();
    if (allStats != null) {
        int numTables = allStats.length;
        int numStats = allStats[0][0].length - 1;
        double[][][] statTables = calcStatTables(allStats, Mode.Average, numTables, algorithmSimulationWrappers, numStats, statistics);
        double[] utilities = calcUtilities(statistics, algorithmSimulationWrappers, statTables[0]);
        // Add utilities to table as the last column.
        for (int u = 0; u < numTables; u++) {
            for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
                statTables[u][t][numStats] = utilities[t];
            }
        }
        int[] newOrder;
        if (isSortByUtility()) {
            newOrder = sort(algorithmSimulationWrappers, utilities);
        } else {
            newOrder = new int[algorithmSimulationWrappers.size()];
            for (int q = 0; q < algorithmSimulationWrappers.size(); q++) {
                newOrder[q] = q;
            }
        }
        out.println("Simulations:");
        out.println();
        // if (simulationWrappers.size() == 1) {
        // out.println(simulationWrappers.get(0).getDescription());
        // } else {
        int i = 0;
        for (SimulationWrapper simulation : simulationWrappers) {
            out.print("Simulation " + (++i) + ": ");
            out.println(simulation.getDescription());
            out.println();
            printParameters(simulation.getParameters(), simulation.getSimulationSpecificParameters(), out);
            // for (String param : simulation.getParameters()) {
            // out.println(param + " = " + simulation.getValue(param));
            // }
            out.println();
        }
        // }
        out.println("Algorithms:");
        out.println();
        for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
            AlgorithmSimulationWrapper wrapper = algorithmSimulationWrappers.get(t);
            if (wrapper.getSimulationWrapper() == simulationWrappers.get(0)) {
                out.println((t + 1) + ". " + wrapper.getAlgorithmWrapper().getDescription());
            }
        }
        if (isSortByUtility()) {
            out.println();
            out.println("Sorting by utility, high to low.");
        }
        if (isShowUtilities()) {
            out.println();
            out.println("Weighting of statistics:");
            out.println();
            out.println("U = ");
            for (Statistic stat : statistics.getStatistics()) {
                String statName = stat.getAbbreviation();
                double weight = statistics.getWeight(stat);
                if (weight != 0.0) {
                    out.println("    " + weight + " * f(" + statName + ")");
                }
            }
            out.println();
            out.println("...normed to range between 0 and 1.");
            out.println();
            out.println("Note that f for each statistic is a function that maps the statistic to the ");
            out.println("interval [0, 1], with higher being better.");
        }
        out.println();
        out.println("Graphs are being compared to the " + comparisonGraph.toString().replace("_", " ") + ".");
        out.println();
        // Add utilities to table as the last column.
        for (int u = 0; u < numTables; u++) {
            for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
                statTables[u][t][numStats] = utilities[t];
            }
        }
        // Print all of the tables.
        printStats(statTables, statistics, Mode.Average, newOrder, algorithmSimulationWrappers, algorithmWrappers, simulationWrappers, utilities, parameters);
        statTables = calcStatTables(allStats, Mode.StandardDeviation, numTables, algorithmSimulationWrappers, numStats, statistics);
        printStats(statTables, statistics, Mode.StandardDeviation, newOrder, algorithmSimulationWrappers, algorithmWrappers, simulationWrappers, utilities, parameters);
        statTables = calcStatTables(allStats, Mode.WorstCase, numTables, algorithmSimulationWrappers, numStats, statistics);
        // Add utilities to table as the last column.
        for (int u = 0; u < numTables; u++) {
            for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
                statTables[u][t][numStats] = utilities[t];
            }
        }
        printStats(statTables, statistics, Mode.WorstCase, newOrder, algorithmSimulationWrappers, algorithmWrappers, simulationWrappers, utilities, parameters);
    }
    out.close();
}
Also used : ExternalAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.ExternalAlgorithm) Statistic(edu.cmu.tetrad.algcomparison.statistic.Statistic) ExternalAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.ExternalAlgorithm) Algorithm(edu.cmu.tetrad.algcomparison.algorithm.Algorithm) MultiDataSetAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm) Simulation(edu.cmu.tetrad.algcomparison.simulation.Simulation)

Aggregations

Simulation (edu.cmu.tetrad.algcomparison.simulation.Simulation)12 Algorithm (edu.cmu.tetrad.algcomparison.algorithm.Algorithm)9 ExternalAlgorithm (edu.cmu.tetrad.algcomparison.algorithm.ExternalAlgorithm)8 MultiDataSetAlgorithm (edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm)8 Statistic (edu.cmu.tetrad.algcomparison.statistic.Statistic)8 Parameters (edu.cmu.tetrad.util.Parameters)5 RandomForward (edu.cmu.tetrad.algcomparison.graph.RandomForward)4 SemSimulation (edu.cmu.tetrad.algcomparison.simulation.SemSimulation)4 HasParameters (edu.cmu.tetrad.algcomparison.utils.HasParameters)4 Comparison (edu.cmu.tetrad.algcomparison.Comparison)3 IndependenceWrapper (edu.cmu.tetrad.algcomparison.independence.IndependenceWrapper)3 ScoreWrapper (edu.cmu.tetrad.algcomparison.score.ScoreWrapper)3 FileNotFoundException (java.io.FileNotFoundException)3 IOException (java.io.IOException)3 DecimalFormat (java.text.DecimalFormat)3 ArrayList (java.util.ArrayList)3 ExecutionException (java.util.concurrent.ExecutionException)3 TimeoutException (java.util.concurrent.TimeoutException)3 ElapsedTime (edu.cmu.tetrad.algcomparison.statistic.ElapsedTime)2 ParameterColumn (edu.cmu.tetrad.algcomparison.statistic.ParameterColumn)2