Search in sources :

Example 11 with Dag

use of edu.cmu.tetrad.graph.Dag in project tetrad by cmu-phil.

the class TestDataSetCellProbs method testCreateUsingBayesIm.

@Test
public void testCreateUsingBayesIm() {
    Graph graph = GraphConverter.convert("X1-->X2,X1-->X3,X2-->X4,X3-->X4");
    Dag dag = new Dag(graph);
    BayesPm bayesPm = new BayesPm(dag);
    BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
    BayesImProbs bayesImProbs = new BayesImProbs(bayesIm);
    DataSet dataSet = bayesIm.simulateData(1000, false);
    CellTableProbs dataSetProbs = new CellTableProbs(dataSet);
    int[] cell = new int[4];
    for (int i = 0; i < 200; i++) {
        for (int j = 0; j < 4; j++) {
            cell[j] = RandomUtil.getInstance().nextInt(bayesIm.getNumColumns(j));
        }
        double count1 = bayesImProbs.getCellProb(cell);
        double count2 = dataSetProbs.getCellProb(cell);
        assertEquals(count1, count2, .05);
    }
}
Also used : Graph(edu.cmu.tetrad.graph.Graph) DataSet(edu.cmu.tetrad.data.DataSet) Dag(edu.cmu.tetrad.graph.Dag) Test(org.junit.Test)

Example 12 with Dag

use of edu.cmu.tetrad.graph.Dag in project tetrad by cmu-phil.

the class SemEstimatorWrapper method serializableInstance.

// public SemEstimatorWrapper(DataWrapper dataWrapper,
// SemPmWrapper semPmWrapper,
// SemImWrapper semImWrapper,
// Parameters params) {
// if (dataWrapper == null) {
// throw new NullPointerException();
// }
// 
// if (semPmWrapper == null) {
// throw new NullPointerException();
// }
// 
// if (semImWrapper == null) {
// throw new NullPointerException();
// }
// 
// DataSet dataSet =
// (DataSet) dataWrapper.getSelectedDataModel();
// SemPm semPm = semPmWrapper.getSemPm();
// SemIm semIm = semImWrapper.getSemIm();
// 
// this.semEstimator = new SemEstimator(dataSet, semPm, getOptimizer());
// if (!degreesOfFreedomCheck(semPm)) return;
// this.semEstimator.setTrueSemIm(semIm);
// this.semEstimator.setNumRestarts(getParams().getInt("numRestarts", 1));
// this.semEstimator.estimate();
// 
// this.params = params;
// 
// log();
// }
/**
 * Generates a simple exemplar of this class to test serialization.
 *
 * @see TetradSerializableUtils
 */
public static SemEstimatorWrapper serializableInstance() {
    List<Node> variables = new LinkedList<>();
    ContinuousVariable x = new ContinuousVariable("X");
    variables.add(x);
    DataSet dataSet = new ColtDataSet(10, variables);
    for (int i = 0; i < dataSet.getNumRows(); i++) {
        for (int j = 0; j < dataSet.getNumColumns(); j++) {
            dataSet.setDouble(i, j, RandomUtil.getInstance().nextDouble());
        }
    }
    Dag dag = new Dag();
    dag.addNode(x);
    SemPm pm = new SemPm(dag);
    Parameters params1 = new Parameters();
    return new SemEstimatorWrapper(dataSet, pm, params1);
}
Also used : ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) Parameters(edu.cmu.tetrad.util.Parameters) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) SemPm(edu.cmu.tetrad.sem.SemPm) Dag(edu.cmu.tetrad.graph.Dag) LinkedList(java.util.LinkedList)

Example 13 with Dag

use of edu.cmu.tetrad.graph.Dag in project tetrad by cmu-phil.

the class HsimEvalFromData method main.

