Search in sources :

Example 6 with LargeScaleSimulation

use of edu.cmu.tetrad.sem.LargeScaleSimulation in project tetrad by cmu-phil.

the class Comparison method compare.

/**
 * Simulates data from model paramerizing the given DAG, and runs the algorithm on that data,
 * printing out error statistics.
 */
public static ComparisonResult compare(ComparisonParameters params) {
    DataSet dataSet;
    Graph trueDag;
    IndependenceTest test = null;
    Score score = null;
    ComparisonResult result = new ComparisonResult(params);
    if (params.getDataFile() != null) {
        dataSet = loadDataFile(params.getDataFile());
        if (params.getGraphFile() == null) {
            throw new IllegalArgumentException("True graph file not set.");
        }
        trueDag = loadGraphFile(params.getGraphFile());
    } else {
        if (params.getNumVars() == -1) {
            throw new IllegalArgumentException("Number of variables not set.");
        }
        if (params.getNumEdges() == -1) {
            throw new IllegalArgumentException("Number of edges not set.");
        }
        if (params.getDataType() == ComparisonParameters.DataType.Continuous) {
            List<Node> nodes = new ArrayList<>();
            for (int i = 0; i < params.getNumVars(); i++) {
                nodes.add(new ContinuousVariable("X" + (i + 1)));
            }
            trueDag = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
            if (params.getDataType() == null) {
                throw new IllegalArgumentException("Data type not set or inferred.");
            }
            if (params.getSampleSize() == -1) {
                throw new IllegalArgumentException("Sample size not set.");
            }
            LargeScaleSimulation sim = new LargeScaleSimulation(trueDag);
            dataSet = sim.simulateDataFisher(params.getSampleSize());
        } else if (params.getDataType() == ComparisonParameters.DataType.Discrete) {
            List<Node> nodes = new ArrayList<>();
            for (int i = 0; i < params.getNumVars(); i++) {
                nodes.add(new DiscreteVariable("X" + (i + 1), 3));
            }
            trueDag = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
            if (params.getDataType() == null) {
                throw new IllegalArgumentException("Data type not set or inferred.");
            }
            if (params.getSampleSize() == -1) {
                throw new IllegalArgumentException("Sample size not set.");
            }
            int[] tiers = new int[nodes.size()];
            for (int i = 0; i < nodes.size(); i++) {
                tiers[i] = i;
            }
            BayesPm pm = new BayesPm(trueDag, 3, 3);
            MlBayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM);
            dataSet = im.simulateData(params.getSampleSize(), false, tiers);
        } else {
            throw new IllegalArgumentException("Unrecognized data type.");
        }
        if (dataSet == null) {
            throw new IllegalArgumentException("No data set.");
        }
    }
    if (params.getIndependenceTest() == ComparisonParameters.IndependenceTestType.FisherZ) {
        if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Continuous) {
            throw new IllegalArgumentException("Data type previously set to something other than continuous.");
        }
        if (Double.isNaN(params.getAlpha())) {
            throw new IllegalArgumentException("Alpha not set.");
        }
        test = new IndTestFisherZ(dataSet, params.getAlpha());
        params.setDataType(ComparisonParameters.DataType.Continuous);
    } else if (params.getIndependenceTest() == ComparisonParameters.IndependenceTestType.ChiSquare) {
        if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Discrete) {
            throw new IllegalArgumentException("Data type previously set to something other than discrete.");
        }
        if (Double.isNaN(params.getAlpha())) {
            throw new IllegalArgumentException("Alpha not set.");
        }
        test = new IndTestChiSquare(dataSet, params.getAlpha());
        params.setDataType(ComparisonParameters.DataType.Discrete);
    }
    if (params.getScore() == ScoreType.SemBic) {
        if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Continuous) {
            throw new IllegalArgumentException("Data type previously set to something other than continuous.");
        }
        if (Double.isNaN(params.getPenaltyDiscount())) {
            throw new IllegalArgumentException("Penalty discount not set.");
        }
        SemBicScore semBicScore = new SemBicScore(new CovarianceMatrixOnTheFly(dataSet));
        semBicScore.setPenaltyDiscount(params.getPenaltyDiscount());
        score = semBicScore;
        params.setDataType(ComparisonParameters.DataType.Continuous);
    } else if (params.getScore() == ScoreType.BDeu) {
        if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Discrete) {
            throw new IllegalArgumentException("Data type previously set to something other than discrete.");
        }
        if (Double.isNaN(params.getSamplePrior())) {
            throw new IllegalArgumentException("Sample prior not set.");
        }
        if (Double.isNaN(params.getStructurePrior())) {
            throw new IllegalArgumentException("Structure prior not set.");
        }
        score = new BDeuScore(dataSet);
        ((BDeuScore) score).setSamplePrior(params.getSamplePrior());
        ((BDeuScore) score).setStructurePrior(params.getStructurePrior());
        params.setDataType(ComparisonParameters.DataType.Discrete);
        params.setDataType(ComparisonParameters.DataType.Discrete);
    }
    if (params.getAlgorithm() == null) {
        throw new IllegalArgumentException("Algorithm not set.");
    }
    long time1 = System.currentTimeMillis();
    if (params.getAlgorithm() == ComparisonParameters.Algorithm.PC) {
        if (test == null)
            throw new IllegalArgumentException("Test not set.");
        Pc search = new Pc(test);
        result.setResultGraph(search.search());
        result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.CPC) {
        if (test == null)
            throw new IllegalArgumentException("Test not set.");
        Cpc search = new Cpc(test);
        result.setResultGraph(search.search());
        result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCLocal) {
        if (test == null)
            throw new IllegalArgumentException("Test not set.");
        PcLocal search = new PcLocal(test);
        result.setResultGraph(search.search());
        result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCStableMax) {
        if (test == null)
            throw new IllegalArgumentException("Test not set.");
        PcStableMax search = new PcStableMax(test);
        result.setResultGraph(search.search());
        result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FGES) {
        if (score == null)
            throw new IllegalArgumentException("Score not set.");
        Fges search = new Fges(score);
        search.setFaithfulnessAssumed(params.isOneEdgeFaithfulnessAssumed());
        result.setResultGraph(search.search());
        result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FGES2) {
        if (score == null)
            throw new IllegalArgumentException("Score not set.");
        Fges search = new Fges(score);
        search.setFaithfulnessAssumed(params.isOneEdgeFaithfulnessAssumed());
        result.setResultGraph(search.search());
        result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FCI) {
        if (test == null)
            throw new IllegalArgumentException("Test not set.");
        Fci search = new Fci(test);
        result.setResultGraph(search.search());
        result.setCorrectResult(new DagToPag(trueDag).convert());
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.GFCI) {
        if (test == null)
            throw new IllegalArgumentException("Test not set.");
        GFci search = new GFci(test, score);
        result.setResultGraph(search.search());
        result.setCorrectResult(new DagToPag(trueDag).convert());
    } else {
        throw new IllegalArgumentException("Unrecognized algorithm.");
    }
    long time2 = System.currentTimeMillis();
    long elapsed = time2 - time1;
    result.setElapsed(elapsed);
    result.setTrueDag(trueDag);
    return result;
}
Also used : MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) LargeScaleSimulation(edu.cmu.tetrad.sem.LargeScaleSimulation) ArrayList(java.util.ArrayList) List(java.util.List) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) Graph(edu.cmu.tetrad.graph.Graph) BayesPm(edu.cmu.tetrad.bayes.BayesPm)

