Search in sources :

Example 1 with SemBicScore

use of edu.cmu.tetrad.search.SemBicScore in project tetrad by cmu-phil.

the class TestAutisticClassification method testForBiwei.

// @Test
public void testForBiwei() {
    Parameters parameters = new Parameters();
    parameters.set("penaltyDiscount", 2);
    parameters.set("depth", -1);
    parameters.set("numRuns", 10);
    parameters.set("randomSelectionSize", 1);
    parameters.set("Structure", "Placeholder");
    FaskGraphs files = new FaskGraphs("/Users/jdramsey/Downloads/USM_ABIDE", new Parameters());
    List<DataSet> datasets = files.getDatasets();
    List<String> filenames = files.getFilenames();
    for (int i = 0; i < datasets.size(); i++) {
        DataSet dataSet = datasets.get(i);
        SemBicScore score = new SemBicScore(new CovarianceMatrixOnTheFly(dataSet));
        Fas fas = new Fas(new IndTestScore(score));
        Graph graph = fas.search();
        System.out.println(graph);
        List<Node> nodes = graph.getNodes();
        StringBuilder b = new StringBuilder();
        for (int j = 0; j < nodes.size(); j++) {
            for (int k = 0; k < nodes.size(); k++) {
                if (graph.isAdjacentTo(nodes.get(j), nodes.get(k))) {
                    b.append("1 ");
                } else {
                    b.append("0 ");
                }
            }
            b.append("\n");
        }
        try {
            File dir = new File("/Users/jdramsey/Downloads/biwei/USM_ABIDE");
            dir.mkdirs();
            File file = new File(dir, filenames.get(i) + ".graph.txt");
            PrintStream out = new PrintStream(file);
            out.println(b);
            out.close();
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        }
    }
}
Also used : PrintStream(java.io.PrintStream) Parameters(edu.cmu.tetrad.util.Parameters) DataSet(edu.cmu.tetrad.data.DataSet) IndTestScore(edu.cmu.tetrad.search.IndTestScore) Node(edu.cmu.tetrad.graph.Node) FileNotFoundException(java.io.FileNotFoundException) Graph(edu.cmu.tetrad.graph.Graph) Fas(edu.cmu.tetrad.search.Fas) CovarianceMatrixOnTheFly(edu.cmu.tetrad.data.CovarianceMatrixOnTheFly) File(java.io.File) SemBicScore(edu.cmu.tetrad.search.SemBicScore)

Example 2 with SemBicScore

use of edu.cmu.tetrad.search.SemBicScore in project tetrad by cmu-phil.

the class HsimEvalFromData method main.

