Search in sources :

Example 1 with CovarianceMatrixOnTheFly

use of edu.cmu.tetrad.data.CovarianceMatrixOnTheFly in project tetrad by cmu-phil.

the class TestAutisticClassification method testForBiwei.

// @Test
public void testForBiwei() {
    Parameters parameters = new Parameters();
    parameters.set("penaltyDiscount", 2);
    parameters.set("depth", -1);
    parameters.set("numRuns", 10);
    parameters.set("randomSelectionSize", 1);
    parameters.set("Structure", "Placeholder");
    FaskGraphs files = new FaskGraphs("/Users/jdramsey/Downloads/USM_ABIDE", new Parameters());
    List<DataSet> datasets = files.getDatasets();
    List<String> filenames = files.getFilenames();
    for (int i = 0; i < datasets.size(); i++) {
        DataSet dataSet = datasets.get(i);
        SemBicScore score = new SemBicScore(new CovarianceMatrixOnTheFly(dataSet));
        Fas fas = new Fas(new IndTestScore(score));
        Graph graph = fas.search();
        System.out.println(graph);
        List<Node> nodes = graph.getNodes();
        StringBuilder b = new StringBuilder();
        for (int j = 0; j < nodes.size(); j++) {
            for (int k = 0; k < nodes.size(); k++) {
                if (graph.isAdjacentTo(nodes.get(j), nodes.get(k))) {
                    b.append("1 ");
                } else {
                    b.append("0 ");
                }
            }
            b.append("\n");
        }
        try {
            File dir = new File("/Users/jdramsey/Downloads/biwei/USM_ABIDE");
            dir.mkdirs();
            File file = new File(dir, filenames.get(i) + ".graph.txt");
            PrintStream out = new PrintStream(file);
            out.println(b);
            out.close();
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        }
    }
}
Also used : PrintStream(java.io.PrintStream) Parameters(edu.cmu.tetrad.util.Parameters) DataSet(edu.cmu.tetrad.data.DataSet) IndTestScore(edu.cmu.tetrad.search.IndTestScore) Node(edu.cmu.tetrad.graph.Node) FileNotFoundException(java.io.FileNotFoundException) Graph(edu.cmu.tetrad.graph.Graph) Fas(edu.cmu.tetrad.search.Fas) CovarianceMatrixOnTheFly(edu.cmu.tetrad.data.CovarianceMatrixOnTheFly) File(java.io.File) SemBicScore(edu.cmu.tetrad.search.SemBicScore)

Example 2 with CovarianceMatrixOnTheFly

use of edu.cmu.tetrad.data.CovarianceMatrixOnTheFly in project tetrad by cmu-phil.

the class HsimEvalFromData method main.

