Search in sources :

Example 6 with Simulation

use of edu.cmu.tetrad.algcomparison.simulation.Simulation in project tetrad by cmu-phil.

the class TimeoutComparison method compareFromSimulations.

/**
 * Compares algorithms.
 *
 * @param resultsPath Path to the file where the output should be printed.
 * @param simulations The list of simulationWrapper that is used to generate
 * graphs and data for the comparison.
 * @param algorithms The list of algorithms to be compared.
 * @param statistics The list of statistics on which to compare the
 * algorithm, and their utility weights.
 */
public void compareFromSimulations(String resultsPath, Simulations simulations, String outputFileName, Algorithms algorithms, Statistics statistics, Parameters parameters, long timeout, TimeUnit unit) {
    this.resultsPath = resultsPath;
    // Create output file.
    try {
        File dir = new File(resultsPath);
        dir.mkdirs();
        File file = new File(dir, outputFileName);
        this.out = new PrintStream(new FileOutputStream(file));
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
    out.println(new Date());
    // Set up simulations--create data and graphs, read in parameters. The parameters
    // are set in the parameters object.
    List<SimulationWrapper> simulationWrappers = new ArrayList<>();
    int numRuns = parameters.getInt("numRuns");
    for (Simulation simulation : simulations.getSimulations()) {
        List<SimulationWrapper> wrappers = getSimulationWrappers(simulation, parameters);
        for (SimulationWrapper wrapper : wrappers) {
            wrapper.createData(wrapper.getSimulationSpecificParameters());
            simulationWrappers.add(wrapper);
        }
    }
    // Set up the algorithms.
    List<AlgorithmWrapper> algorithmWrappers = new ArrayList<>();
    for (Algorithm algorithm : algorithms.getAlgorithms()) {
        List<Integer> _dims = new ArrayList<>();
        List<String> varyingParameters = new ArrayList<>();
        final List<String> parameters1 = algorithm.getParameters();
        for (String name : parameters1) {
            if (parameters.getNumValues(name) > 1) {
                _dims.add(parameters.getNumValues(name));
                varyingParameters.add(name);
            }
        }
        if (varyingParameters.isEmpty()) {
            algorithmWrappers.add(new AlgorithmWrapper(algorithm, parameters));
        } else {
            int[] dims = new int[_dims.size()];
            for (int i = 0; i < _dims.size(); i++) {
                dims[i] = _dims.get(i);
            }
            CombinationGenerator gen = new CombinationGenerator(dims);
            int[] choice;
            while ((choice = gen.next()) != null) {
                AlgorithmWrapper wrapper = new AlgorithmWrapper(algorithm, parameters);
                for (int h = 0; h < dims.length; h++) {
                    String parameter = varyingParameters.get(h);
                    Object[] values = parameters.getValues(parameter);
                    Object value = values[choice[h]];
                    wrapper.setValue(parameter, value);
                }
                algorithmWrappers.add(wrapper);
            }
        }
    }
    // Create the algorithm-simulation wrappers for every combination of algorithm and
    // simulation.
    List<AlgorithmSimulationWrapper> algorithmSimulationWrappers = new ArrayList<>();
    for (SimulationWrapper simulationWrapper : simulationWrappers) {
        for (AlgorithmWrapper algorithmWrapper : algorithmWrappers) {
            DataType algDataType = algorithmWrapper.getDataType();
            DataType simDataType = simulationWrapper.getDataType();
            if (!(algDataType == DataType.Mixed || (algDataType == simDataType))) {
                System.out.println("Type mismatch: " + algorithmWrapper.getDescription() + " / " + simulationWrapper.getDescription());
            }
            if (algorithmWrapper.getAlgorithm() instanceof ExternalAlgorithm) {
                ExternalAlgorithm external = (ExternalAlgorithm) algorithmWrapper.getAlgorithm();
                // external.setSimulation(simulationWrapper.getSimulation());
                // external.setPath(dirs.get(simulationWrappers.indexOf(simulationWrapper)));
                // external.setPath(resultsPath);
                external.setSimIndex(simulationWrappers.indexOf(external.getSimulation()));
            }
            algorithmSimulationWrappers.add(new AlgorithmSimulationWrapper(algorithmWrapper, simulationWrapper));
        }
    }
    // Run all of the algorithms and compile statistics.
    double[][][][] allStats = calcStats(algorithmSimulationWrappers, algorithmWrappers, simulationWrappers, statistics, numRuns, timeout, unit);
    // Print out the preliminary information for statistics types, etc.
    if (allStats != null) {
        out.println();
        out.println("Statistics:");
        out.println();
        for (Statistic stat : statistics.getStatistics()) {
            out.println(stat.getAbbreviation() + " = " + stat.getDescription());
        }
    }
    out.println();
    // out.println();
    if (allStats != null) {
        int numTables = allStats.length;
        int numStats = allStats[0][0].length - 1;
        double[][][] statTables = calcStatTables(allStats, Mode.Average, numTables, algorithmSimulationWrappers, numStats, statistics);
        double[] utilities = calcUtilities(statistics, algorithmSimulationWrappers, statTables[0]);
        // Add utilities to table as the last column.
        for (int u = 0; u < numTables; u++) {
            for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
                statTables[u][t][numStats] = utilities[t];
            }
        }
        int[] newOrder;
        if (isSortByUtility()) {
            newOrder = sort(algorithmSimulationWrappers, utilities);
        } else {
            newOrder = new int[algorithmSimulationWrappers.size()];
            for (int q = 0; q < algorithmSimulationWrappers.size(); q++) {
                newOrder[q] = q;
            }
        }
        out.println("Simulations:");
        out.println();
        // if (simulationWrappers.size() == 1) {
        // out.println(simulationWrappers.get(0).getDescription());
        // } else {
        int i = 0;
        for (SimulationWrapper simulation : simulationWrappers) {
            out.print("Simulation " + (++i) + ": ");
            out.println(simulation.getDescription());
            out.println();
            printParameters(simulation.getParameters(), simulation.getSimulationSpecificParameters(), out);
            // for (String param : simulation.getParameters()) {
            // out.println(param + " = " + simulation.getValue(param));
            // }
            out.println();
        }
        // }
        out.println("Algorithms:");
        out.println();
        for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
            AlgorithmSimulationWrapper wrapper = algorithmSimulationWrappers.get(t);
            if (wrapper.getSimulationWrapper() == simulationWrappers.get(0)) {
                out.println((t + 1) + ". " + wrapper.getAlgorithmWrapper().getDescription());
            }
        }
        if (isSortByUtility()) {
            out.println();
            out.println("Sorting by utility, high to low.");
        }
        if (isShowUtilities()) {
            out.println();
            out.println("Weighting of statistics:");
            out.println();
            out.println("U = ");
            for (Statistic stat : statistics.getStatistics()) {
                String statName = stat.getAbbreviation();
                double weight = statistics.getWeight(stat);
                if (weight != 0.0) {
                    out.println("    " + weight + " * f(" + statName + ")");
                }
            }
            out.println();
            out.println("...normed to range between 0 and 1.");
            out.println();
            out.println("Note that f for each statistic is a function that maps the statistic to the ");
            out.println("interval [0, 1], with higher being better.");
        }
        out.println();
        out.println("Graphs are being compared to the " + comparisonGraph.toString().replace("_", " ") + ".");
        out.println();
        // Add utilities to table as the last column.
        for (int u = 0; u < numTables; u++) {
            for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
                statTables[u][t][numStats] = utilities[t];
            }
        }
        // Print all of the tables.
        printStats(statTables, statistics, Mode.Average, newOrder, algorithmSimulationWrappers, algorithmWrappers, simulationWrappers, utilities, parameters);
        statTables = calcStatTables(allStats, Mode.StandardDeviation, numTables, algorithmSimulationWrappers, numStats, statistics);
        printStats(statTables, statistics, Mode.StandardDeviation, newOrder, algorithmSimulationWrappers, algorithmWrappers, simulationWrappers, utilities, parameters);
        statTables = calcStatTables(allStats, Mode.WorstCase, numTables, algorithmSimulationWrappers, numStats, statistics);
        // Add utilities to table as the last column.
        for (int u = 0; u < numTables; u++) {
            for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
                statTables[u][t][numStats] = utilities[t];
            }
        }
        printStats(statTables, statistics, Mode.WorstCase, newOrder, algorithmSimulationWrappers, algorithmWrappers, simulationWrappers, utilities, parameters);
    }
    out.close();
}
Also used : ArrayList(java.util.ArrayList) ExternalAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.ExternalAlgorithm) Statistic(edu.cmu.tetrad.algcomparison.statistic.Statistic) DataType(edu.cmu.tetrad.data.DataType) PrintStream(java.io.PrintStream) CombinationGenerator(edu.cmu.tetrad.util.CombinationGenerator) ExternalAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.ExternalAlgorithm) Algorithm(edu.cmu.tetrad.algcomparison.algorithm.Algorithm) MultiDataSetAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm) TimeoutException(java.util.concurrent.TimeoutException) FileNotFoundException(java.io.FileNotFoundException) IOException(java.io.IOException) ExecutionException(java.util.concurrent.ExecutionException) Date(java.util.Date) Simulation(edu.cmu.tetrad.algcomparison.simulation.Simulation) FileOutputStream(java.io.FileOutputStream) File(java.io.File)

