Search in sources :

Example 1 with ParameterColumn

use of edu.cmu.tetrad.algcomparison.statistic.ParameterColumn in project tetrad by cmu-phil.

the class TimeoutComparison method calcStatTables.

private double[][][] calcStatTables(double[][][][] allStats, Mode mode, int numTables, List<AlgorithmSimulationWrapper> wrappers, int numStats, Statistics statistics) {
    double[][][] statTables = new double[numTables][wrappers.size()][numStats + 1];
    for (int u = 0; u < numTables; u++) {
        for (int i = 0; i < wrappers.size(); i++) {
            for (int j = 0; j < numStats; j++) {
                if (statistics.getStatistics().get(j) instanceof ParameterColumn) {
                    String statName = statistics.getStatistics().get(j).getAbbreviation();
                    SimulationWrapper simulationWrapper = wrappers.get(i).getSimulationWrapper();
                    AlgorithmWrapper algorithmWrapper = wrappers.get(i).getAlgorithmWrapper();
                    double stat = Double.NaN;
                    List<String> parameterNames = simulationWrapper.getParameters();
                    Parameters parameters = simulationWrapper.getSimulationSpecificParameters();
                    for (String name : parameterNames) {
                        if (name.equals(statName)) {
                            if (parameters.get(name) instanceof Boolean) {
                                boolean b = parameters.getBoolean(name);
                                stat = b ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY;
                            } else {
                                stat = parameters.getDouble(name);
                            }
                            break;
                        }
                    }
                    if (Double.isNaN(stat)) {
                        List<String> _parameterNames = algorithmWrapper.getParameters();
                        Parameters _parameters = algorithmWrapper.parameters;
                        for (String name : _parameterNames) {
                            if (name.equals(statName)) {
                                try {
                                    stat = _parameters.getDouble(name);
                                } catch (Exception e) {
                                    boolean b = _parameters.getBoolean(name);
                                    stat = b ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY;
                                }
                                break;
                            }
                        }
                    }
                    statTables[u][i][j] = stat;
                } else if (mode == Mode.Average) {
                    statTables[u][i][j] = StatUtils.mean(allStats[u][i][j]);
                } else if (mode == Mode.WorstCase) {
                    statTables[u][i][j] = StatUtils.min(allStats[u][i][j]);
                } else if (mode == Mode.StandardDeviation) {
                    statTables[u][i][j] = StatUtils.sd(allStats[u][i][j]);
                } else {
                    throw new IllegalStateException();
                }
            }
        }
    }
    return statTables;
}
Also used : Parameters(edu.cmu.tetrad.util.Parameters) HasParameters(edu.cmu.tetrad.algcomparison.utils.HasParameters) TimeoutException(java.util.concurrent.TimeoutException) FileNotFoundException(java.io.FileNotFoundException) IOException(java.io.IOException) ExecutionException(java.util.concurrent.ExecutionException) ParameterColumn(edu.cmu.tetrad.algcomparison.statistic.ParameterColumn)

Example 2 with ParameterColumn

use of edu.cmu.tetrad.algcomparison.statistic.ParameterColumn in project tetrad by cmu-phil.

the class Comparison method calcStatTables.

private double[][][] calcStatTables(double[][][][] allStats, Mode mode, int numTables, List<AlgorithmSimulationWrapper> wrappers, int numStats, Statistics statistics) {
    double[][][] statTables = new double[numTables][wrappers.size()][numStats + 1];
    for (int u = 0; u < numTables; u++) {
        for (int i = 0; i < wrappers.size(); i++) {
            for (int j = 0; j < numStats; j++) {
                if (statistics.getStatistics().get(j) instanceof ParameterColumn) {
                    String statName = statistics.getStatistics().get(j).getAbbreviation();
                    SimulationWrapper simulationWrapper = wrappers.get(i).getSimulationWrapper();
                    AlgorithmWrapper algorithmWrapper = wrappers.get(i).getAlgorithmWrapper();
                    double stat = Double.NaN;
                    List<String> parameterNames = simulationWrapper.getParameters();
                    Parameters parameters = simulationWrapper.getSimulationSpecificParameters();
                    for (String name : parameterNames) {
                        if (name.equals(statName)) {
                            if (parameters.get(name) instanceof Boolean) {
                                boolean b = parameters.getBoolean(name);
                                stat = b ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY;
                            } else {
                                stat = parameters.getDouble(name);
                            }
                            break;
                        }
                    }
                    if (Double.isNaN(stat)) {
                        List<String> _parameterNames = algorithmWrapper.getParameters();
                        Parameters _parameters = algorithmWrapper.parameters;
                        for (String name : _parameterNames) {
                            if (name.equals(statName)) {
                                try {
                                    stat = _parameters.getDouble(name);
                                } catch (Exception e) {
                                    boolean b = _parameters.getBoolean(name);
                                    stat = b ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY;
                                }
                                break;
                            }
                        }
                    }
                    statTables[u][i][j] = stat;
                } else if (mode == Mode.Average) {
                    statTables[u][i][j] = StatUtils.mean(allStats[u][i][j]);
                } else if (mode == Mode.WorstCase) {
                    statTables[u][i][j] = StatUtils.min(allStats[u][i][j]);
                } else if (mode == Mode.StandardDeviation) {
                    statTables[u][i][j] = StatUtils.sd(allStats[u][i][j]);
                } else {
                    throw new IllegalStateException();
                }
            }
        }
    }
    return statTables;
}
Also used : HasParameters(edu.cmu.tetrad.algcomparison.utils.HasParameters) ParameterColumn(edu.cmu.tetrad.algcomparison.statistic.ParameterColumn)