public static void main(String[] args) {
    long timestart = System.nanoTime();
    System.out.println("Beginning Evaluation");
    String nl = System.lineSeparator();
    String output = "Simulation edu.cmu.tetrad.study output comparing Fsim and Hsim on predicting graph discovery accuracy" + nl;
    int iterations = 100;
    int vars = 20;
    int cases = 500;
    int edgeratio = 3;
    List<Integer> hsimRepeat = Arrays.asList(40);
    List<Integer> fsimRepeat = Arrays.asList(40);
    List<PRAOerrors>[] fsimErrsByPars = new ArrayList[fsimRepeat.size()];
    int whichFrepeat = 0;
    for (int frepeat : fsimRepeat) {
        fsimErrsByPars[whichFrepeat] = new ArrayList<PRAOerrors>();
        whichFrepeat++;
    }
    List<PRAOerrors>[][] hsimErrsByPars = new ArrayList[1][hsimRepeat.size()];
    // System.out.println(resimSize.size()+" "+hsimRepeat.size());
    int whichHrepeat;
    whichHrepeat = 0;
    for (int hrepeat : hsimRepeat) {
        // System.out.println(whichrsize+" "+whichHrepeat);
        hsimErrsByPars[0][whichHrepeat] = new ArrayList<PRAOerrors>();
        whichHrepeat++;
    }
    // !(*%(@!*^!($%!^ START ITERATING HERE !#$%(*$#@!^(*!$*%(!$#
    try {
        for (int iterate = 0; iterate < iterations; iterate++) {
            System.out.println("iteration " + iterate);
            // @#$%@$%^@$^@$^@%$%@$#^ LOADING THE DATA AND GRAPH @$#%%*#^##*^$#@%$
            DataSet data1;
            Graph graph1 = GraphUtils.loadGraphTxt(new File("graph/graph.1.txt"));
            Dag odag = new Dag(graph1);
            Set<String> eVars = new HashSet<String>();
            eVars.add("MULT");
            Path dataFile = Paths.get("data/data.1.txt");
            TabularDataReader dataReader = new ContinuousTabularDataFileReader(dataFile.toFile(), Delimiter.TAB);
            data1 = (DataSet) DataConvertUtils.toDataModel(dataReader.readInData(eVars));
            vars = data1.getNumColumns();
            cases = data1.getNumRows();
            edgeratio = 3;
            // !#@^$@&%^!#$!&@^ CALCULATING TARGET ERRORS $%$#@^@!%!#^$!%$#%
            ICovarianceMatrix newcov = new CovarianceMatrixOnTheFly(data1);
            SemBicScore oscore = new SemBicScore(newcov);
            Fges ofgs = new Fges(oscore);
            ofgs.setVerbose(false);
            ofgs.setNumPatternsToStore(0);
            // ***********This is the original FGS output on the data
            Graph oFGSGraph = ofgs.search();
            PRAOerrors oErrors = new PRAOerrors(HsimUtils.errorEval(oFGSGraph, odag), "target errors");
            // **then step 1: full resim. iterate through the combinations of estimator parameters (just repeat num)
            for (whichFrepeat = 0; whichFrepeat < fsimRepeat.size(); whichFrepeat++) {
                ArrayList<PRAOerrors> errorsList = new ArrayList<PRAOerrors>();
                for (int r = 0; r < fsimRepeat.get(whichFrepeat); r++) {
                    PatternToDag pickdag = new PatternToDag(oFGSGraph);
                    Graph fgsDag = pickdag.patternToDagMeek();
                    Dag fgsdag2 = new Dag(fgsDag);
                    // then fit an IM to this dag and the data. GeneralizedSemEstimator seems to bug out
                    // GeneralizedSemPm simSemPm = new GeneralizedSemPm(fgsdag2);
                    // GeneralizedSemEstimator gsemEstimator = new GeneralizedSemEstimator();
                    // GeneralizedSemIm fittedIM = gsemEstimator.estimate(simSemPm, oData);
                    SemPm simSemPm = new SemPm(fgsdag2);
                    // BayesPm simBayesPm = new BayesPm(fgsdag2, bayesPm);
                    SemEstimator simSemEstimator = new SemEstimator(data1, simSemPm);
                    SemIm fittedIM = simSemEstimator.estimate();
                    DataSet simData = fittedIM.simulateData(data1.getNumRows(), false);
                    // after making the full resim data (simData), run FGS on that
                    ICovarianceMatrix simcov = new CovarianceMatrixOnTheFly(simData);
                    SemBicScore simscore = new SemBicScore(simcov);
                    Fges simfgs = new Fges(simscore);
                    simfgs.setVerbose(false);
                    simfgs.setNumPatternsToStore(0);
                    Graph simGraphOut = simfgs.search();
                    PRAOerrors simErrors = new PRAOerrors(HsimUtils.errorEval(simGraphOut, fgsdag2), "Fsim errors " + r);
                    errorsList.add(simErrors);
                }
                PRAOerrors avErrors = new PRAOerrors(errorsList, "Average errors for Fsim at repeat=" + fsimRepeat.get(whichFrepeat));
                // if (verbosity>3) System.out.println(avErrors.allToString());
                // ****calculate the squared errors of prediction, store all these errors in a list
                double FsimAR2 = (avErrors.getAdjRecall() - oErrors.getAdjRecall()) * (avErrors.getAdjRecall() - oErrors.getAdjRecall());
                double FsimAP2 = (avErrors.getAdjPrecision() - oErrors.getAdjPrecision()) * (avErrors.getAdjPrecision() - oErrors.getAdjPrecision());
                double FsimOR2 = (avErrors.getOrientRecall() - oErrors.getOrientRecall()) * (avErrors.getOrientRecall() - oErrors.getOrientRecall());
                double FsimOP2 = (avErrors.getOrientPrecision() - oErrors.getOrientPrecision()) * (avErrors.getOrientPrecision() - oErrors.getOrientPrecision());
                PRAOerrors Fsim2 = new PRAOerrors(new double[] { FsimAR2, FsimAP2, FsimOR2, FsimOP2 }, "squared errors for Fsim at repeat=" + fsimRepeat.get(whichFrepeat));
                // add the fsim squared errors to the appropriate list
                fsimErrsByPars[whichFrepeat].add(Fsim2);
            }
            // **then step 2: hybrid sim. iterate through combos of params (repeat num, resimsize)
            for (whichHrepeat = 0; whichHrepeat < hsimRepeat.size(); whichHrepeat++) {
                HsimRepeatAC study = new HsimRepeatAC(data1);
                PRAOerrors HsimErrors = new PRAOerrors(study.run(1, hsimRepeat.get(whichHrepeat)), "Hsim errors" + "at rsize=" + 1 + " repeat=" + hsimRepeat.get(whichHrepeat));
                // ****calculate the squared errors of prediction
                double HsimAR2 = (HsimErrors.getAdjRecall() - oErrors.getAdjRecall()) * (HsimErrors.getAdjRecall() - oErrors.getAdjRecall());
                double HsimAP2 = (HsimErrors.getAdjPrecision() - oErrors.getAdjPrecision()) * (HsimErrors.getAdjPrecision() - oErrors.getAdjPrecision());
                double HsimOR2 = (HsimErrors.getOrientRecall() - oErrors.getOrientRecall()) * (HsimErrors.getOrientRecall() - oErrors.getOrientRecall());
                double HsimOP2 = (HsimErrors.getOrientPrecision() - oErrors.getOrientPrecision()) * (HsimErrors.getOrientPrecision() - oErrors.getOrientPrecision());
                PRAOerrors Hsim2 = new PRAOerrors(new double[] { HsimAR2, HsimAP2, HsimOR2, HsimOP2 }, "squared errors for Hsim, rsize=" + 1 + " repeat=" + hsimRepeat.get(whichHrepeat));
                hsimErrsByPars[0][whichHrepeat].add(Hsim2);
            }
        }
        // Average the squared errors for each set of fsim/hsim params across all iterations
        PRAOerrors[] fMSE = new PRAOerrors[fsimRepeat.size()];
        PRAOerrors[][] hMSE = new PRAOerrors[1][hsimRepeat.size()];
        String[][] latexTableArray = new String[1 * hsimRepeat.size() + fsimRepeat.size()][5];
        for (int j = 0; j < fMSE.length; j++) {
            fMSE[j] = new PRAOerrors(fsimErrsByPars[j], "MSE for Fsim at vars=" + vars + " edgeratio=" + edgeratio + " cases=" + cases + " frepeat=" + fsimRepeat.get(j) + " iterations=" + iterations);
            // if(verbosity>0){System.out.println(fMSE[j].allToString());}
            output = output + fMSE[j].allToString() + nl;
            latexTableArray[j] = prelimToPRAOtable(fMSE[j]);
        }
        for (int j = 0; j < hMSE.length; j++) {
            for (int k = 0; k < hMSE[j].length; k++) {
                hMSE[j][k] = new PRAOerrors(hsimErrsByPars[j][k], "MSE for Hsim at vars=" + vars + " edgeratio=" + edgeratio + " cases=" + cases + " rsize=" + 1 + " repeat=" + hsimRepeat.get(k) + " iterations=" + iterations);
                // if(verbosity>0){System.out.println(hMSE[j][k].allToString());}
                output = output + hMSE[j][k].allToString() + nl;
                latexTableArray[fsimRepeat.size() + j * hMSE[j].length + k] = prelimToPRAOtable(hMSE[j][k]);
            }
        }
        // record all the params, the base error values, and the fsim/hsim mean squared errors
        String latexTable = HsimUtils.makeLatexTable(latexTableArray);
        PrintWriter writer = new PrintWriter("latexTable.txt", "UTF-8");
        writer.println(latexTable);
        writer.close();
        PrintWriter writer2 = new PrintWriter("HvsF-SimulationEvaluation.txt", "UTF-8");
        writer2.println(output);
        writer2.close();
        long timestop = System.nanoTime();
        System.out.println("Evaluation Concluded. Duration: " + (timestop - timestart) / 1000000000 + "s");
    } catch (Exception IOException) {
        IOException.printStackTrace();
    }
}
Also used : TabularDataReader(edu.pitt.dbmi.data.reader.tabular.TabularDataReader) DataSet(edu.cmu.tetrad.data.DataSet) ICovarianceMatrix(edu.cmu.tetrad.data.ICovarianceMatrix) SemPm(edu.cmu.tetrad.sem.SemPm) SemEstimator(edu.cmu.tetrad.sem.SemEstimator) PrintWriter(java.io.PrintWriter) Path(java.nio.file.Path) PatternToDag(edu.cmu.tetrad.search.PatternToDag) ContinuousTabularDataFileReader(edu.pitt.dbmi.data.reader.tabular.ContinuousTabularDataFileReader) PatternToDag(edu.cmu.tetrad.search.PatternToDag) Dag(edu.cmu.tetrad.graph.Dag) Fges(edu.cmu.tetrad.search.Fges) Graph(edu.cmu.tetrad.graph.Graph) CovarianceMatrixOnTheFly(edu.cmu.tetrad.data.CovarianceMatrixOnTheFly) File(java.io.File) SemIm(edu.cmu.tetrad.sem.SemIm) SemBicScore(edu.cmu.tetrad.search.SemBicScore)