Example 7 with LargeScaleSimulation

use of edu.cmu.tetrad.sem.LargeScaleSimulation in project tetrad by cmu-phil.

the class TestLargeSemSimulator method test1.

@Test
public void test1() {
    List<Node> nodes = new ArrayList<>();
    for (int i = 1; i <= 10; i++) nodes.add(new ContinuousVariable("X" + i));
    Graph graph = GraphUtils.randomGraph(nodes, 0, 10, 5, 5, 5, false);
    LargeScaleSimulation simulator = new LargeScaleSimulation(graph);
    DataSet dataset = simulator.simulateDataFisher(1000);
    assertEquals(1000, dataset.getNumRows());
}
Also used : ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) LargeScaleSimulation(edu.cmu.tetrad.sem.LargeScaleSimulation) DataSet(edu.cmu.tetrad.data.DataSet) ArrayList(java.util.ArrayList) Test(org.junit.Test)

Example 8 with LargeScaleSimulation

use of edu.cmu.tetrad.sem.LargeScaleSimulation in project tetrad by cmu-phil.

the class LinearFisherModel method createData.

@Override
public void createData(Parameters parameters) {
    boolean saveLatentVars = parameters.getBoolean("saveLatentVars");
    dataSets = new ArrayList<>();
    graphs = new ArrayList<>();
    Graph graph = randomGraph.createGraph(parameters);
    System.out.println("degree = " + GraphUtils.getDegree(graph));
    for (int i = 0; i < parameters.getInt("numRuns"); i++) {
        System.out.println("Simulating dataset #" + (i + 1));
        if (shocks != null && shocks.size() > 0) {
            parameters.set("numVars", shocks.get(0).getVariables().size());
        }
        if (parameters.getBoolean("differentGraphs") && i > 0) {
            graph = randomGraph.createGraph(parameters);
        }
        if (shocks != null && shocks.size() > 0) {
            graph.setNodes(shocks.get(0).getVariables());
        }
        graphs.add(graph);
        int[] tiers = new int[graph.getNodes().size()];
        for (int j = 0; j < tiers.length; j++) {
            tiers[j] = j;
        }
        LargeScaleSimulation simulator = new LargeScaleSimulation(graph, graph.getNodes(), tiers);
        simulator.setCoefRange(parameters.getDouble("coefLow"), parameters.getDouble("coefHigh"));
        simulator.setVarRange(parameters.getDouble("varLow"), parameters.getDouble("varHigh"));
        simulator.setIncludePositiveCoefs(parameters.getBoolean("includePositiveCoefs"));
        simulator.setIncludeNegativeCoefs(parameters.getBoolean("includeNegativeCoefs"));
        simulator.setBetaLeftValue(parameters.getDouble("betaLeftValue"));
        simulator.setBetaRightValue(parameters.getDouble("betaRightValue"));
        simulator.setSelfLoopCoef(parameters.getDouble("selfLoopCoef"));
        simulator.setMeanRange(parameters.getDouble("meanLow"), parameters.getDouble("meanHigh"));
        simulator.setErrorsNormal(parameters.getBoolean("errorsNormal"));
        simulator.setVerbose(parameters.getBoolean("verbose"));
        DataSet dataSet;
        if (shocks == null) {
            dataSet = simulator.simulateDataFisher(parameters.getInt("intervalBetweenShocks"), parameters.getInt("intervalBetweenRecordings"), parameters.getInt("sampleSize"), parameters.getDouble("fisherEpsilon"), saveLatentVars);
        } else {
            DataSet _shocks = (DataSet) shocks.get(i);
            dataSet = simulator.simulateDataFisher(_shocks.getDoubleData().toArray(), parameters.getInt("intervalBetweenShocks"), parameters.getDouble("fisherEpsilon"));
        }
        double variance = parameters.getDouble("measurementVariance");
        if (variance > 0) {
            for (int k = 0; k < dataSet.getNumRows(); k++) {
                for (int j = 0; j < dataSet.getNumColumns(); j++) {
                    double d = dataSet.getDouble(k, j);
                    double delta = RandomUtil.getInstance().nextNormal(0, Math.sqrt(variance));
                    dataSet.setDouble(k, j, d + delta);
                }
            }
        }
        dataSet.setName("" + (i + 1));
        if (parameters.getDouble("percentDiscrete") > 0.0) {
            if (this.shuffledOrder == null) {
                List<Node> shuffledNodes = new ArrayList<>(dataSet.getVariables());
                Collections.shuffle(shuffledNodes);
                this.shuffledOrder = shuffledNodes;
            }
            Discretizer discretizer = new Discretizer(dataSet);
            for (int k = 0; k < shuffledOrder.size() * parameters.getDouble("percentDiscrete") * 0.01; k++) {
                discretizer.equalIntervals(dataSet.getVariable(shuffledOrder.get(k).getName()), parameters.getInt("numCategories"));
            }
            String name = dataSet.getName();
            dataSet = discretizer.discretize();
            dataSet.setName(name);
        }
        if (parameters.getBoolean("randomizeColumns")) {
            dataSet = DataUtils.reorderColumns(dataSet);
        }
        dataSets.add(saveLatentVars ? dataSet : DataUtils.restrictToMeasured(dataSet));
    }
}
Also used : Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) Graph(edu.cmu.tetrad.graph.Graph) RandomGraph(edu.cmu.tetrad.algcomparison.graph.RandomGraph) LargeScaleSimulation(edu.cmu.tetrad.sem.LargeScaleSimulation)