public static void main(String[] args) {
    long timestart = System.nanoTime();
    System.out.println("Beginning Evaluation");
    String nl = System.lineSeparator();
    String output = "Simulation edu.cmu.tetrad.study output comparing Fsim and Hsim on predicting graph discovery accuracy" + nl;
    int iterations = 100;
    int vars = 20;
    int cases = 500;
    int edgeratio = 3;
    List<Integer> hsimRepeat = Arrays.asList(40);
    List<Integer> fsimRepeat = Arrays.asList(40);
    List<PRAOerrors>[] fsimErrsByPars = new ArrayList[fsimRepeat.size()];
    int whichFrepeat = 0;
    for (int frepeat : fsimRepeat) {
        fsimErrsByPars[whichFrepeat] = new ArrayList<PRAOerrors>();
        whichFrepeat++;
    }
    List<PRAOerrors>[][] hsimErrsByPars = new ArrayList[1][hsimRepeat.size()];
    // System.out.println(resimSize.size()+" "+hsimRepeat.size());
    int whichHrepeat;
    whichHrepeat = 0;
    for (int hrepeat : hsimRepeat) {
        // System.out.println(whichrsize+" "+whichHrepeat);
        hsimErrsByPars[0][whichHrepeat] = new ArrayList<PRAOerrors>();
        whichHrepeat++;
    }
    // !(*%(@!*^!($%!^ START ITERATING HERE !#$%(*$#@!^(*!$*%(!$#
    try {
        for (int iterate = 0; iterate < iterations; iterate++) {
            System.out.println("iteration " + iterate);
            // @#$%@$%^@$^@$^@%$%@$#^ LOADING THE DATA AND GRAPH @$#%%*#^##*^$#@%$
            DataSet data1;
            Graph graph1 = GraphUtils.loadGraphTxt(new File("graph/graph.1.txt"));
            Dag odag = new Dag(graph1);
            Set<String> eVars = new HashSet<String>();
            eVars.add("MULT");
            Path dataFile = Paths.get("data/data.1.txt");
            TabularDataReader dataReader = new ContinuousTabularDataFileReader(dataFile.toFile(), Delimiter.TAB);
            data1 = (DataSet) DataConvertUtils.toDataModel(dataReader.readInData(eVars));
            vars = data1.getNumColumns();
            cases = data1.getNumRows();
            edgeratio = 3;
            // !#@^$@&%^!#$!&@^ CALCULATING TARGET ERRORS $%$#@^@!%!#^$!%$#%
            ICovarianceMatrix newcov = new CovarianceMatrixOnTheFly(data1);
            SemBicScore oscore = new SemBicScore(newcov);
            Fges ofgs = new Fges(oscore);
            ofgs.setVerbose(false);
            ofgs.setNumPatternsToStore(0);
            // ***********This is the original FGS output on the data
            Graph oFGSGraph = ofgs.search();
            PRAOerrors oErrors = new PRAOerrors(HsimUtils.errorEval(oFGSGraph, odag), "target errors");
            // **then step 1: full resim. iterate through the combinations of estimator parameters (just repeat num)
            for (whichFrepeat = 0; whichFrepeat < fsimRepeat.size(); whichFrepeat++) {
                ArrayList<PRAOerrors> errorsList = new ArrayList<PRAOerrors>();
                for (int r = 0; r < fsimRepeat.get(whichFrepeat); r++) {
                    PatternToDag pickdag = new PatternToDag(oFGSGraph);
                    Graph fgsDag = pickdag.patternToDagMeek();
                    Dag fgsdag2 = new Dag(fgsDag);
                    // then fit an IM to this dag and the data. GeneralizedSemEstimator seems to bug out
                    // GeneralizedSemPm simSemPm = new GeneralizedSemPm(fgsdag2);
                    // GeneralizedSemEstimator gsemEstimator = new GeneralizedSemEstimator();
                    // GeneralizedSemIm fittedIM = gsemEstimator.estimate(simSemPm, oData);
                    SemPm simSemPm = new SemPm(fgsdag2);
                    // BayesPm simBayesPm = new BayesPm(fgsdag2, bayesPm);
                    SemEstimator simSemEstimator = new SemEstimator(data1, simSemPm);
                    SemIm fittedIM = simSemEstimator.estimate();
                    DataSet simData = fittedIM.simulateData(data1.getNumRows(), false);
                    // after making the full resim data (simData), run FGS on that
                    ICovarianceMatrix simcov = new CovarianceMatrixOnTheFly(simData);
                    SemBicScore simscore = new SemBicScore(simcov);
                    Fges simfgs = new Fges(simscore);
                    simfgs.setVerbose(false);
                    simfgs.setNumPatternsToStore(0);
                    Graph simGraphOut = simfgs.search();
                    PRAOerrors simErrors = new PRAOerrors(HsimUtils.errorEval(simGraphOut, fgsdag2), "Fsim errors " + r);
                    errorsList.add(simErrors);
                }
                PRAOerrors avErrors = new PRAOerrors(errorsList, "Average errors for Fsim at repeat=" + fsimRepeat.get(whichFrepeat));
                // if (verbosity>3) System.out.println(avErrors.allToString());
                // ****calculate the squared errors of prediction, store all these errors in a list
                double FsimAR2 = (avErrors.getAdjRecall() - oErrors.getAdjRecall()) * (avErrors.getAdjRecall() - oErrors.getAdjRecall());
                double FsimAP2 = (avErrors.getAdjPrecision() - oErrors.getAdjPrecision()) * (avErrors.getAdjPrecision() - oErrors.getAdjPrecision());
                double FsimOR2 = (avErrors.getOrientRecall() - oErrors.getOrientRecall()) * (avErrors.getOrientRecall() - oErrors.getOrientRecall());
                double FsimOP2 = (avErrors.getOrientPrecision() - oErrors.getOrientPrecision()) * (avErrors.getOrientPrecision() - oErrors.getOrientPrecision());
                PRAOerrors Fsim2 = new PRAOerrors(new double[] { FsimAR2, FsimAP2, FsimOR2, FsimOP2 }, "squared errors for Fsim at repeat=" + fsimRepeat.get(whichFrepeat));
                // add the fsim squared errors to the appropriate list
                fsimErrsByPars[whichFrepeat].add(Fsim2);
            }
            // **then step 2: hybrid sim. iterate through combos of params (repeat num, resimsize)
            for (whichHrepeat = 0; whichHrepeat < hsimRepeat.size(); whichHrepeat++) {
                HsimRepeatAC study = new HsimRepeatAC(data1);
                PRAOerrors HsimErrors = new PRAOerrors(study.run(1, hsimRepeat.get(whichHrepeat)), "Hsim errors" + "at rsize=" + 1 + " repeat=" + hsimRepeat.get(whichHrepeat));
                // ****calculate the squared errors of prediction
                double HsimAR2 = (HsimErrors.getAdjRecall() - oErrors.getAdjRecall()) * (HsimErrors.getAdjRecall() - oErrors.getAdjRecall());
                double HsimAP2 = (HsimErrors.getAdjPrecision() - oErrors.getAdjPrecision()) * (HsimErrors.getAdjPrecision() - oErrors.getAdjPrecision());
                double HsimOR2 = (HsimErrors.getOrientRecall() - oErrors.getOrientRecall()) * (HsimErrors.getOrientRecall() - oErrors.getOrientRecall());
                double HsimOP2 = (HsimErrors.getOrientPrecision() - oErrors.getOrientPrecision()) * (HsimErrors.getOrientPrecision() - oErrors.getOrientPrecision());
                PRAOerrors Hsim2 = new PRAOerrors(new double[] { HsimAR2, HsimAP2, HsimOR2, HsimOP2 }, "squared errors for Hsim, rsize=" + 1 + " repeat=" + hsimRepeat.get(whichHrepeat));
                hsimErrsByPars[0][whichHrepeat].add(Hsim2);
            }
        }
        // Average the squared errors for each set of fsim/hsim params across all iterations
        PRAOerrors[] fMSE = new PRAOerrors[fsimRepeat.size()];
        PRAOerrors[][] hMSE = new PRAOerrors[1][hsimRepeat.size()];
        String[][] latexTableArray = new String[1 * hsimRepeat.size() + fsimRepeat.size()][5];
        for (int j = 0; j < fMSE.length; j++) {
            fMSE[j] = new PRAOerrors(fsimErrsByPars[j], "MSE for Fsim at vars=" + vars + " edgeratio=" + edgeratio + " cases=" + cases + " frepeat=" + fsimRepeat.get(j) + " iterations=" + iterations);
            // if(verbosity>0){System.out.println(fMSE[j].allToString());}
            output = output + fMSE[j].allToString() + nl;
            latexTableArray[j] = prelimToPRAOtable(fMSE[j]);
        }
        for (int j = 0; j < hMSE.length; j++) {
            for (int k = 0; k < hMSE[j].length; k++) {
                hMSE[j][k] = new PRAOerrors(hsimErrsByPars[j][k], "MSE for Hsim at vars=" + vars + " edgeratio=" + edgeratio + " cases=" + cases + " rsize=" + 1 + " repeat=" + hsimRepeat.get(k) + " iterations=" + iterations);
                // if(verbosity>0){System.out.println(hMSE[j][k].allToString());}
                output = output + hMSE[j][k].allToString() + nl;
                latexTableArray[fsimRepeat.size() + j * hMSE[j].length + k] = prelimToPRAOtable(hMSE[j][k]);
            }
        }
        // record all the params, the base error values, and the fsim/hsim mean squared errors
        String latexTable = HsimUtils.makeLatexTable(latexTableArray);
        PrintWriter writer = new PrintWriter("latexTable.txt", "UTF-8");
        writer.println(latexTable);
        writer.close();
        PrintWriter writer2 = new PrintWriter("HvsF-SimulationEvaluation.txt", "UTF-8");
        writer2.println(output);
        writer2.close();
        long timestop = System.nanoTime();
        System.out.println("Evaluation Concluded. Duration: " + (timestop - timestart) / 1000000000 + "s");
    } catch (Exception IOException) {
        IOException.printStackTrace();
    }
}
Also used : TabularDataReader(edu.pitt.dbmi.data.reader.tabular.TabularDataReader) DataSet(edu.cmu.tetrad.data.DataSet) ICovarianceMatrix(edu.cmu.tetrad.data.ICovarianceMatrix) SemPm(edu.cmu.tetrad.sem.SemPm) SemEstimator(edu.cmu.tetrad.sem.SemEstimator) PrintWriter(java.io.PrintWriter) Path(java.nio.file.Path) PatternToDag(edu.cmu.tetrad.search.PatternToDag) ContinuousTabularDataFileReader(edu.pitt.dbmi.data.reader.tabular.ContinuousTabularDataFileReader) PatternToDag(edu.cmu.tetrad.search.PatternToDag) Dag(edu.cmu.tetrad.graph.Dag) Fges(edu.cmu.tetrad.search.Fges) Graph(edu.cmu.tetrad.graph.Graph) CovarianceMatrixOnTheFly(edu.cmu.tetrad.data.CovarianceMatrixOnTheFly) File(java.io.File) SemIm(edu.cmu.tetrad.sem.SemIm) SemBicScore(edu.cmu.tetrad.search.SemBicScore)