Example 14 with Dag

use of edu.cmu.tetrad.graph.Dag in project tetrad by cmu-phil.

the class TestLogisticRegression method test1.

@Test
public void test1() {
    List<Node> nodes = new ArrayList<>();
    for (int i = 0; i < 5; i++) {
        nodes.add(new ContinuousVariable("X" + (i + 1)));
    }
    Graph graph = new Dag(GraphUtils.randomGraph(nodes, 0, 5, 3, 3, 3, false));
    System.out.println(graph);
    SemPm pm = new SemPm(graph);
    SemIm im = new SemIm(pm);
    DataSet data = im.simulateDataRecursive(1000, false);
    Node x1 = data.getVariable("X1");
    Node x2 = data.getVariable("X2");
    Node x3 = data.getVariable("X3");
    Node x4 = data.getVariable("X4");
    Node x5 = data.getVariable("X5");
    Discretizer discretizer = new Discretizer(data);
    discretizer.equalCounts(x1, 2);
    DataSet d2 = discretizer.discretize();
    LogisticRegression regression = new LogisticRegression(d2);
    List<Node> regressors = new ArrayList<>();
    regressors.add(x2);
    regressors.add(x3);
    regressors.add(x4);
    regressors.add(x5);
    DiscreteVariable x1b = (DiscreteVariable) d2.getVariable("X1");
    regression.regress(x1b, regressors);
    System.out.println(regression);
}
Also used : DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) Dag(edu.cmu.tetrad.graph.Dag) Discretizer(edu.cmu.tetrad.data.Discretizer) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable) Graph(edu.cmu.tetrad.graph.Graph) SemPm(edu.cmu.tetrad.sem.SemPm) LogisticRegression(edu.cmu.tetrad.regression.LogisticRegression) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test)