public static void main(String[] args) {
    long timestart = System.nanoTime();
    System.out.println("Beginning Evaluation");
    String nl = System.lineSeparator();
    String output = "Simulation edu.cmu.tetrad.study output comparing Fsim and Hsim on predicting graph discovery accuracy" + nl;
    int iterations = 100;
    int vars = 20;
    int cases = 500;
    int edgeratio = 3;
    List<Integer> hsimRepeat = Arrays.asList(40);
    List<Integer> fsimRepeat = Arrays.asList(40);
    List<PRAOerrors>[] fsimErrsByPars = new ArrayList[fsimRepeat.size()];
    int whichFrepeat = 0;
    for (int frepeat : fsimRepeat) {
        fsimErrsByPars[whichFrepeat] = new ArrayList<PRAOerrors>();
        whichFrepeat++;
    }
    List<PRAOerrors>[][] hsimErrsByPars = new ArrayList[1][hsimRepeat.size()];
    // System.out.println(resimSize.size()+" "+hsimRepeat.size());
    int whichHrepeat;
    whichHrepeat = 0;
    for (int hrepeat : hsimRepeat) {
        // System.out.println(whichrsize+" "+whichHrepeat);
        hsimErrsByPars[0][whichHrepeat] = new ArrayList<PRAOerrors>();
        whichHrepeat++;
    }
    // !(*%(@!*^!($%!^ START ITERATING HERE !#$%(*$#@!^(*!$*%(!$#
    try {
        for (int iterate = 0; iterate < iterations; iterate++) {
            System.out.println("iteration " + iterate);
            // @#$%@$%^@$^@$^@%$%@$#^ LOADING THE DATA AND GRAPH @$#%%*#^##*^$#@%$
            DataSet data1;
            Graph graph1 = GraphUtils.loadGraphTxt(new File("graph/graph.1.txt"));
            Dag odag = new Dag(graph1);
            Set<String> eVars = new HashSet<String>();
            eVars.add("MULT");
            Path dataFile = Paths.get("data/data.1.txt");
            TabularDataReader dataReader = new ContinuousTabularDataFileReader(dataFile.toFile(), Delimiter.TAB);
            data1 = (DataSet) DataConvertUtils.toDataModel(dataReader.readInData(eVars));
            vars = data1.getNumColumns();
            cases = data1.getNumRows();
            edgeratio = 3;
            // !#@^$@&%^!#$!&@^ CALCULATING TARGET ERRORS $%$#@^@!%!#^$!%$#%
            ICovarianceMatrix newcov = new CovarianceMatrixOnTheFly(data1);
            SemBicScore oscore = new SemBicScore(newcov);
            Fges ofgs = new Fges(oscore);
            ofgs.setVerbose(false);
            ofgs.setNumPatternsToStore(0);
            // ***********This is the original FGS output on the data
            Graph oFGSGraph = ofgs.search();
            PRAOerrors oErrors = new PRAOerrors(HsimUtils.errorEval(oFGSGraph, odag), "target errors");
            // **then step 1: full resim. iterate through the combinations of estimator parameters (just repeat num)
            for (whichFrepeat = 0; whichFrepeat < fsimRepeat.size(); whichFrepeat++) {
                ArrayList<PRAOerrors> errorsList = new ArrayList<PRAOerrors>();
                for (int r = 0; r < fsimRepeat.get(whichFrepeat); r++) {
                    PatternToDag pickdag = new PatternToDag(oFGSGraph);
                    Graph fgsDag = pickdag.patternToDagMeek();
                    Dag fgsdag2 = new Dag(fgsDag);
                    // then fit an IM to this dag and the data. GeneralizedSemEstimator seems to bug out
                    // GeneralizedSemPm simSemPm = new GeneralizedSemPm(fgsdag2);
                    // GeneralizedSemEstimator gsemEstimator = new GeneralizedSemEstimator();
                    // GeneralizedSemIm fittedIM = gsemEstimator.estimate(simSemPm, oData);
                    SemPm simSemPm = new SemPm(fgsdag2);
                    // BayesPm simBayesPm = new BayesPm(fgsdag2, bayesPm);
                    SemEstimator simSemEstimator = new SemEstimator(data1, simSemPm);
                    SemIm fittedIM = simSemEstimator.estimate();
                    DataSet simData = fittedIM.simulateData(data1.getNumRows(), false);
                    // after making the full resim data (simData), run FGS on that
                    ICovarianceMatrix simcov = new CovarianceMatrixOnTheFly(simData);
                    SemBicScore simscore = new SemBicScore(simcov);
                    Fges simfgs = new Fges(simscore);
                    simfgs.setVerbose(false);
                    simfgs.setNumPatternsToStore(0);
                    Graph simGraphOut = simfgs.search();
                    PRAOerrors simErrors = new PRAOerrors(HsimUtils.errorEval(simGraphOut, fgsdag2), "Fsim errors " + r);
                    errorsList.add(simErrors);
                }
                PRAOerrors avErrors = new PRAOerrors(errorsList, "Average errors for Fsim at repeat=" + fsimRepeat.get(whichFrepeat));
                // if (verbosity>3) System.out.println(avErrors.allToString());
                // ****calculate the squared errors of prediction, store all these errors in a list
                double FsimAR2 = (avErrors.getAdjRecall() - oErrors.getAdjRecall()) * (avErrors.getAdjRecall() - oErrors.getAdjRecall());
                double FsimAP2 = (avErrors.getAdjPrecision() - oErrors.getAdjPrecision()) * (avErrors.getAdjPrecision() - oErrors.getAdjPrecision());
                double FsimOR2 = (avErrors.getOrientRecall() - oErrors.getOrientRecall()) * (avErrors.getOrientRecall() - oErrors.getOrientRecall());
                double FsimOP2 = (avErrors.getOrientPrecision() - oErrors.getOrientPrecision()) * (avErrors.getOrientPrecision() - oErrors.getOrientPrecision());
                PRAOerrors Fsim2 = new PRAOerrors(new double[] { FsimAR2, FsimAP2, FsimOR2, FsimOP2 }, "squared errors for Fsim at repeat=" + fsimRepeat.get(whichFrepeat));
                // add the fsim squared errors to the appropriate list
                fsimErrsByPars[whichFrepeat].add(Fsim2);
            }
            // **then step 2: hybrid sim. iterate through combos of params (repeat num, resimsize)
            for (whichHrepeat = 0; whichHrepeat < hsimRepeat.size(); whichHrepeat++) {
                HsimRepeatAC study = new HsimRepeatAC(data1);
                PRAOerrors HsimErrors = new PRAOerrors(study.run(1, hsimRepeat.get(whichHrepeat)), "Hsim errors" + "at rsize=" + 1 + " repeat=" + hsimRepeat.get(whichHrepeat));
                // ****calculate the squared errors of prediction
                double HsimAR2 = (HsimErrors.getAdjRecall() - oErrors.getAdjRecall()) * (HsimErrors.getAdjRecall() - oErrors.getAdjRecall());
                double HsimAP2 = (HsimErrors.getAdjPrecision() - oErrors.getAdjPrecision()) * (HsimErrors.getAdjPrecision() - oErrors.getAdjPrecision());
                double HsimOR2 = (HsimErrors.getOrientRecall() - oErrors.getOrientRecall()) * (HsimErrors.getOrientRecall() - oErrors.getOrientRecall());
                double HsimOP2 = (HsimErrors.getOrientPrecision() - oErrors.getOrientPrecision()) * (HsimErrors.getOrientPrecision() - oErrors.getOrientPrecision());
                PRAOerrors Hsim2 = new PRAOerrors(new double[] { HsimAR2, HsimAP2, HsimOR2, HsimOP2 }, "squared errors for Hsim, rsize=" + 1 + " repeat=" + hsimRepeat.get(whichHrepeat));
                hsimErrsByPars[0][whichHrepeat].add(Hsim2);
            }
        }
        // Average the squared errors for each set of fsim/hsim params across all iterations
        PRAOerrors[] fMSE = new PRAOerrors[fsimRepeat.size()];
        PRAOerrors[][] hMSE = new PRAOerrors[1][hsimRepeat.size()];
        String[][] latexTableArray = new String[1 * hsimRepeat.size() + fsimRepeat.size()][5];
        for (int j = 0; j < fMSE.length; j++) {
            fMSE[j] = new PRAOerrors(fsimErrsByPars[j], "MSE for Fsim at vars=" + vars + " edgeratio=" + edgeratio + " cases=" + cases + " frepeat=" + fsimRepeat.get(j) + " iterations=" + iterations);
            // if(verbosity>0){System.out.println(fMSE[j].allToString());}
            output = output + fMSE[j].allToString() + nl;
            latexTableArray[j] = prelimToPRAOtable(fMSE[j]);
        }
        for (int j = 0; j < hMSE.length; j++) {
            for (int k = 0; k < hMSE[j].length; k++) {
                hMSE[j][k] = new PRAOerrors(hsimErrsByPars[j][k], "MSE for Hsim at vars=" + vars + " edgeratio=" + edgeratio + " cases=" + cases + " rsize=" + 1 + " repeat=" + hsimRepeat.get(k) + " iterations=" + iterations);
                // if(verbosity>0){System.out.println(hMSE[j][k].allToString());}
                output = output + hMSE[j][k].allToString() + nl;
                latexTableArray[fsimRepeat.size() + j * hMSE[j].length + k] = prelimToPRAOtable(hMSE[j][k]);
            }
        }
        // record all the params, the base error values, and the fsim/hsim mean squared errors
        String latexTable = HsimUtils.makeLatexTable(latexTableArray);
        PrintWriter writer = new PrintWriter("latexTable.txt", "UTF-8");
        writer.println(latexTable);
        writer.close();
        PrintWriter writer2 = new PrintWriter("HvsF-SimulationEvaluation.txt", "UTF-8");
        writer2.println(output);
        writer2.close();
        long timestop = System.nanoTime();
        System.out.println("Evaluation Concluded. Duration: " + (timestop - timestart) / 1000000000 + "s");
    } catch (Exception IOException) {
        IOException.printStackTrace();
    }
}
Also used : TabularDataReader(edu.pitt.dbmi.data.reader.tabular.TabularDataReader) DataSet(edu.cmu.tetrad.data.DataSet) ICovarianceMatrix(edu.cmu.tetrad.data.ICovarianceMatrix) SemPm(edu.cmu.tetrad.sem.SemPm) SemEstimator(edu.cmu.tetrad.sem.SemEstimator) PrintWriter(java.io.PrintWriter) Path(java.nio.file.Path) PatternToDag(edu.cmu.tetrad.search.PatternToDag) ContinuousTabularDataFileReader(edu.pitt.dbmi.data.reader.tabular.ContinuousTabularDataFileReader) PatternToDag(edu.cmu.tetrad.search.PatternToDag) Dag(edu.cmu.tetrad.graph.Dag) Fges(edu.cmu.tetrad.search.Fges) Graph(edu.cmu.tetrad.graph.Graph) CovarianceMatrixOnTheFly(edu.cmu.tetrad.data.CovarianceMatrixOnTheFly) File(java.io.File) SemIm(edu.cmu.tetrad.sem.SemIm) SemBicScore(edu.cmu.tetrad.search.SemBicScore)