Example 7 with Simulation

use of edu.cmu.tetrad.algcomparison.simulation.Simulation in project tetrad by cmu-phil.

the class TimeoutComparison method doRun.

private void doRun(List<AlgorithmSimulationWrapper> algorithmSimulationWrappers, List<AlgorithmWrapper> algorithmWrappers, List<SimulationWrapper> simulationWrappers, Statistics statistics, int numGraphTypes, double[][][][] allStats, Run run) {
    System.out.println();
    System.out.println("Run " + (run.getRunIndex() + 1));
    System.out.println();
    AlgorithmSimulationWrapper algorithmSimulationWrapper = algorithmSimulationWrappers.get(run.getAlgSimIndex());
    AlgorithmWrapper algorithmWrapper = algorithmSimulationWrapper.getAlgorithmWrapper();
    SimulationWrapper simulationWrapper = algorithmSimulationWrapper.getSimulationWrapper();
    DataModel data = simulationWrapper.getDataModel(run.getRunIndex());
    Graph trueGraph = simulationWrapper.getTrueGraph(run.getRunIndex());
    System.out.println((run.getAlgSimIndex() + 1) + ". " + algorithmWrapper.getDescription() + " simulationWrapper: " + simulationWrapper.getDescription());
    long start = System.currentTimeMillis();
    Graph out;
    try {
        Algorithm algorithm = algorithmWrapper.getAlgorithm();
        Simulation simulation = simulationWrapper.getSimulation();
        if (algorithm instanceof HasKnowledge && simulation instanceof HasKnowledge) {
            ((HasKnowledge) algorithm).setKnowledge(((HasKnowledge) simulation).getKnowledge());
        }
        if (algorithmWrapper.getAlgorithm() instanceof ExternalAlgorithm) {
            ExternalAlgorithm external = (ExternalAlgorithm) algorithmWrapper.getAlgorithm();
            external.setSimulation(simulationWrapper.getSimulation());
            external.setPath(resultsPath);
            external.setSimIndex(simulationWrappers.indexOf(simulationWrapper));
        }
        if (algorithm instanceof MultiDataSetAlgorithm) {
            List<Integer> indices = new ArrayList<>();
            int numDataModels = simulationWrapper.getSimulation().getNumDataModels();
            for (int i = 0; i < numDataModels; i++) {
                indices.add(i);
            }
            Collections.shuffle(indices);
            List<DataModel> dataModels = new ArrayList<>();
            int randomSelectionSize = algorithmWrapper.getAlgorithmSpecificParameters().getInt("randomSelectionSize");
            for (int i = 0; i < Math.min(numDataModels, randomSelectionSize); i++) {
                dataModels.add(simulationWrapper.getSimulation().getDataModel(indices.get(i)));
            }
            Parameters _params = algorithmWrapper.getAlgorithmSpecificParameters();
            out = ((MultiDataSetAlgorithm) algorithm).search(dataModels, _params);
        } else {
            DataModel dataModel = copyData ? data.copy() : data;
            Parameters _params = algorithmWrapper.getAlgorithmSpecificParameters();
            out = algorithm.search(dataModel, _params);
        }
    } catch (Exception e) {
        System.out.println("Could not run " + algorithmWrapper.getDescription());
        e.printStackTrace();
        return;
    }
    int simIndex = simulationWrappers.indexOf(simulationWrapper) + 1;
    int algIndex = algorithmWrappers.indexOf(algorithmWrapper) + 1;
    long stop = System.currentTimeMillis();
    long elapsed = stop - start;
    saveGraph(resultsPath, out, run.getRunIndex(), simIndex, algIndex, algorithmWrapper, elapsed);
    if (trueGraph != null) {
        out = GraphUtils.replaceNodes(out, trueGraph.getNodes());
    }
    if (algorithmWrapper.getAlgorithm() instanceof ExternalAlgorithm) {
        ExternalAlgorithm extAlg = (ExternalAlgorithm) algorithmWrapper.getAlgorithm();
        extAlg.setSimIndex(simulationWrappers.indexOf(simulationWrapper));
        extAlg.setSimulation(simulationWrapper.getSimulation());
        extAlg.setPath(resultsPath);
        elapsed = extAlg.getElapsedTime(data, simulationWrapper.getSimulationSpecificParameters());
    }
    Graph[] est = new Graph[numGraphTypes];
    Graph comparisonGraph;
    if (this.comparisonGraph == ComparisonGraph.true_DAG) {
        comparisonGraph = new EdgeListGraph(trueGraph);
    } else if (this.comparisonGraph == ComparisonGraph.Pattern_of_the_true_DAG) {
        comparisonGraph = SearchGraphUtils.patternForDag(new EdgeListGraph(trueGraph));
    } else if (this.comparisonGraph == ComparisonGraph.PAG_of_the_true_DAG) {
        comparisonGraph = new DagToPag(new EdgeListGraph(trueGraph)).convert();
    } else {
        throw new IllegalArgumentException("Unrecognized graph type.");
    }
    // Graph comparisonGraph = trueGraph == null ? null : algorithmSimulationWrapper.getComparisonGraph(trueGraph);
    est[0] = out;
    graphTypeUsed[0] = true;
    if (data.isMixed()) {
        est[1] = getSubgraph(out, true, true, data);
        est[2] = getSubgraph(out, true, false, data);
        est[3] = getSubgraph(out, false, false, data);
        graphTypeUsed[1] = true;
        graphTypeUsed[2] = true;
        graphTypeUsed[3] = true;
    }
    Graph[] truth = new Graph[numGraphTypes];
    truth[0] = comparisonGraph;
    if (data.isMixed() && comparisonGraph != null) {
        truth[1] = getSubgraph(comparisonGraph, true, true, data);
        truth[2] = getSubgraph(comparisonGraph, true, false, data);
        truth[3] = getSubgraph(comparisonGraph, false, false, data);
    }
    if (comparisonGraph != null) {
        for (int u = 0; u < numGraphTypes; u++) {
            if (!graphTypeUsed[u]) {
                continue;
            }
            int statIndex = -1;
            for (Statistic _stat : statistics.getStatistics()) {
                statIndex++;
                if (_stat instanceof ParameterColumn) {
                    continue;
                }
                double stat;
                if (_stat instanceof ElapsedTime) {
                    stat = elapsed / 1000.0;
                } else {
                    stat = _stat.getValue(truth[u], est[u]);
                }
                allStats[u][run.getAlgSimIndex()][statIndex][run.getRunIndex()] = stat;
            }
        }
    }
}
Also used : ArrayList(java.util.ArrayList) ElapsedTime(edu.cmu.tetrad.algcomparison.statistic.ElapsedTime) HasKnowledge(edu.cmu.tetrad.algcomparison.utils.HasKnowledge) ExternalAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.ExternalAlgorithm) MultiDataSetAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm) Statistic(edu.cmu.tetrad.algcomparison.statistic.Statistic) ParameterColumn(edu.cmu.tetrad.algcomparison.statistic.ParameterColumn) Parameters(edu.cmu.tetrad.util.Parameters) HasParameters(edu.cmu.tetrad.algcomparison.utils.HasParameters) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) ExternalAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.ExternalAlgorithm) Algorithm(edu.cmu.tetrad.algcomparison.algorithm.Algorithm) MultiDataSetAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm) TimeoutException(java.util.concurrent.TimeoutException) FileNotFoundException(java.io.FileNotFoundException) IOException(java.io.IOException) ExecutionException(java.util.concurrent.ExecutionException) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) TakesInitialGraph(edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph) Graph(edu.cmu.tetrad.graph.Graph) DagToPag(edu.cmu.tetrad.search.DagToPag) Simulation(edu.cmu.tetrad.algcomparison.simulation.Simulation) DataModel(edu.cmu.tetrad.data.DataModel)