Example 9 with LargeScaleSimulation

use of edu.cmu.tetrad.sem.LargeScaleSimulation in project tetrad by cmu-phil.

the class Comparison2 method compare.

/**
 * Simulates data from model parameterizing the given DAG, and runs the
 * algorithm on that data, printing out error statistics.
 */
public static ComparisonResult compare(ComparisonParameters params) {
    DataSet dataSet = null;
    Graph trueDag = null;
    IndependenceTest test = null;
    Score score = null;
    ComparisonResult result = new ComparisonResult(params);
    if (params.isDataFromFile()) {
        /**
         * Set path to the data directory *
         */
        String path = "/Users/dmalinsky/Documents/research/data/danexamples";
        File dir = new File(path);
        File[] files = dir.listFiles();
        if (files == null) {
            throw new NullPointerException("No files in " + path);
        }
        for (File file : files) {
            if (file.getName().startsWith("graph") && file.getName().contains(String.valueOf(params.getGraphNum())) && file.getName().endsWith(".g.txt")) {
                params.setGraphFile(file.getName());
                trueDag = GraphUtils.loadGraphTxt(file);
                break;
            }
        }
        String trialGraph = String.valueOf(params.getGraphNum()).concat("-").concat(String.valueOf(params.getTrial())).concat(".dat.txt");
        for (File file : files) {
            if (file.getName().startsWith("graph") && file.getName().endsWith(trialGraph)) {
                Path dataFile = Paths.get(path.concat("/").concat(file.getName()));
                Delimiter delimiter = Delimiter.TAB;
                if (params.getDataType() == ComparisonParameters.DataType.Continuous) {
                    try {
                        TabularDataReader dataReader = new ContinuousTabularDataFileReader(dataFile.toFile(), delimiter);
                        dataSet = (DataSet) DataConvertUtils.toDataModel(dataReader.readInData());
                    } catch (IOException e) {
                        e.printStackTrace();
                    }
                    params.setDataFile(file.getName());
                    break;
                } else {
                    try {
                        TabularDataReader dataReader = new VerticalDiscreteTabularDataReader(dataFile.toFile(), delimiter);
                        dataSet = (DataSet) DataConvertUtils.toDataModel(dataReader.readInData());
                    } catch (IOException e) {
                        e.printStackTrace();
                    }
                    params.setDataFile(file.getName());
                    break;
                }
            }
        }
        System.out.println("current graph file = " + params.getGraphFile());
        System.out.println("current data set file = " + params.getDataFile());
    }
    if (params.isNoData()) {
        List<Node> nodes = new ArrayList<>();
        for (int i = 0; i < params.getNumVars(); i++) {
            nodes.add(new ContinuousVariable("X" + (i + 1)));
        }
        trueDag = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
        /**
         * added 5.25.16 for tsFCI *
         */
        if (params.getAlgorithm() == ComparisonParameters.Algorithm.TsFCI) {
            trueDag = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
            trueDag = TimeSeriesUtils.graphToLagGraph(trueDag, 2);
            System.out.println("Creating Time Lag Graph : " + trueDag);
        }
        /**
         * ************************
         */
        test = new IndTestDSep(trueDag);
        score = new GraphScore(trueDag);
        if (params.getAlgorithm() == null) {
            throw new IllegalArgumentException("Algorithm not set.");
        }
        long time1 = System.currentTimeMillis();
        if (params.getAlgorithm() == ComparisonParameters.Algorithm.PC) {
            if (test == null) {
                throw new IllegalArgumentException("Test not set.");
            }
            Pc search = new Pc(test);
            result.setResultGraph(search.search());
            result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
        } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.CPC) {
            if (test == null) {
                throw new IllegalArgumentException("Test not set.");
            }
            Cpc search = new Cpc(test);
            result.setResultGraph(search.search());
            result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
        } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCLocal) {
            if (test == null) {
                throw new IllegalArgumentException("Test not set.");
            }
            PcLocal search = new PcLocal(test);
            result.setResultGraph(search.search());
            result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
        } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCStableMax) {
            if (test == null) {
                throw new IllegalArgumentException("Test not set.");
            }
            PcStableMax search = new PcStableMax(test);
            result.setResultGraph(search.search());
            result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
        } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FGES) {
            if (score == null) {
                throw new IllegalArgumentException("Score not set.");
            }
            Fges search = new Fges(score);
            // search.setFaithfulnessAssumed(params.isOneEdgeFaithfulnessAssumed());
            result.setResultGraph(search.search());
            result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
        } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FCI) {
            if (test == null) {
                throw new IllegalArgumentException("Test not set.");
            }
            Fci search = new Fci(test);
            result.setResultGraph(search.search());
            result.setCorrectResult(new DagToPag(trueDag).convert());
        } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.GFCI) {
            if (test == null) {
                throw new IllegalArgumentException("Test not set.");
            }
            GFci search = new GFci(test, score);
            result.setResultGraph(search.search());
            result.setCorrectResult(new DagToPag(trueDag).convert());
        } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.TsFCI) {
            if (test == null) {
                throw new IllegalArgumentException("Test not set.");
            }
            TsFci search = new TsFci(test);
            IKnowledge knowledge = getKnowledge(trueDag);
            search.setKnowledge(knowledge);
            result.setResultGraph(search.search());
            result.setCorrectResult(new TsDagToPag(trueDag).convert());
            System.out.println("Correct result for trial = " + result.getCorrectResult());
            System.out.println("Search result for trial = " + result.getResultGraph());
        } else {
            throw new IllegalArgumentException("Unrecognized algorithm.");
        }
        long time2 = System.currentTimeMillis();
        long elapsed = time2 - time1;
        result.setElapsed(elapsed);
        result.setTrueDag(trueDag);
        return result;
    } else if (params.getDataFile() != null) {
        // dataSet = loadDataFile(params.getDataFile());
        System.out.println("Using data from file... ");
        if (params.getGraphFile() == null) {
            throw new IllegalArgumentException("True graph file not set.");
        } else {
            System.out.println("Using graph from file... ");
        // trueDag = GraphUtils.loadGraph(File params.getGraphFile());
        }
    } else {
        if (params.getNumVars() == -1) {
            throw new IllegalArgumentException("Number of variables not set.");
        }
        if (params.getNumEdges() == -1) {
            throw new IllegalArgumentException("Number of edges not set.");
        }
        if (params.getDataType() == ComparisonParameters.DataType.Continuous) {
            List<Node> nodes = new ArrayList<>();
            for (int i = 0; i < params.getNumVars(); i++) {
                nodes.add(new ContinuousVariable("X" + (i + 1)));
            }
            trueDag = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
            /**
             * added 6.08.16 for tsFCI *
             */
            if (params.getAlgorithm() == ComparisonParameters.Algorithm.TsFCI) {
                trueDag = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
                trueDag = TimeSeriesUtils.graphToLagGraph(trueDag, 2);
                System.out.println("Creating Time Lag Graph : " + trueDag);
            }
            if (params.getDataType() == null) {
                throw new IllegalArgumentException("Data type not set or inferred.");
            }
            if (params.getSampleSize() == -1) {
                throw new IllegalArgumentException("Sample size not set.");
            }
            LargeScaleSimulation sim = new LargeScaleSimulation(trueDag);
            /**
             * added 6.08.16 for tsFCI *
             */
            if (params.getAlgorithm() == ComparisonParameters.Algorithm.TsFCI) {
                sim.setCoefRange(0.20, 0.50);
            }
            /**
             * added 6.08.16 for tsFCI *
             */
            if (params.getAlgorithm() == ComparisonParameters.Algorithm.TsFCI) {
                // //                    System.out.println("Coefs matrix : " + sim.getCoefs());
                // System.out.println(MatrixUtils.toString(sim.getCoefficientMatrix()));
                // //                    System.out.println("dim = " + sim.getCoefs()[1][1]);
                // boolean isStableTetradMatrix = allEigenvaluesAreSmallerThanOneInModulus(new TetradMatrix(sim.getCoefficientMatrix()));
                // //this TetradMatrix needs to be the matrix of coefficients from the SEM!
                // if (!isStableTetradMatrix) {
                // System.out.println("%%%%%%%%%% WARNING %%%%%%%%% not a stable set of eigenvalues for data generation");
                // System.out.println("Skipping this attempt!");
                // sim.setCoefRange(0.2, 0.5);
                // dataSet = sim.simulateDataAcyclic(params.getSampleSize());
                // }
                // 
                // /***************************/
                boolean isStableTetradMatrix;
                int attempt = 1;
                int tierSize = params.getNumVars();
                int[] sub = new int[tierSize];
                int[] sub2 = new int[tierSize];
                for (int i = 0; i < tierSize; i++) {
                    sub[i] = i;
                    sub2[i] = tierSize + i;
                }
                do {
                    dataSet = sim.simulateDataFisher(params.getSampleSize());
                    // System.out.println("Variable Nodes : " + sim.getVariableNodes());
                    // System.out.println(MatrixUtils.toString(sim.getCoefficientMatrix()));
                    TetradMatrix coefMat = new TetradMatrix(sim.getCoefficientMatrix());
                    TetradMatrix B = coefMat.getSelection(sub, sub);
                    TetradMatrix Gamma1 = coefMat.getSelection(sub2, sub);
                    TetradMatrix Gamma0 = TetradMatrix.identity(tierSize).minus(B);
                    TetradMatrix A1 = Gamma0.inverse().times(Gamma1);
                    // TetradMatrix B2 = coefMat.getSelection(sub2, sub2);
                    // System.out.println("B matrix : " + B);
                    // System.out.println("B2 matrix : " + B2);
                    // System.out.println("Gamma1 matrix : " + Gamma1);
                    // isStableTetradMatrix = allEigenvaluesAreSmallerThanOneInModulus(new TetradMatrix(sim.getCoefficientMatrix()));
                    isStableTetradMatrix = TimeSeriesUtils.allEigenvaluesAreSmallerThanOneInModulus(A1);
                    System.out.println("isStableTetradMatrix? : " + isStableTetradMatrix);
                    attempt++;
                } while ((!isStableTetradMatrix) && attempt <= 5);
                if (!isStableTetradMatrix) {
                    System.out.println("%%%%%%%%%% WARNING %%%%%%%% not a stable coefficient matrix, forcing coefs to [0.15,0.3]");
                    System.out.println("Made " + (attempt - 1) + " attempts to get stable matrix.");
                    sim.setCoefRange(0.15, 0.3);
                    dataSet = sim.simulateDataFisher(params.getSampleSize());
                } else {
                    System.out.println("Coefficient matrix is stable.");
                }
            }
        } else if (params.getDataType() == ComparisonParameters.DataType.Discrete) {
            List<Node> nodes = new ArrayList<>();
            for (int i = 0; i < params.getNumVars(); i++) {
                nodes.add(new DiscreteVariable("X" + (i + 1), 3));
            }
            trueDag = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
            if (params.getDataType() == null) {
                throw new IllegalArgumentException("Data type not set or inferred.");
            }
            if (params.getSampleSize() == -1) {
                throw new IllegalArgumentException("Sample size not set.");
            }
            int[] tiers = new int[nodes.size()];
            for (int i = 0; i < nodes.size(); i++) {
                tiers[i] = i;
            }
            BayesPm pm = new BayesPm(trueDag, 3, 3);
            MlBayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM);
            dataSet = im.simulateData(params.getSampleSize(), false, tiers);
        } else {
            throw new IllegalArgumentException("Unrecognized data type.");
        }
        if (dataSet == null) {
            throw new IllegalArgumentException("No data set.");
        }
    }
    if (params.getIndependenceTest() == ComparisonParameters.IndependenceTestType.FisherZ) {
        if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Continuous) {
            throw new IllegalArgumentException("Data type previously set to something other than continuous.");
        }
        if (Double.isNaN(params.getAlpha())) {
            throw new IllegalArgumentException("Alpha not set.");
        }
        test = new IndTestFisherZ(dataSet, params.getAlpha());
        params.setDataType(ComparisonParameters.DataType.Continuous);
    } else if (params.getIndependenceTest() == ComparisonParameters.IndependenceTestType.ChiSquare) {
        if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Discrete) {
            throw new IllegalArgumentException("Data type previously set to something other than discrete.");
        }
        if (Double.isNaN(params.getAlpha())) {
            throw new IllegalArgumentException("Alpha not set.");
        }
        test = new IndTestChiSquare(dataSet, params.getAlpha());
        params.setDataType(ComparisonParameters.DataType.Discrete);
    }
    if (params.getScore() == ScoreType.SemBic) {
        if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Continuous) {
            throw new IllegalArgumentException("Data type previously set to something other than continuous.");
        }
        if (Double.isNaN(params.getPenaltyDiscount())) {
            throw new IllegalArgumentException("Penalty discount not set.");
        }
        SemBicScore semBicScore = new SemBicScore(new CovarianceMatrixOnTheFly(dataSet));
        semBicScore.setPenaltyDiscount(params.getPenaltyDiscount());
        score = semBicScore;
        params.setDataType(ComparisonParameters.DataType.Continuous);
    } else if (params.getScore() == ScoreType.BDeu) {
        if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Discrete) {
            throw new IllegalArgumentException("Data type previously set to something other than discrete.");
        }
        if (Double.isNaN(params.getSamplePrior())) {
            throw new IllegalArgumentException("Sample prior not set.");
        }
        if (Double.isNaN(params.getStructurePrior())) {
            throw new IllegalArgumentException("Structure prior not set.");
        }
        score = new BDeuScore(dataSet);
        ((BDeuScore) score).setSamplePrior(params.getSamplePrior());
        ((BDeuScore) score).setStructurePrior(params.getStructurePrior());
        params.setDataType(ComparisonParameters.DataType.Discrete);
        params.setDataType(ComparisonParameters.DataType.Discrete);
    }
    if (params.getAlgorithm() == null) {
        throw new IllegalArgumentException("Algorithm not set.");
    }
    long time1 = System.currentTimeMillis();
    if (params.getAlgorithm() == ComparisonParameters.Algorithm.PC) {
        if (test == null) {
            throw new IllegalArgumentException("Test not set.");
        }
        Pc search = new Pc(test);
        result.setResultGraph(search.search());
        result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.CPC) {
        if (test == null) {
            throw new IllegalArgumentException("Test not set.");
        }
        Cpc search = new Cpc(test);
        result.setResultGraph(search.search());
        result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCLocal) {
        if (test == null) {
            throw new IllegalArgumentException("Test not set.");
        }
        PcLocal search = new PcLocal(test);
        result.setResultGraph(search.search());
        result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCStableMax) {
        if (test == null) {
            throw new IllegalArgumentException("Test not set.");
        }
        PcStableMax search = new PcStableMax(test);
        result.setResultGraph(search.search());
        result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FGES) {
        if (score == null) {
            throw new IllegalArgumentException("Score not set.");
        }
        Fges search = new Fges(score);
        // search.setFaithfulnessAssumed(params.isOneEdgeFaithfulnessAssumed());
        result.setResultGraph(search.search());
        result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FCI) {
        if (test == null) {
            throw new IllegalArgumentException("Test not set.");
        }
        Fci search = new Fci(test);
        result.setResultGraph(search.search());
        result.setCorrectResult(new DagToPag(trueDag).convert());
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.GFCI) {
        if (test == null) {
            throw new IllegalArgumentException("Test not set.");
        }
        GFci search = new GFci(test, score);
        result.setResultGraph(search.search());
        result.setCorrectResult(new DagToPag(trueDag).convert());
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.TsFCI) {
        if (test == null) {
            throw new IllegalArgumentException("Test not set.");
        }
        TsFci search = new TsFci(test);
        IKnowledge knowledge = getKnowledge(trueDag);
        search.setKnowledge(knowledge);
        result.setResultGraph(search.search());
        result.setCorrectResult(new TsDagToPag(trueDag).convert());
    } else {
        throw new IllegalArgumentException("Unrecognized algorithm.");
    }
    long time2 = System.currentTimeMillis();
    long elapsed = time2 - time1;
    result.setElapsed(elapsed);
    result.setTrueDag(trueDag);
    return result;
}
Also used : MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) ArrayList(java.util.ArrayList) List(java.util.List) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) Graph(edu.cmu.tetrad.graph.Graph) TabularDataReader(edu.pitt.dbmi.data.reader.tabular.TabularDataReader) VerticalDiscreteTabularDataReader(edu.pitt.dbmi.data.reader.tabular.VerticalDiscreteTabularDataReader) LargeScaleSimulation(edu.cmu.tetrad.sem.LargeScaleSimulation) VerticalDiscreteTabularDataReader(edu.pitt.dbmi.data.reader.tabular.VerticalDiscreteTabularDataReader) Path(java.nio.file.Path) Delimiter(edu.pitt.dbmi.data.Delimiter) ContinuousTabularDataFileReader(edu.pitt.dbmi.data.reader.tabular.ContinuousTabularDataFileReader) TetradMatrix(edu.cmu.tetrad.util.TetradMatrix) BayesPm(edu.cmu.tetrad.bayes.BayesPm)