Example 3 with SemBicScore

use of edu.cmu.tetrad.search.SemBicScore in project tetrad by cmu-phil.

the class TestFges method explore1.

// private OutputStream out =
// @Test
public void explore1() {
    RandomUtil.getInstance().setSeed(1450184147770L);
    int numVars = 10;
    double edgesPerNode = 1.0;
    int numCases = 1000;
    double penaltyDiscount = 2.0;
    final int numEdges = (int) (numVars * edgesPerNode);
    List<Node> vars = new ArrayList<>();
    for (int i = 0; i < numVars; i++) {
        vars.add(new ContinuousVariable("X" + i));
    }
    Graph dag = GraphUtils.randomGraphRandomForwardEdges(vars, 0, numEdges, 30, 15, 15, false, true);
    // printDegreeDistribution(dag, System.out);
    int[] causalOrdering = new int[vars.size()];
    for (int i = 0; i < vars.size(); i++) {
        causalOrdering[i] = i;
    }
    LargeScaleSimulation simulator = new LargeScaleSimulation(dag, vars, causalOrdering);
    simulator.setOut(out);
    DataSet data = simulator.simulateDataFisher(numCases);
    // ICovarianceMatrix cov = new CovarianceMatrix(data);
    ICovarianceMatrix cov = new CovarianceMatrixOnTheFly(data);
    SemBicScore score = new SemBicScore(cov);
    score.setPenaltyDiscount(penaltyDiscount);
    Fges fges = new Fges(score);
    fges.setVerbose(false);
    fges.setNumPatternsToStore(0);
    fges.setOut(out);
    fges.setFaithfulnessAssumed(true);
    // fges.setMaxIndegree(1);
    fges.setCycleBound(5);
    Graph estPattern = fges.search();
    // printDegreeDistribution(estPattern, out);
    final Graph truePattern = SearchGraphUtils.patternForDag(dag);
    int[][] counts = SearchGraphUtils.graphComparison(estPattern, truePattern, null);
    int[][] expectedCounts = { { 2, 0, 0, 0, 0, 0 }, { 0, 0, 0, 0, 0, 0 }, { 0, 0, 0, 0, 0, 0 }, { 0, 0, 0, 0, 0, 0 }, { 0, 0, 0, 8, 0, 0 }, { 0, 0, 0, 0, 0, 0 }, { 0, 0, 0, 0, 0, 0 }, { 0, 0, 0, 0, 0, 0 } };
    for (int i = 0; i < counts.length; i++) {
        assertTrue(Arrays.equals(counts[i], expectedCounts[i]));
    }
// 
// System.out.println(MatrixUtils.toString(expectedCounts));
// System.out.println(MatrixUtils.toString(counts));
}
Also used : Fges(edu.cmu.tetrad.search.Fges) RandomGraph(edu.cmu.tetrad.algcomparison.graph.RandomGraph) SemBicScore(edu.cmu.tetrad.search.SemBicScore)