Example 3 with CovarianceMatrixOnTheFly

use of edu.cmu.tetrad.data.CovarianceMatrixOnTheFly in project tetrad by cmu-phil.

the class TestSimulatedFmri method testClark.

// @Test
public void testClark() {
    double f = .1;
    int N = 512;
    double alpha = 1.0;
    double penaltyDiscount = 1.0;
    for (int i = 0; i < 100; i++) {
        {
            Node x = new ContinuousVariable("X");
            Node y = new ContinuousVariable("Y");
            Node z = new ContinuousVariable("Z");
            Graph g = new EdgeListGraph();
            g.addNode(x);
            g.addNode(y);
            g.addNode(z);
            g.addDirectedEdge(x, y);
            g.addDirectedEdge(z, x);
            g.addDirectedEdge(z, y);
            GeneralizedSemPm pm = new GeneralizedSemPm(g);
            try {
                pm.setNodeExpression(g.getNode("X"), "0.5 * Z + E_X");
                pm.setNodeExpression(g.getNode("Y"), "0.5 * X + 0.5 * Z + E_Y");
                pm.setNodeExpression(g.getNode("Z"), "E_Z");
                String error = "pow(Uniform(0, 1), " + f + ")";
                pm.setNodeExpression(pm.getErrorNode(g.getNode("X")), error);
                pm.setNodeExpression(pm.getErrorNode(g.getNode("Y")), error);
                pm.setNodeExpression(pm.getErrorNode(g.getNode("Z")), error);
            } catch (ParseException e) {
                System.out.println(e);
            }
            GeneralizedSemIm im = new GeneralizedSemIm(pm);
            DataSet data = im.simulateData(N, false);
            edu.cmu.tetrad.search.SemBicScore score = new edu.cmu.tetrad.search.SemBicScore(new CovarianceMatrixOnTheFly(data, false));
            score.setPenaltyDiscount(penaltyDiscount);
            Fask fask = new Fask(data, score);
            fask.setPenaltyDiscount(penaltyDiscount);
            fask.setAlpha(alpha);
            Graph out = fask.search();
            System.out.println(out);
        }
        {
            Node x = new ContinuousVariable("X");
            Node y = new ContinuousVariable("Y");
            Node z = new ContinuousVariable("Z");
            Graph g = new EdgeListGraph();
            g.addNode(x);
            g.addNode(y);
            g.addNode(z);
            g.addDirectedEdge(x, y);
            g.addDirectedEdge(x, z);
            g.addDirectedEdge(y, z);
            GeneralizedSemPm pm = new GeneralizedSemPm(g);
            try {
                pm.setNodeExpression(g.getNode("X"), "E_X");
                pm.setNodeExpression(g.getNode("Y"), "0.4 * X + E_Y");
                pm.setNodeExpression(g.getNode("Z"), "0.4 * X + 0.4 * Y + E_Z");
                String error = "pow(Uniform(0, 1), " + f + ")";
                pm.setNodeExpression(pm.getErrorNode(g.getNode("X")), error);
                pm.setNodeExpression(pm.getErrorNode(g.getNode("Y")), error);
                pm.setNodeExpression(pm.getErrorNode(g.getNode("Z")), error);
            } catch (ParseException e) {
                System.out.println(e);
            }
            GeneralizedSemIm im = new GeneralizedSemIm(pm);
            DataSet data = im.simulateData(N, false);
            edu.cmu.tetrad.search.SemBicScore score = new edu.cmu.tetrad.search.SemBicScore(new CovarianceMatrixOnTheFly(data, false));
            score.setPenaltyDiscount(penaltyDiscount);
            Fask fask = new Fask(data, score);
            fask.setPenaltyDiscount(penaltyDiscount);
            fask.setAlpha(alpha);
            Graph out = fask.search();
            System.out.println(out);
        }
    }
}
Also used : DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) Fask(edu.cmu.tetrad.search.Fask) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) Graph(edu.cmu.tetrad.graph.Graph) GeneralizedSemIm(edu.cmu.tetrad.sem.GeneralizedSemIm) ParseException(java.text.ParseException) CovarianceMatrixOnTheFly(edu.cmu.tetrad.data.CovarianceMatrixOnTheFly) GeneralizedSemPm(edu.cmu.tetrad.sem.GeneralizedSemPm) SemBicScore(edu.cmu.tetrad.algcomparison.score.SemBicScore)