Example 15 with Dag

use of edu.cmu.tetrad.graph.Dag in project tetrad by cmu-phil.

the class TestProposition method sampleBayesIm2.

private BayesIm sampleBayesIm2() {
    Node a = new GraphNode("a");
    Node b = new GraphNode("b");
    Node c = new GraphNode("c");
    Dag graph;
    graph = new Dag();
    graph.addNode(a);
    graph.addNode(b);
    graph.addNode(c);
    graph.addDirectedEdge(a, b);
    graph.addDirectedEdge(a, c);
    graph.addDirectedEdge(b, c);
    BayesPm bayesPm = new BayesPm(graph);
    bayesPm.setNumCategories(b, 3);
    BayesIm bayesIm1 = new MlBayesIm(bayesPm);
    bayesIm1.setProbability(0, 0, 0, .3);
    bayesIm1.setProbability(0, 0, 1, .7);
    bayesIm1.setProbability(1, 0, 0, .3);
    bayesIm1.setProbability(1, 0, 1, .4);
    bayesIm1.setProbability(1, 0, 2, .3);
    bayesIm1.setProbability(1, 1, 0, .6);
    bayesIm1.setProbability(1, 1, 1, .1);
    bayesIm1.setProbability(1, 1, 2, .3);
    bayesIm1.setProbability(2, 0, 0, .9);
    bayesIm1.setProbability(2, 0, 1, .1);
    bayesIm1.setProbability(2, 1, 0, .1);
    bayesIm1.setProbability(2, 1, 1, .9);
    bayesIm1.setProbability(2, 2, 0, .5);
    bayesIm1.setProbability(2, 2, 1, .5);
    bayesIm1.setProbability(2, 3, 0, .2);
    bayesIm1.setProbability(2, 3, 1, .8);
    bayesIm1.setProbability(2, 4, 0, .6);
    bayesIm1.setProbability(2, 4, 1, .4);
    bayesIm1.setProbability(2, 5, 0, .7);
    bayesIm1.setProbability(2, 5, 1, .3);
    return bayesIm1;
}
Also used : MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) BayesIm(edu.cmu.tetrad.bayes.BayesIm) MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) GraphNode(edu.cmu.tetrad.graph.GraphNode) Node(edu.cmu.tetrad.graph.Node) GraphNode(edu.cmu.tetrad.graph.GraphNode) Dag(edu.cmu.tetrad.graph.Dag) BayesPm(edu.cmu.tetrad.bayes.BayesPm)