Example 3 with ParameterColumn

use of edu.cmu.tetrad.algcomparison.statistic.ParameterColumn in project tetrad by cmu-phil.

the class TimeoutComparison method doRun.

private void doRun(List<AlgorithmSimulationWrapper> algorithmSimulationWrappers, List<AlgorithmWrapper> algorithmWrappers, List<SimulationWrapper> simulationWrappers, Statistics statistics, int numGraphTypes, double[][][][] allStats, Run run) {
    System.out.println();
    System.out.println("Run " + (run.getRunIndex() + 1));
    System.out.println();
    AlgorithmSimulationWrapper algorithmSimulationWrapper = algorithmSimulationWrappers.get(run.getAlgSimIndex());
    AlgorithmWrapper algorithmWrapper = algorithmSimulationWrapper.getAlgorithmWrapper();
    SimulationWrapper simulationWrapper = algorithmSimulationWrapper.getSimulationWrapper();
    DataModel data = simulationWrapper.getDataModel(run.getRunIndex());
    Graph trueGraph = simulationWrapper.getTrueGraph(run.getRunIndex());
    System.out.println((run.getAlgSimIndex() + 1) + ". " + algorithmWrapper.getDescription() + " simulationWrapper: " + simulationWrapper.getDescription());
    long start = System.currentTimeMillis();
    Graph out;
    try {
        Algorithm algorithm = algorithmWrapper.getAlgorithm();
        Simulation simulation = simulationWrapper.getSimulation();
        if (algorithm instanceof HasKnowledge && simulation instanceof HasKnowledge) {
            ((HasKnowledge) algorithm).setKnowledge(((HasKnowledge) simulation).getKnowledge());
        }
        if (algorithmWrapper.getAlgorithm() instanceof ExternalAlgorithm) {
            ExternalAlgorithm external = (ExternalAlgorithm) algorithmWrapper.getAlgorithm();
            external.setSimulation(simulationWrapper.getSimulation());
            external.setPath(resultsPath);
            external.setSimIndex(simulationWrappers.indexOf(simulationWrapper));
        }
        if (algorithm instanceof MultiDataSetAlgorithm) {
            List<Integer> indices = new ArrayList<>();
            int numDataModels = simulationWrapper.getSimulation().getNumDataModels();
            for (int i = 0; i < numDataModels; i++) {
                indices.add(i);
            }
            Collections.shuffle(indices);
            List<DataModel> dataModels = new ArrayList<>();
            int randomSelectionSize = algorithmWrapper.getAlgorithmSpecificParameters().getInt("randomSelectionSize");
            for (int i = 0; i < Math.min(numDataModels, randomSelectionSize); i++) {
                dataModels.add(simulationWrapper.getSimulation().getDataModel(indices.get(i)));
            }
            Parameters _params = algorithmWrapper.getAlgorithmSpecificParameters();
            out = ((MultiDataSetAlgorithm) algorithm).search(dataModels, _params);
        } else {
            DataModel dataModel = copyData ? data.copy() : data;
            Parameters _params = algorithmWrapper.getAlgorithmSpecificParameters();
            out = algorithm.search(dataModel, _params);
        }
    } catch (Exception e) {
        System.out.println("Could not run " + algorithmWrapper.getDescription());
        e.printStackTrace();
        return;
    }
    int simIndex = simulationWrappers.indexOf(simulationWrapper) + 1;
    int algIndex = algorithmWrappers.indexOf(algorithmWrapper) + 1;
    long stop = System.currentTimeMillis();
    long elapsed = stop - start;
    saveGraph(resultsPath, out, run.getRunIndex(), simIndex, algIndex, algorithmWrapper, elapsed);
    if (trueGraph != null) {
        out = GraphUtils.replaceNodes(out, trueGraph.getNodes());
    }
    if (algorithmWrapper.getAlgorithm() instanceof ExternalAlgorithm) {
        ExternalAlgorithm extAlg = (ExternalAlgorithm) algorithmWrapper.getAlgorithm();
        extAlg.setSimIndex(simulationWrappers.indexOf(simulationWrapper));
        extAlg.setSimulation(simulationWrapper.getSimulation());
        extAlg.setPath(resultsPath);
        elapsed = extAlg.getElapsedTime(data, simulationWrapper.getSimulationSpecificParameters());
    }
    Graph[] est = new Graph[numGraphTypes];
    Graph comparisonGraph;
    if (this.comparisonGraph == ComparisonGraph.true_DAG) {
        comparisonGraph = new EdgeListGraph(trueGraph);
    } else if (this.comparisonGraph == ComparisonGraph.Pattern_of_the_true_DAG) {
        comparisonGraph = SearchGraphUtils.patternForDag(new EdgeListGraph(trueGraph));
    } else if (this.comparisonGraph == ComparisonGraph.PAG_of_the_true_DAG) {
        comparisonGraph = new DagToPag(new EdgeListGraph(trueGraph)).convert();
    } else {
        throw new IllegalArgumentException("Unrecognized graph type.");
    }
    // Graph comparisonGraph = trueGraph == null ? null : algorithmSimulationWrapper.getComparisonGraph(trueGraph);
    est[0] = out;
    graphTypeUsed[0] = true;
    if (data.isMixed()) {
        est[1] = getSubgraph(out, true, true, data);
        est[2] = getSubgraph(out, true, false, data);
        est[3] = getSubgraph(out, false, false, data);
        graphTypeUsed[1] = true;
        graphTypeUsed[2] = true;
        graphTypeUsed[3] = true;
    }
    Graph[] truth = new Graph[numGraphTypes];
    truth[0] = comparisonGraph;
    if (data.isMixed() && comparisonGraph != null) {
        truth[1] = getSubgraph(comparisonGraph, true, true, data);
        truth[2] = getSubgraph(comparisonGraph, true, false, data);
        truth[3] = getSubgraph(comparisonGraph, false, false, data);
    }
    if (comparisonGraph != null) {
        for (int u = 0; u < numGraphTypes; u++) {
            if (!graphTypeUsed[u]) {
                continue;
            }
            int statIndex = -1;
            for (Statistic _stat : statistics.getStatistics()) {
                statIndex++;
                if (_stat instanceof ParameterColumn) {
                    continue;
                }
                double stat;
                if (_stat instanceof ElapsedTime) {
                    stat = elapsed / 1000.0;
                } else {
                    stat = _stat.getValue(truth[u], est[u]);
                }
                allStats[u][run.getAlgSimIndex()][statIndex][run.getRunIndex()] = stat;
            }
        }
    }
}
Also used : ArrayList(java.util.ArrayList) ElapsedTime(edu.cmu.tetrad.algcomparison.statistic.ElapsedTime) HasKnowledge(edu.cmu.tetrad.algcomparison.utils.HasKnowledge) ExternalAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.ExternalAlgorithm) MultiDataSetAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm) Statistic(edu.cmu.tetrad.algcomparison.statistic.Statistic) ParameterColumn(edu.cmu.tetrad.algcomparison.statistic.ParameterColumn) Parameters(edu.cmu.tetrad.util.Parameters) HasParameters(edu.cmu.tetrad.algcomparison.utils.HasParameters) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) ExternalAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.ExternalAlgorithm) Algorithm(edu.cmu.tetrad.algcomparison.algorithm.Algorithm) MultiDataSetAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm) TimeoutException(java.util.concurrent.TimeoutException) FileNotFoundException(java.io.FileNotFoundException) IOException(java.io.IOException) ExecutionException(java.util.concurrent.ExecutionException) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) TakesInitialGraph(edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph) Graph(edu.cmu.tetrad.graph.Graph) DagToPag(edu.cmu.tetrad.search.DagToPag) Simulation(edu.cmu.tetrad.algcomparison.simulation.Simulation) DataModel(edu.cmu.tetrad.data.DataModel)