Example 4 with CovarianceMatrixOnTheFly

use of edu.cmu.tetrad.data.CovarianceMatrixOnTheFly in project tetrad by cmu-phil.

the class TestLingamPattern method test1.

@Test
public void test1() {
    RandomUtil.getInstance().setSeed(4938492L);
    int sampleSize = 1000;
    List<Node> nodes = new ArrayList<>();
    for (int i = 0; i < 6; i++) {
        nodes.add(new ContinuousVariable("X" + (i + 1)));
    }
    Graph graph = new Dag(GraphUtils.randomGraph(nodes, 0, 6, 4, 4, 4, false));
    List<Distribution> variableDistributions = new ArrayList<>();
    variableDistributions.add(new Normal(0, 1));
    variableDistributions.add(new Normal(0, 1));
    variableDistributions.add(new Normal(0, 1));
    variableDistributions.add(new Uniform(-1, 1));
    variableDistributions.add(new Normal(0, 1));
    variableDistributions.add(new Normal(0, 1));
    SemPm semPm = new SemPm(graph);
    SemIm semIm = new SemIm(semPm);
    DataSet dataSet = simulateDataNonNormal(semIm, sampleSize, variableDistributions);
    Score score = new SemBicScore(new CovarianceMatrixOnTheFly(dataSet));
    Graph estPattern = new Fges(score).search();
    LingamPattern lingam = new LingamPattern(estPattern, dataSet);
    lingam.search();
    double[] pvals = lingam.getPValues();
    double[] expectedPVals = { 0.18, 0.29, 0.88, 0.00, 0.01, 0.58 };
    for (int i = 0; i < pvals.length; i++) {
        assertEquals(expectedPVals[i], pvals[i], 0.01);
    }
}
Also used : ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) DataSet(edu.cmu.tetrad.data.DataSet) Uniform(edu.cmu.tetrad.util.dist.Uniform) Normal(edu.cmu.tetrad.util.dist.Normal) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) Distribution(edu.cmu.tetrad.util.dist.Distribution) SemPm(edu.cmu.tetrad.sem.SemPm) CovarianceMatrixOnTheFly(edu.cmu.tetrad.data.CovarianceMatrixOnTheFly) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test)