Example 4 with SemBicScore

use of edu.cmu.tetrad.search.SemBicScore in project tetrad by cmu-phil.

the class TestFges method searchSemFges.

private Graph searchSemFges(DataSet Dk, double penalty) {
    Dk = DataUtils.convertNumericalDiscreteToContinuous(Dk);
    SemBicScore score = new SemBicScore(new CovarianceMatrixOnTheFly(Dk));
    score.setPenaltyDiscount(penalty);
    Fges fges = new Fges(score);
    return fges.search();
}
Also used : Fges(edu.cmu.tetrad.search.Fges) SemBicScore(edu.cmu.tetrad.search.SemBicScore)

Example 5 with SemBicScore

use of edu.cmu.tetrad.search.SemBicScore in project tetrad by cmu-phil.

the class MixedFgesTreatingDiscreteAsContinuous method search.

public Graph search(DataModel Dk, Parameters parameters) {
    if (parameters.getInt("bootstrapSampleSize") < 1) {
        DataSet mixedDataSet = DataUtils.getMixedDataSet(Dk);
        mixedDataSet = DataUtils.convertNumericalDiscreteToContinuous(mixedDataSet);
        SemBicScore score = new SemBicScore(new CovarianceMatrixOnTheFly(mixedDataSet));
        score.setPenaltyDiscount(parameters.getDouble("penaltyDiscount"));
        Fges fges = new Fges(score);
        Graph p = fges.search();
        return convertBack(mixedDataSet, p);
    } else {
        MixedFgesTreatingDiscreteAsContinuous algorithm = new MixedFgesTreatingDiscreteAsContinuous();
        DataSet data = (DataSet) Dk;
        GeneralBootstrapTest search = new GeneralBootstrapTest(data, algorithm, parameters.getInt("bootstrapSampleSize"));
        BootstrapEdgeEnsemble edgeEnsemble = BootstrapEdgeEnsemble.Highest;
        switch(parameters.getInt("bootstrapEnsemble", 1)) {
            case 0:
                edgeEnsemble = BootstrapEdgeEnsemble.Preserved;
                break;
            case 1:
                edgeEnsemble = BootstrapEdgeEnsemble.Highest;
                break;
            case 2:
                edgeEnsemble = BootstrapEdgeEnsemble.Majority;
        }
        search.setEdgeEnsemble(edgeEnsemble);
        search.setParameters(parameters);
        search.setVerbose(parameters.getBoolean("verbose"));
        return search.search();
    }
}
Also used : EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) Graph(edu.cmu.tetrad.graph.Graph) GeneralBootstrapTest(edu.pitt.dbmi.algo.bootstrap.GeneralBootstrapTest) BootstrapEdgeEnsemble(edu.pitt.dbmi.algo.bootstrap.BootstrapEdgeEnsemble) Fges(edu.cmu.tetrad.search.Fges) SemBicScore(edu.cmu.tetrad.search.SemBicScore)