Example 4 with ParameterColumn

use of edu.cmu.tetrad.algcomparison.statistic.ParameterColumn in project tetrad by cmu-phil.

the class Comparison method doRun.

private void doRun(List<AlgorithmSimulationWrapper> algorithmSimulationWrappers, List<AlgorithmWrapper> algorithmWrappers, List<SimulationWrapper> simulationWrappers, Statistics statistics, int numGraphTypes, double[][][][] allStats, Run run) {
    System.out.println();
    System.out.println("Run " + (run.getRunIndex() + 1));
    System.out.println();
    AlgorithmSimulationWrapper algorithmSimulationWrapper = algorithmSimulationWrappers.get(run.getAlgSimIndex());
    AlgorithmWrapper algorithmWrapper = algorithmSimulationWrapper.getAlgorithmWrapper();
    SimulationWrapper simulationWrapper = algorithmSimulationWrapper.getSimulationWrapper();
    DataModel data = simulationWrapper.getDataModel(run.getRunIndex());
    Graph trueGraph = simulationWrapper.getTrueGraph(run.getRunIndex());
    System.out.println((run.getAlgSimIndex() + 1) + ". " + algorithmWrapper.getDescription() + " simulationWrapper: " + simulationWrapper.getDescription());
    long start = System.currentTimeMillis();
    Graph out;
    try {
        Algorithm algorithm = algorithmWrapper.getAlgorithm();
        Simulation simulation = simulationWrapper.getSimulation();
        if (algorithm instanceof HasKnowledge && simulation instanceof HasKnowledge) {
            ((HasKnowledge) algorithm).setKnowledge(((HasKnowledge) simulation).getKnowledge());
        }
        if (algorithmWrapper.getAlgorithm() instanceof ExternalAlgorithm) {
            ExternalAlgorithm external = (ExternalAlgorithm) algorithmWrapper.getAlgorithm();
            external.setSimulation(simulationWrapper.getSimulation());
            external.setPath(resultsPath);
            external.setSimIndex(simulationWrappers.indexOf(simulationWrapper));
        }
        if (algorithm instanceof MultiDataSetAlgorithm) {
            List<Integer> indices = new ArrayList<>();
            int numDataModels = simulationWrapper.getSimulation().getNumDataModels();
            for (int i = 0; i < numDataModels; i++) indices.add(i);
            Collections.shuffle(indices);
            List<DataModel> dataModels = new ArrayList<>();
            int randomSelectionSize = algorithmWrapper.getAlgorithmSpecificParameters().getInt("randomSelectionSize");
            for (int i = 0; i < Math.min(numDataModels, randomSelectionSize); i++) {
                dataModels.add(simulationWrapper.getSimulation().getDataModel(indices.get(i)));
            }
            Parameters _params = algorithmWrapper.getAlgorithmSpecificParameters();
            out = ((MultiDataSetAlgorithm) algorithm).search(dataModels, _params);
        } else {
            DataModel dataModel = copyData ? data.copy() : data;
            Parameters _params = algorithmWrapper.getAlgorithmSpecificParameters();
            out = algorithm.search(dataModel, _params);
        }
    } catch (Exception e) {
        System.out.println("Could not run " + algorithmWrapper.getDescription());
        e.printStackTrace();
        return;
    }
    int simIndex = simulationWrappers.indexOf(simulationWrapper) + 1;
    int algIndex = algorithmWrappers.indexOf(algorithmWrapper) + 1;
    long stop = System.currentTimeMillis();
    long elapsed = stop - start;
    saveGraph(resultsPath, out, run.getRunIndex(), simIndex, algIndex, algorithmWrapper, elapsed);
    if (trueGraph != null) {
        out = GraphUtils.replaceNodes(out, trueGraph.getNodes());
    }
    if (algorithmWrapper.getAlgorithm() instanceof ExternalAlgorithm) {
        ExternalAlgorithm extAlg = (ExternalAlgorithm) algorithmWrapper.getAlgorithm();
        extAlg.setSimIndex(simulationWrappers.indexOf(simulationWrapper));
        extAlg.setSimulation(simulationWrapper.getSimulation());
        extAlg.setPath(resultsPath);
        elapsed = extAlg.getElapsedTime(data, simulationWrapper.getSimulationSpecificParameters());
    }
    Graph[] est = new Graph[numGraphTypes];
    Graph comparisonGraph;
    if (this.comparisonGraph == ComparisonGraph.true_DAG) {
        comparisonGraph = new EdgeListGraph(trueGraph);
    } else if (this.comparisonGraph == ComparisonGraph.Pattern_of_the_true_DAG) {
        comparisonGraph = SearchGraphUtils.patternForDag(new EdgeListGraph(trueGraph));
    } else if (this.comparisonGraph == ComparisonGraph.PAG_of_the_true_DAG) {
        comparisonGraph = new DagToPag(new EdgeListGraph(trueGraph)).convert();
    } else {
        throw new IllegalArgumentException("Unrecognized graph type.");
    }
    // Graph comparisonGraph = trueGraph == null ? null : algorithmSimulationWrapper.getComparisonGraph(trueGraph);
    est[0] = out;
    graphTypeUsed[0] = true;
    if (data.isMixed()) {
        est[1] = getSubgraph(out, true, true, data);
        est[2] = getSubgraph(out, true, false, data);
        est[3] = getSubgraph(out, false, false, data);
        graphTypeUsed[1] = true;
        graphTypeUsed[2] = true;
        graphTypeUsed[3] = true;
    }
    Graph[] truth = new Graph[numGraphTypes];
    truth[0] = comparisonGraph;
    if (data.isMixed() && comparisonGraph != null) {
        truth[1] = getSubgraph(comparisonGraph, true, true, data);
        truth[2] = getSubgraph(comparisonGraph, true, false, data);
        truth[3] = getSubgraph(comparisonGraph, false, false, data);
    }
    if (comparisonGraph != null) {
        for (int u = 0; u < numGraphTypes; u++) {
            if (!graphTypeUsed[u])
                continue;
            int statIndex = -1;
            for (Statistic _stat : statistics.getStatistics()) {
                statIndex++;
                if (_stat instanceof ParameterColumn)
                    continue;
                double stat;
                if (_stat instanceof ElapsedTime) {
                    stat = elapsed / 1000.0;
                } else {
                    stat = _stat.getValue(truth[u], est[u]);
                }
                allStats[u][run.getAlgSimIndex()][statIndex][run.getRunIndex()] = stat;
            }
        }
    }
}
Also used : ElapsedTime(edu.cmu.tetrad.algcomparison.statistic.ElapsedTime) HasKnowledge(edu.cmu.tetrad.algcomparison.utils.HasKnowledge) ExternalAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.ExternalAlgorithm) MultiDataSetAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm) Statistic(edu.cmu.tetrad.algcomparison.statistic.Statistic) ParameterColumn(edu.cmu.tetrad.algcomparison.statistic.ParameterColumn) HasParameters(edu.cmu.tetrad.algcomparison.utils.HasParameters) ExternalAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.ExternalAlgorithm) Algorithm(edu.cmu.tetrad.algcomparison.algorithm.Algorithm) MultiDataSetAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm) TakesInitialGraph(edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph) DagToPag(edu.cmu.tetrad.search.DagToPag) Simulation(edu.cmu.tetrad.algcomparison.simulation.Simulation)