Example 5 with CovarianceMatrixOnTheFly

use of edu.cmu.tetrad.data.CovarianceMatrixOnTheFly in project tetrad by cmu-phil.

the class MbfsRunner method execute.

// =================PUBLIC METHODS OVERRIDING ABSTRACT=================//
/**
 * Executes the algorithm, producing (at least) a result workbench. Must be
 * implemented in the extending class.
 */
public void execute() {
    // int pcDepth = ((Parameters) getParameters()).getMaxIndegree();
    // Mbfs mbfs = new Mbfs(getIndependenceTest(), pcDepth);
    // Parameters params = getParameters();
    // if (params instanceof Parameters) {
    // mbfs.setAggressivelyPreventCycles(((Parameters) params)
    // .isAggressivelyPreventCycles());
    // }
    IKnowledge knowledge = (IKnowledge) getParams().get("knowledge", new Knowledge2());
    // mbfs.setKnowledge(knowledge);
    String targetName = getParams().getString("targetName", null);
    // Graph searchGraph = mbfs.search(targetName);
    // setResultGraph(searchGraph);
    DataSet dataSet = (DataSet) getDataModelList().get(0);
    SemBicScore score = new SemBicScore(new CovarianceMatrixOnTheFly(dataSet));
    score.setPenaltyDiscount(getParams().getDouble("alpha", 0.001));
    FgesMb search = new FgesMb(score);
    search.setFaithfulnessAssumed(true);
    Graph searchGraph = search.search(dataSet.getVariable(targetName));
    if (getSourceGraph() != null) {
        GraphUtils.arrangeBySourceGraph(searchGraph, getSourceGraph());
    } else if (knowledge.isDefaultToKnowledgeLayout()) {
        SearchGraphUtils.arrangeByKnowledgeTiers(searchGraph, knowledge);
    } else {
        GraphUtils.circleLayout(searchGraph, 200, 200, 150);
    }
    // this.mbfs = mbfs;
    setResultGraph(searchGraph);
}
Also used : IKnowledge(edu.cmu.tetrad.data.IKnowledge) Graph(edu.cmu.tetrad.graph.Graph) DataSet(edu.cmu.tetrad.data.DataSet) Knowledge2(edu.cmu.tetrad.data.Knowledge2) CovarianceMatrixOnTheFly(edu.cmu.tetrad.data.CovarianceMatrixOnTheFly)