Example 10 with LargeScaleSimulation

use of edu.cmu.tetrad.sem.LargeScaleSimulation in project tetrad by cmu-phil.

the class PerformanceTests method testPcStable.

public void testPcStable(int numVars, double edgeFactor, int numCases, double alpha) {
    int depth = -1;
    init(new File("long.pcstable." + numVars + "." + edgeFactor + "." + alpha + ".txt"), "Tests performance of the PC Stable algorithm");
    long time1 = System.currentTimeMillis();
    Graph dag = makeDag(numVars, edgeFactor);
    System.out.println("Graph done");
    out.println("Graph done");
    System.out.println("Starting simulation");
    LargeScaleSimulation simulator = new LargeScaleSimulation(dag);
    simulator.setOut(out);
    DataSet data = simulator.simulateDataFisher(numCases);
    System.out.println("Finishing simulation");
    long time2 = System.currentTimeMillis();
    out.println("Elapsed (simulating the data): " + (time2 - time1) + " ms");
    System.out.println("Making covariance matrix");
    // ICovarianceMatrix cov = new CovarianceMatrix(data);
    ICovarianceMatrix cov = new CovarianceMatrixOnTheFly(data);
    // ICovarianceMatrix cov = new CorrelationMatrix(new CovarianceMatrix(data));
    // ICovarianceMatrix cov = DataUtils.covarianceParanormalDrton(data);
    // ICovarianceMatrix cov = new CovarianceMatrix(DataUtils.covarianceParanormalWasserman(data));
    // System.out.println(cov);
    System.out.println("Covariance matrix done");
    long time3 = System.currentTimeMillis();
    out.println("Elapsed (calculating cov): " + (time3 - time2) + " ms");
    // out.println(cov);
    IndTestFisherZ test = new IndTestFisherZ(cov, alpha);
    PcStable pcStable = new PcStable(test);
    // pcStable.setVerbose(false);
    // pcStable.setDepth(depth);
    // pcStable.setOut(out);
    Graph estPattern = pcStable.search();
    // out.println(estPattern);
    long time4 = System.currentTimeMillis();
    // out.println("# Vars = " + numVars);
    // out.println("# Edges = " + (int) (numVars * edgeFactor));
    out.println("# Cases = " + numCases);
    out.println("alpha = " + alpha);
    out.println("depth = " + depth);
    out.println("Elapsed (simulating the data): " + (time2 - time1) + " ms");
    out.println("Elapsed (calculating cov): " + (time3 - time2) + " ms");
    out.println("Elapsed (running PC-Stable) " + (time4 - time3) + " ms");
    out.println("Total elapsed (cov + PC-Stable) " + (time4 - time2) + " ms");
    final Graph truePattern = SearchGraphUtils.patternForDag(dag);
    System.out.println("# edges in true pattern = " + truePattern.getNumEdges());
    System.out.println("# edges in est pattern = " + estPattern.getNumEdges());
    SearchGraphUtils.graphComparison(estPattern, truePattern, out);
    out.println("seed = " + RandomUtil.getInstance().getSeed() + "L");
    out.close();
}
Also used : LargeScaleSimulation(edu.cmu.tetrad.sem.LargeScaleSimulation)