Aggregations

Dag (edu.cmu.tetrad.graph.Dag)53 Node (edu.cmu.tetrad.graph.Node)41 Graph (edu.cmu.tetrad.graph.Graph)21 GraphNode (edu.cmu.tetrad.graph.GraphNode)21 Test (org.junit.Test)18 ArrayList (java.util.ArrayList)12 SemIm (edu.cmu.tetrad.sem.SemIm)10 SemPm (edu.cmu.tetrad.sem.SemPm)10 DataSet (edu.cmu.tetrad.data.DataSet)7 BayesPm (edu.cmu.tetrad.bayes.BayesPm)6 MlBayesIm (edu.cmu.tetrad.bayes.MlBayesIm)6 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)6 BayesIm (edu.cmu.tetrad.bayes.BayesIm)5 DiscreteVariable (edu.cmu.tetrad.data.DiscreteVariable)3 FruchtermanReingoldLayout (edu.cmu.tetrad.graph.FruchtermanReingoldLayout)2 TetradMatrix (edu.cmu.tetrad.util.TetradMatrix)2 LinkedList (java.util.LinkedList)2 StoredCellProbs (edu.cmu.tetrad.bayes.StoredCellProbs)1 ColtDataSet (edu.cmu.tetrad.data.ColtDataSet)1 CovarianceMatrixOnTheFly (edu.cmu.tetrad.data.CovarianceMatrixOnTheFly)1