Aggregations

CovarianceMatrixOnTheFly (edu.cmu.tetrad.data.CovarianceMatrixOnTheFly)7 DataSet (edu.cmu.tetrad.data.DataSet)6 Graph (edu.cmu.tetrad.graph.Graph)6 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)3 Node (edu.cmu.tetrad.graph.Node)3 SemBicScore (edu.cmu.tetrad.algcomparison.score.SemBicScore)2 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)2 Fask (edu.cmu.tetrad.search.Fask)2 SemBicScore (edu.cmu.tetrad.search.SemBicScore)2 GeneralizedSemIm (edu.cmu.tetrad.sem.GeneralizedSemIm)2 GeneralizedSemPm (edu.cmu.tetrad.sem.GeneralizedSemPm)2 SemIm (edu.cmu.tetrad.sem.SemIm)2 SemPm (edu.cmu.tetrad.sem.SemPm)2 File (java.io.File)2 ParseException (java.text.ParseException)2 ColtDataSet (edu.cmu.tetrad.data.ColtDataSet)1 ICovarianceMatrix (edu.cmu.tetrad.data.ICovarianceMatrix)1 IKnowledge (edu.cmu.tetrad.data.IKnowledge)1 Knowledge2 (edu.cmu.tetrad.data.Knowledge2)1 Dag (edu.cmu.tetrad.graph.Dag)1