Aggregations

ParameterColumn (edu.cmu.tetrad.algcomparison.statistic.ParameterColumn)4 HasParameters (edu.cmu.tetrad.algcomparison.utils.HasParameters)4 Algorithm (edu.cmu.tetrad.algcomparison.algorithm.Algorithm)2 ExternalAlgorithm (edu.cmu.tetrad.algcomparison.algorithm.ExternalAlgorithm)2 MultiDataSetAlgorithm (edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm)2 Simulation (edu.cmu.tetrad.algcomparison.simulation.Simulation)2 ElapsedTime (edu.cmu.tetrad.algcomparison.statistic.ElapsedTime)2 Statistic (edu.cmu.tetrad.algcomparison.statistic.Statistic)2 HasKnowledge (edu.cmu.tetrad.algcomparison.utils.HasKnowledge)2 TakesInitialGraph (edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph)2 DagToPag (edu.cmu.tetrad.search.DagToPag)2 Parameters (edu.cmu.tetrad.util.Parameters)2 FileNotFoundException (java.io.FileNotFoundException)2 IOException (java.io.IOException)2 ExecutionException (java.util.concurrent.ExecutionException)2 TimeoutException (java.util.concurrent.TimeoutException)2 DataModel (edu.cmu.tetrad.data.DataModel)1 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)1 Graph (edu.cmu.tetrad.graph.Graph)1 ArrayList (java.util.ArrayList)1