Example 8 with Simulation

use of edu.cmu.tetrad.algcomparison.simulation.Simulation in project tetrad by cmu-phil.

the class TimeoutComparison method configuration.

/**
 */
public void configuration(String path) {
    try {
        new File(path).mkdirs();
        PrintStream out = new PrintStream(new FileOutputStream(new File(path, "Configuration.txt")));
        Parameters allParams = new Parameters();
        List<Class> algorithms = new ArrayList<>();
        List<Class> statistics = new ArrayList<>();
        List<Class> independenceWrappers = new ArrayList<>();
        List<Class> scoreWrappers = new ArrayList<>();
        List<Class> simulations = new ArrayList<>();
        algorithms.addAll(getClasses(Algorithm.class));
        statistics.addAll(getClasses(Statistic.class));
        independenceWrappers.addAll(getClasses(IndependenceWrapper.class));
        scoreWrappers.addAll(getClasses(ScoreWrapper.class));
        simulations.addAll(getClasses(Simulation.class));
        out.println("Available Algorithms:");
        out.println();
        out.println("Algorithms that take an independence test (using an example independence test):");
        out.println();
        for (Class clazz : new ArrayList<>(algorithms)) {
            if (Experimental.class.isAssignableFrom(clazz)) {
                continue;
            }
            Constructor[] constructors = clazz.getConstructors();
            for (Constructor constructor : constructors) {
                if (constructor.getParameterTypes().length == 1 && constructor.getParameterTypes()[0] == IndependenceWrapper.class) {
                    Algorithm algorithm = (Algorithm) constructor.newInstance(FisherZ.class.newInstance());
                    out.println(clazz.getSimpleName() + ": " + algorithm.getDescription());
                    if (HasParameters.class.isAssignableFrom(clazz)) {
                        printParameters(algorithm.getParameters(), allParams, out);
                    }
                    if (TakesInitialGraph.class.isAssignableFrom(clazz)) {
                        out.println("\t" + clazz.getSimpleName() + " can take an initial graph from some other algorithm as input");
                    }
                }
            }
        }
        out.println();
        out.println("Algorithms that take a score (using an example score):");
        out.println();
        for (Class clazz : new ArrayList<>(algorithms)) {
            if (Experimental.class.isAssignableFrom(clazz)) {
                continue;
            }
            Constructor[] constructors = clazz.getConstructors();
            for (Constructor constructor : constructors) {
                if (constructor.getParameterTypes().length == 1 && constructor.getParameterTypes()[0] == ScoreWrapper.class) {
                    Algorithm algorithm = (Algorithm) constructor.newInstance(BdeuScore.class.newInstance());
                    out.println(clazz.getSimpleName() + ": " + algorithm.getDescription());
                    if (HasParameters.class.isAssignableFrom(clazz)) {
                        printParameters(algorithm.getParameters(), allParams, out);
                    }
                }
            }
        }
        out.println();
        out.println("Algorithms with blank constructor:");
        out.println();
        for (Class clazz : new ArrayList<>(algorithms)) {
            if (Experimental.class.isAssignableFrom(clazz)) {
                continue;
            }
            Constructor[] constructors = clazz.getConstructors();
            for (Constructor constructor : constructors) {
                if (constructor.getParameterTypes().length == 0) {
                    Algorithm algorithm = (Algorithm) constructor.newInstance();
                    out.println(clazz.getSimpleName() + ": " + algorithm.getDescription());
                    if (HasParameters.class.isAssignableFrom(clazz)) {
                        printParameters(algorithm.getParameters(), allParams, out);
                    }
                }
            }
        }
        out.println();
        out.println("Available Statistics:");
        out.println();
        for (Class clazz : statistics) {
            if (Experimental.class.isAssignableFrom(clazz)) {
                continue;
            }
            Constructor[] constructors = clazz.getConstructors();
            for (Constructor constructor : constructors) {
                if (constructor.getParameterTypes().length == 0) {
                    Statistic statistic = (Statistic) constructor.newInstance();
                    out.println(clazz.getSimpleName() + ": " + statistic.getDescription());
                }
            }
        }
        out.println();
        out.println("Available Independence Tests:");
        out.println();
        for (Class clazz : independenceWrappers) {
            if (Experimental.class.isAssignableFrom(clazz)) {
                continue;
            }
            Constructor[] constructors = clazz.getConstructors();
            for (Constructor constructor : constructors) {
                if (constructor.getParameterTypes().length == 0) {
                    IndependenceWrapper independence = (IndependenceWrapper) constructor.newInstance();
                    out.println(clazz.getSimpleName() + ": " + independence.getDescription());
                    if (HasParameters.class.isAssignableFrom(clazz)) {
                        printParameters(independence.getParameters(), allParams, out);
                    }
                }
            }
        }
        out.println();
        out.println("Available Scores:");
        out.println();
        for (Class clazz : scoreWrappers) {
            if (Experimental.class.isAssignableFrom(clazz)) {
                continue;
            }
            Constructor[] constructors = clazz.getConstructors();
            for (Constructor constructor : constructors) {
                if (constructor.getParameterTypes().length == 0) {
                    ScoreWrapper score = (ScoreWrapper) constructor.newInstance();
                    out.println(clazz.getSimpleName() + ": " + score.getDescription());
                    if (HasParameters.class.isAssignableFrom(clazz)) {
                        printParameters(score.getParameters(), allParams, out);
                    }
                }
            }
        }
        out.println();
        out.println("Available Simulations:");
        out.println();
        for (Class clazz : simulations) {
            if (Experimental.class.isAssignableFrom(clazz)) {
                continue;
            }
            Constructor[] constructors = clazz.getConstructors();
            for (Constructor constructor : constructors) {
                if (constructor.getParameterTypes().length == 0) {
                    Simulation simulation = (Simulation) constructor.newInstance();
                    out.println(clazz.getSimpleName() + ": " + simulation.getDescription());
                    if (HasParameters.class.isAssignableFrom(clazz)) {
                        printParameters(simulation.getParameters(), allParams, out);
                    }
                }
            }
        }
        out.println();
        out.close();
    } catch (Exception e) {
        e.printStackTrace();
    }
}
Also used : PrintStream(java.io.PrintStream) Parameters(edu.cmu.tetrad.util.Parameters) HasParameters(edu.cmu.tetrad.algcomparison.utils.HasParameters) ScoreWrapper(edu.cmu.tetrad.algcomparison.score.ScoreWrapper) Constructor(java.lang.reflect.Constructor) ArrayList(java.util.ArrayList) ExternalAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.ExternalAlgorithm) Algorithm(edu.cmu.tetrad.algcomparison.algorithm.Algorithm) MultiDataSetAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm) TimeoutException(java.util.concurrent.TimeoutException) FileNotFoundException(java.io.FileNotFoundException) IOException(java.io.IOException) ExecutionException(java.util.concurrent.ExecutionException) IndependenceWrapper(edu.cmu.tetrad.algcomparison.independence.IndependenceWrapper) Statistic(edu.cmu.tetrad.algcomparison.statistic.Statistic) Simulation(edu.cmu.tetrad.algcomparison.simulation.Simulation) FileOutputStream(java.io.FileOutputStream) File(java.io.File)