Aggregations

SemBicScore (edu.cmu.tetrad.search.SemBicScore)9 Fges (edu.cmu.tetrad.search.Fges)8 RandomGraph (edu.cmu.tetrad.algcomparison.graph.RandomGraph)4 Graph (edu.cmu.tetrad.graph.Graph)4 SemBicDTest (edu.cmu.tetrad.algcomparison.independence.SemBicDTest)2 SemBicTest (edu.cmu.tetrad.algcomparison.independence.SemBicTest)2 CovarianceMatrixOnTheFly (edu.cmu.tetrad.data.CovarianceMatrixOnTheFly)2 DataSet (edu.cmu.tetrad.data.DataSet)2 File (java.io.File)2 Test (org.junit.Test)2 ICovarianceMatrix (edu.cmu.tetrad.data.ICovarianceMatrix)1 Dag (edu.cmu.tetrad.graph.Dag)1 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)1 Node (edu.cmu.tetrad.graph.Node)1 Fas (edu.cmu.tetrad.search.Fas)1 IndTestScore (edu.cmu.tetrad.search.IndTestScore)1 PatternToDag (edu.cmu.tetrad.search.PatternToDag)1 SemEstimator (edu.cmu.tetrad.sem.SemEstimator)1 SemIm (edu.cmu.tetrad.sem.SemIm)1 SemPm (edu.cmu.tetrad.sem.SemPm)1