Aggregations

LargeScaleSimulation (edu.cmu.tetrad.sem.LargeScaleSimulation)21 Graph (edu.cmu.tetrad.graph.Graph)6 ArrayList (java.util.ArrayList)6 Test (org.junit.Test)5 BayesPm (edu.cmu.tetrad.bayes.BayesPm)4 MlBayesIm (edu.cmu.tetrad.bayes.MlBayesIm)4 DataSet (edu.cmu.tetrad.data.DataSet)4 Node (edu.cmu.tetrad.graph.Node)3 Parameters (edu.cmu.tetrad.util.Parameters)3 GeneralBootstrapTest (edu.pitt.dbmi.algo.bootstrap.GeneralBootstrapTest)3 Algorithm (edu.cmu.tetrad.algcomparison.algorithm.Algorithm)2 FisherZ (edu.cmu.tetrad.algcomparison.independence.FisherZ)2 IndependenceWrapper (edu.cmu.tetrad.algcomparison.independence.IndependenceWrapper)2 ScoreWrapper (edu.cmu.tetrad.algcomparison.score.ScoreWrapper)2 SemBicScore (edu.cmu.tetrad.algcomparison.score.SemBicScore)2 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)2 DagToPag (edu.cmu.tetrad.search.DagToPag)2 List (java.util.List)2 Fci (edu.cmu.tetrad.algcomparison.algorithm.oracle.pag.Fci)1 Gfci (edu.cmu.tetrad.algcomparison.algorithm.oracle.pag.Gfci)1