Example 9 with Simulation

use of edu.cmu.tetrad.algcomparison.simulation.Simulation in project tetrad by cmu-phil.

the class Comparison method configuration.

/**
 */
public void configuration(String path) {
    try {
        new File(path).mkdirs();
        PrintStream out = new PrintStream(new FileOutputStream(new File(path, "Configuration.txt")));
        Parameters allParams = new Parameters();
        List<Class> algorithms = new ArrayList<>();
        List<Class> statistics = new ArrayList<>();
        List<Class> independenceWrappers = new ArrayList<>();
        List<Class> scoreWrappers = new ArrayList<>();
        List<Class> simulations = new ArrayList<>();
        algorithms.addAll(getClasses(Algorithm.class));
        statistics.addAll(getClasses(Statistic.class));
        independenceWrappers.addAll(getClasses(IndependenceWrapper.class));
        scoreWrappers.addAll(getClasses(ScoreWrapper.class));
        simulations.addAll(getClasses(Simulation.class));
        out.println("Available Algorithms:");
        out.println();
        out.println("Algorithms that take an independence test (using an example independence test):");
        out.println();
        for (Class clazz : new ArrayList<>(algorithms)) {
            if (Experimental.class.isAssignableFrom(clazz)) {
                continue;
            }
            Constructor[] constructors = clazz.getConstructors();
            for (Constructor constructor : constructors) {
                if (constructor.getParameterTypes().length == 1 && constructor.getParameterTypes()[0] == IndependenceWrapper.class) {
                    Algorithm algorithm = (Algorithm) constructor.newInstance(FisherZ.class.newInstance());
                    out.println(clazz.getSimpleName() + ": " + algorithm.getDescription());
                    if (HasParameters.class.isAssignableFrom(clazz)) {
                        printParameters(algorithm.getParameters(), allParams, out);
                    }
                    if (TakesInitialGraph.class.isAssignableFrom(clazz)) {
                        out.println("\t" + clazz.getSimpleName() + " can take an initial graph from some other algorithm as input");
                    }
                }
            }
        }
        out.println();
        out.println("Algorithms that take a score (using an example score):");
        out.println();
        for (Class clazz : new ArrayList<>(algorithms)) {
            if (Experimental.class.isAssignableFrom(clazz)) {
                continue;
            }
            Constructor[] constructors = clazz.getConstructors();
            for (Constructor constructor : constructors) {
                if (constructor.getParameterTypes().length == 1 && constructor.getParameterTypes()[0] == ScoreWrapper.class) {
                    Algorithm algorithm = (Algorithm) constructor.newInstance(BdeuScore.class.newInstance());
                    out.println(clazz.getSimpleName() + ": " + algorithm.getDescription());
                    if (HasParameters.class.isAssignableFrom(clazz)) {
                        printParameters(algorithm.getParameters(), allParams, out);
                    }
                }
            }
        }
        out.println();
        out.println("Algorithms with blank constructor:");
        out.println();
        for (Class clazz : new ArrayList<>(algorithms)) {
            if (Experimental.class.isAssignableFrom(clazz)) {
                continue;
            }
            Constructor[] constructors = clazz.getConstructors();
            for (Constructor constructor : constructors) {
                if (constructor.getParameterTypes().length == 0) {
                    Algorithm algorithm = (Algorithm) constructor.newInstance();
                    out.println(clazz.getSimpleName() + ": " + algorithm.getDescription());
                    if (HasParameters.class.isAssignableFrom(clazz)) {
                        printParameters(algorithm.getParameters(), allParams, out);
                    }
                }
            }
        }
        out.println();
        out.println("Available Statistics:");
        out.println();
        for (Class clazz : statistics) {
            if (Experimental.class.isAssignableFrom(clazz)) {
                continue;
            }
            Constructor[] constructors = clazz.getConstructors();
            for (Constructor constructor : constructors) {
                if (constructor.getParameterTypes().length == 0) {
                    Statistic statistic = (Statistic) constructor.newInstance();
                    out.println(clazz.getSimpleName() + ": " + statistic.getDescription());
                }
            }
        }
        out.println();
        out.println("Available Independence Tests:");
        out.println();
        for (Class clazz : independenceWrappers) {
            if (Experimental.class.isAssignableFrom(clazz)) {
                continue;
            }
            Constructor[] constructors = clazz.getConstructors();
            for (Constructor constructor : constructors) {
                if (constructor.getParameterTypes().length == 0) {
                    IndependenceWrapper independence = (IndependenceWrapper) constructor.newInstance();
                    out.println(clazz.getSimpleName() + ": " + independence.getDescription());
                    if (HasParameters.class.isAssignableFrom(clazz)) {
                        printParameters(independence.getParameters(), allParams, out);
                    }
                }
            }
        }
        out.println();
        out.println("Available Scores:");
        out.println();
        for (Class clazz : scoreWrappers) {
            if (Experimental.class.isAssignableFrom(clazz)) {
                continue;
            }
            Constructor[] constructors = clazz.getConstructors();
            for (Constructor constructor : constructors) {
                if (constructor.getParameterTypes().length == 0) {
                    ScoreWrapper score = (ScoreWrapper) constructor.newInstance();
                    out.println(clazz.getSimpleName() + ": " + score.getDescription());
                    if (HasParameters.class.isAssignableFrom(clazz)) {
                        printParameters(score.getParameters(), allParams, out);
                    }
                }
            }
        }
        out.println();
        out.println("Available Simulations:");
        out.println();
        for (Class clazz : simulations) {
            if (Experimental.class.isAssignableFrom(clazz)) {
                continue;
            }
            Constructor[] constructors = clazz.getConstructors();
            for (Constructor constructor : constructors) {
                if (constructor.getParameterTypes().length == 0) {
                    Simulation simulation = (Simulation) constructor.newInstance();
                    out.println(clazz.getSimpleName() + ": " + simulation.getDescription());
                    if (HasParameters.class.isAssignableFrom(clazz)) {
                        printParameters(simulation.getParameters(), allParams, out);
                    }
                }
            }
        }
        out.println();
        out.close();
    } catch (Exception e) {
        e.printStackTrace();
    }
}
Also used : HasParameters(edu.cmu.tetrad.algcomparison.utils.HasParameters) ScoreWrapper(edu.cmu.tetrad.algcomparison.score.ScoreWrapper) Constructor(java.lang.reflect.Constructor) ExternalAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.ExternalAlgorithm) Algorithm(edu.cmu.tetrad.algcomparison.algorithm.Algorithm) MultiDataSetAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm) IndependenceWrapper(edu.cmu.tetrad.algcomparison.independence.IndependenceWrapper) Statistic(edu.cmu.tetrad.algcomparison.statistic.Statistic) Simulation(edu.cmu.tetrad.algcomparison.simulation.Simulation)

Example 10 with Simulation

use of edu.cmu.tetrad.algcomparison.simulation.Simulation in project tetrad by cmu-phil.

the class Comparison method doRun.

private void doRun(List<AlgorithmSimulationWrapper> algorithmSimulationWrappers, List<AlgorithmWrapper> algorithmWrappers, List<SimulationWrapper> simulationWrappers, Statistics statistics, int numGraphTypes, double[][][][] allStats, Run run) {
    System.out.println();
    System.out.println("Run " + (run.getRunIndex() + 1));
    System.out.println();
    AlgorithmSimulationWrapper algorithmSimulationWrapper = algorithmSimulationWrappers.get(run.getAlgSimIndex());
    AlgorithmWrapper algorithmWrapper = algorithmSimulationWrapper.getAlgorithmWrapper();
    SimulationWrapper simulationWrapper = algorithmSimulationWrapper.getSimulationWrapper();
    DataModel data = simulationWrapper.getDataModel(run.getRunIndex());
    Graph trueGraph = simulationWrapper.getTrueGraph(run.getRunIndex());
    System.out.println((run.getAlgSimIndex() + 1) + ". " + algorithmWrapper.getDescription() + " simulationWrapper: " + simulationWrapper.getDescription());
    long start = System.currentTimeMillis();
    Graph out;
    try {
        Algorithm algorithm = algorithmWrapper.getAlgorithm();
        Simulation simulation = simulationWrapper.getSimulation();
        if (algorithm instanceof HasKnowledge && simulation instanceof HasKnowledge) {
            ((HasKnowledge) algorithm).setKnowledge(((HasKnowledge) simulation).getKnowledge());
        }
        if (algorithmWrapper.getAlgorithm() instanceof ExternalAlgorithm) {
            ExternalAlgorithm external = (ExternalAlgorithm) algorithmWrapper.getAlgorithm();
            external.setSimulation(simulationWrapper.getSimulation());
            external.setPath(resultsPath);
            external.setSimIndex(simulationWrappers.indexOf(simulationWrapper));
        }
        if (algorithm instanceof MultiDataSetAlgorithm) {
            List<Integer> indices = new ArrayList<>();
            int numDataModels = simulationWrapper.getSimulation().getNumDataModels();
            for (int i = 0; i < numDataModels; i++) indices.add(i);
            Collections.shuffle(indices);
            List<DataModel> dataModels = new ArrayList<>();
            int randomSelectionSize = algorithmWrapper.getAlgorithmSpecificParameters().getInt("randomSelectionSize");
            for (int i = 0; i < Math.min(numDataModels, randomSelectionSize); i++) {
                dataModels.add(simulationWrapper.getSimulation().getDataModel(indices.get(i)));
            }
            Parameters _params = algorithmWrapper.getAlgorithmSpecificParameters();
            out = ((MultiDataSetAlgorithm) algorithm).search(dataModels, _params);
        } else {
            DataModel dataModel = copyData ? data.copy() : data;
            Parameters _params = algorithmWrapper.getAlgorithmSpecificParameters();
            out = algorithm.search(dataModel, _params);
        }
    } catch (Exception e) {
        System.out.println("Could not run " + algorithmWrapper.getDescription());
        e.printStackTrace();
        return;
    }
    int simIndex = simulationWrappers.indexOf(simulationWrapper) + 1;
    int algIndex = algorithmWrappers.indexOf(algorithmWrapper) + 1;
    long stop = System.currentTimeMillis();
    long elapsed = stop - start;
    saveGraph(resultsPath, out, run.getRunIndex(), simIndex, algIndex, algorithmWrapper, elapsed);
    if (trueGraph != null) {
        out = GraphUtils.replaceNodes(out, trueGraph.getNodes());
    }
    if (algorithmWrapper.getAlgorithm() instanceof ExternalAlgorithm) {
        ExternalAlgorithm extAlg = (ExternalAlgorithm) algorithmWrapper.getAlgorithm();
        extAlg.setSimIndex(simulationWrappers.indexOf(simulationWrapper));
        extAlg.setSimulation(simulationWrapper.getSimulation());
        extAlg.setPath(resultsPath);
        elapsed = extAlg.getElapsedTime(data, simulationWrapper.getSimulationSpecificParameters());
    }
    Graph[] est = new Graph[numGraphTypes];
    Graph comparisonGraph;
    if (this.comparisonGraph == ComparisonGraph.true_DAG) {
        comparisonGraph = new EdgeListGraph(trueGraph);
    } else if (this.comparisonGraph == ComparisonGraph.Pattern_of_the_true_DAG) {
        comparisonGraph = SearchGraphUtils.patternForDag(new EdgeListGraph(trueGraph));
    } else if (this.comparisonGraph == ComparisonGraph.PAG_of_the_true_DAG) {
        comparisonGraph = new DagToPag(new EdgeListGraph(trueGraph)).convert();
    } else {
        throw new IllegalArgumentException("Unrecognized graph type.");
    }
    // Graph comparisonGraph = trueGraph == null ? null : algorithmSimulationWrapper.getComparisonGraph(trueGraph);
    est[0] = out;
    graphTypeUsed[0] = true;
    if (data.isMixed()) {
        est[1] = getSubgraph(out, true, true, data);
        est[2] = getSubgraph(out, true, false, data);
        est[3] = getSubgraph(out, false, false, data);
        graphTypeUsed[1] = true;
        graphTypeUsed[2] = true;
        graphTypeUsed[3] = true;
    }
    Graph[] truth = new Graph[numGraphTypes];
    truth[0] = comparisonGraph;
    if (data.isMixed() && comparisonGraph != null) {
        truth[1] = getSubgraph(comparisonGraph, true, true, data);
        truth[2] = getSubgraph(comparisonGraph, true, false, data);
        truth[3] = getSubgraph(comparisonGraph, false, false, data);
    }
    if (comparisonGraph != null) {
        for (int u = 0; u < numGraphTypes; u++) {
            if (!graphTypeUsed[u])
                continue;
            int statIndex = -1;
            for (Statistic _stat : statistics.getStatistics()) {
                statIndex++;
                if (_stat instanceof ParameterColumn)
                    continue;
                double stat;
                if (_stat instanceof ElapsedTime) {
                    stat = elapsed / 1000.0;
                } else {
                    stat = _stat.getValue(truth[u], est[u]);
                }
                allStats[u][run.getAlgSimIndex()][statIndex][run.getRunIndex()] = stat;
            }
        }
    }
}
Also used : ElapsedTime(edu.cmu.tetrad.algcomparison.statistic.ElapsedTime) HasKnowledge(edu.cmu.tetrad.algcomparison.utils.HasKnowledge) ExternalAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.ExternalAlgorithm) MultiDataSetAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm) Statistic(edu.cmu.tetrad.algcomparison.statistic.Statistic) ParameterColumn(edu.cmu.tetrad.algcomparison.statistic.ParameterColumn) HasParameters(edu.cmu.tetrad.algcomparison.utils.HasParameters) ExternalAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.ExternalAlgorithm) Algorithm(edu.cmu.tetrad.algcomparison.algorithm.Algorithm) MultiDataSetAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm) TakesInitialGraph(edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph) DagToPag(edu.cmu.tetrad.search.DagToPag) Simulation(edu.cmu.tetrad.algcomparison.simulation.Simulation)

Aggregations

Simulation (edu.cmu.tetrad.algcomparison.simulation.Simulation)12 Algorithm (edu.cmu.tetrad.algcomparison.algorithm.Algorithm)9 ExternalAlgorithm (edu.cmu.tetrad.algcomparison.algorithm.ExternalAlgorithm)8 MultiDataSetAlgorithm (edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm)8 Statistic (edu.cmu.tetrad.algcomparison.statistic.Statistic)8 Parameters (edu.cmu.tetrad.util.Parameters)5 RandomForward (edu.cmu.tetrad.algcomparison.graph.RandomForward)4 SemSimulation (edu.cmu.tetrad.algcomparison.simulation.SemSimulation)4 HasParameters (edu.cmu.tetrad.algcomparison.utils.HasParameters)4 Comparison (edu.cmu.tetrad.algcomparison.Comparison)3 IndependenceWrapper (edu.cmu.tetrad.algcomparison.independence.IndependenceWrapper)3 ScoreWrapper (edu.cmu.tetrad.algcomparison.score.ScoreWrapper)3 FileNotFoundException (java.io.FileNotFoundException)3 IOException (java.io.IOException)3 DecimalFormat (java.text.DecimalFormat)3 ArrayList (java.util.ArrayList)3 ExecutionException (java.util.concurrent.ExecutionException)3 TimeoutException (java.util.concurrent.TimeoutException)3 ElapsedTime (edu.cmu.tetrad.algcomparison.statistic.ElapsedTime)2 ParameterColumn (edu.cmu.tetrad.algcomparison.statistic.ParameterColumn)2