Search in sources :

Example 1 with Fges

use of edu.cmu.tetrad.search.Fges in project tetrad by cmu-phil.

the class FgesMeasurement method search.

@Override
public Graph search(DataModel dataModel, Parameters parameters) {
    if (parameters.getInt("bootstrapSampleSize") < 1) {
        DataSet dataSet = DataUtils.getContinuousDataSet(dataModel);
        dataSet = dataSet.copy();
        dataSet = DataUtils.standardizeData(dataSet);
        double variance = parameters.getDouble("measurementVariance");
        if (variance > 0) {
            for (int i = 0; i < dataSet.getNumRows(); i++) {
                for (int j = 0; j < dataSet.getNumColumns(); j++) {
                    double d = dataSet.getDouble(i, j);
                    double norm = RandomUtil.getInstance().nextNormal(0, Math.sqrt(variance));
                    dataSet.setDouble(i, j, d + norm);
                }
            }
        }
        Fges search = new Fges(score.getScore(dataSet, parameters));
        search.setFaithfulnessAssumed(parameters.getBoolean("faithfulnessAssumed"));
        search.setKnowledge(knowledge);
        search.setVerbose(parameters.getBoolean("verbose"));
        return search.search();
    } else {
        FgesMeasurement fgesMeasurement = new FgesMeasurement(score, algorithm);
        // fgesMeasurement.setKnowledge(knowledge);
        if (initialGraph != null) {
            fgesMeasurement.setInitialGraph(initialGraph);
        }
        DataSet data = (DataSet) dataModel;
        GeneralBootstrapTest search = new GeneralBootstrapTest(data, fgesMeasurement, parameters.getInt("bootstrapSampleSize"));
        search.setKnowledge(knowledge);
        BootstrapEdgeEnsemble edgeEnsemble = BootstrapEdgeEnsemble.Highest;
        switch(parameters.getInt("bootstrapEnsemble", 1)) {
            case 0:
                edgeEnsemble = BootstrapEdgeEnsemble.Preserved;
                break;
            case 1:
                edgeEnsemble = BootstrapEdgeEnsemble.Highest;
                break;
            case 2:
                edgeEnsemble = BootstrapEdgeEnsemble.Majority;
        }
        search.setEdgeEnsemble(edgeEnsemble);
        search.setParameters(parameters);
        search.setVerbose(parameters.getBoolean("verbose"));
        return search.search();
    }
}
Also used : GeneralBootstrapTest(edu.pitt.dbmi.algo.bootstrap.GeneralBootstrapTest) BootstrapEdgeEnsemble(edu.pitt.dbmi.algo.bootstrap.BootstrapEdgeEnsemble) Fges(edu.cmu.tetrad.search.Fges)

Example 2 with Fges

use of edu.cmu.tetrad.search.Fges in project tetrad by cmu-phil.

the class MixedFgesDiscretingContinuousVariables method search.

public Graph search(DataModel dataSet, Parameters parameters) {
    if (parameters.getInt("bootstrapSampleSize") < 1) {
        Discretizer discretizer = new Discretizer(DataUtils.getContinuousDataSet(dataSet));
        List<Node> nodes = dataSet.getVariables();
        for (Node node : nodes) {
            if (node instanceof ContinuousVariable) {
                discretizer.equalIntervals(node, parameters.getInt("numCategories"));
            }
        }
        dataSet = discretizer.discretize();
        DataSet _dataSet = DataUtils.getDiscreteDataSet(dataSet);
        Fges fges = new Fges(score.getScore(_dataSet, parameters));
        Graph p = fges.search();
        return convertBack(_dataSet, p);
    } else {
        MixedFgesDiscretingContinuousVariables algorithm = new MixedFgesDiscretingContinuousVariables(score);
        DataSet data = (DataSet) dataSet;
        GeneralBootstrapTest search = new GeneralBootstrapTest(data, algorithm, parameters.getInt("bootstrapSampleSize"));
        BootstrapEdgeEnsemble edgeEnsemble = BootstrapEdgeEnsemble.Highest;
        switch(parameters.getInt("bootstrapEnsemble", 1)) {
            case 0:
                edgeEnsemble = BootstrapEdgeEnsemble.Preserved;
                break;
            case 1:
                edgeEnsemble = BootstrapEdgeEnsemble.Highest;
                break;
            case 2:
                edgeEnsemble = BootstrapEdgeEnsemble.Majority;
        }
        search.setEdgeEnsemble(edgeEnsemble);
        search.setParameters(parameters);
        search.setVerbose(parameters.getBoolean("verbose"));
        return search.search();
    }
}
Also used : EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) Graph(edu.cmu.tetrad.graph.Graph) GeneralBootstrapTest(edu.pitt.dbmi.algo.bootstrap.GeneralBootstrapTest) BootstrapEdgeEnsemble(edu.pitt.dbmi.algo.bootstrap.BootstrapEdgeEnsemble) Node(edu.cmu.tetrad.graph.Node) Fges(edu.cmu.tetrad.search.Fges)

Example 3 with Fges

use of edu.cmu.tetrad.search.Fges in project tetrad by cmu-phil.

the class HsimAutoRun method run.

// ***********Public methods*************//
public double[] run(int resimSize) {
    // modify this so that verbose is a private data value, and so that data can be taken from either a dataset or a file.
    // ===========read data from file=============
    Set<String> eVars = new HashSet<String>();
    eVars.add("MULT");
    double[] output;
    output = new double[5];
    try {
        // ==== try with BigDataSetUtility ==============
        // DataSet regularDataSet = BigDataSetUtility.readInDiscreteData(new File(readfilename), delimiter, eVars);
        // ======done with BigDataSetUtility=============
        // if (verbose) System.out.println("Regular cols: " + regularDataSet.getNumColumns() + " rows: " + regularDataSet.getNumRows());
        // testing the read file
        // DataWriter.writeRectangularData(dataSet, new FileWriter("dataOut2.txt"), '\t');
        // apply Hsim to data, with whatever parameters
        // ========first make the Dag for Hsim==========
        BDeuScore score = new BDeuScore(data);
        // ICovarianceMatrix cov = new CovarianceMatrixOnTheFly(dataSet);
        double penaltyDiscount = 2.0;
        Fges fges = new Fges(score);
        fges.setVerbose(false);
        fges.setNumPatternsToStore(0);
        fges.setPenaltyDiscount(penaltyDiscount);
        Graph estGraph = fges.search();
        // if (verbose) System.out.println(estGraph);
        Graph estPattern = new EdgeListGraphSingleConnections(estGraph);
        PatternToDag patternToDag = new PatternToDag(estPattern);
        Graph estGraphDAG = patternToDag.patternToDagMeek();
        Dag estDAG = new Dag(estGraphDAG);
        // ===========Identify the nodes to be resimulated===========
        // select a random node as the centroid
        List<Node> allNodes = estGraph.getNodes();
        int size = allNodes.size();
        int randIndex = new Random().nextInt(size);
        Node centroid = allNodes.get(randIndex);
        if (verbose) {
            System.out.println("the centroid is " + centroid);
        }
        List<Node> queue = new ArrayList<>();
        queue.add(centroid);
        List<Node> queueAdd = new ArrayList<Node>();
        // if (verbose) System.out.println(queue);
        while (queue.size() < resimSize) {
            // if (verbose) System.out.println(queue.size() + " vs " + resimSize);
            // find nodes adjacent to nodes in current queue, add them to a queue without duplicating nodes
            int qsize = queue.size();
            for (int i = 0; i < qsize; i++) {
                // find set of adjacent nodes
                queueAdd = estGraph.getAdjacentNodes(queue.get(i));
                // remove nodes that are already in queue
                queueAdd.removeAll(queue);
                // //**** If queueAdd is empty at this stage, randomly select a node to add
                while (queueAdd.size() < 1) {
                    queueAdd.add(allNodes.get(new Random().nextInt(size)));
                }
                // add remaining nodes to queue
                queue.addAll(queueAdd);
                // break early when queue outgrows resimsize
                if (queue.size() >= resimSize) {
                    break;
                }
            }
        }
        // if queue is too big, remove nodes from the end until it is small enough.
        while (queue.size() > resimSize) {
            queue.remove(queue.size() - 1);
        // if (verbose) System.out.println(queue);
        }
        Set<Node> simnodes = new HashSet<Node>(queue);
        if (verbose) {
            System.out.println("the resimmed nodes are " + simnodes);
        }
        // ===========Apply the hybrid resimulation===============
        // regularDataSet
        Hsim hsim = new Hsim(estDAG, simnodes, data);
        DataSet newDataSet = hsim.hybridsimulate();
        // write output to a new file
        if (write) {
            FileWriter fileWriter = new FileWriter(filenameOut);
            DataWriter.writeRectangularData(newDataSet, fileWriter, delimiter);
            fileWriter.close();
        }
        // =======Run FGES on the output data, and compare it to the original learned graph
        // Path dataFileOut = Paths.get(filenameOut);
        // edu.cmu.tetrad.io.DataReader dataReaderOut = new VerticalTabularDiscreteDataReader(dataFileOut, delimiter);
        // DataSet dataSetOut = dataReaderOut.readInData(eVars);
        BDeuScore newscore = new BDeuScore(newDataSet);
        Fges fgesOut = new Fges(newscore);
        fgesOut.setVerbose(false);
        fgesOut.setNumPatternsToStore(0);
        fgesOut.setPenaltyDiscount(2.0);
        // fgesOut.setOut(out);
        // fgesOut.setFaithfulnessAssumed(true);
        // fgesOut.setMaxIndegree(1);
        // fgesOut.setCycleBound(5);
        Graph estGraphOut = fgesOut.search();
        // if (verbose) System.out.println(" bugchecking: fges estGraphOut: " + estGraphOut);
        // doing the replaceNodes trick to fix some bugs
        estGraphOut = GraphUtils.replaceNodes(estGraphOut, estDAG.getNodes());
        // restrict the comparison to the simnodes and edges to their parents
        Set<Node> allParents = HsimUtils.getAllParents(estGraphOut, simnodes);
        Set<Node> addParents = HsimUtils.getAllParents(estDAG, simnodes);
        allParents.addAll(addParents);
        Graph estEvalGraphOut = HsimUtils.evalEdges(estGraphOut, simnodes, allParents);
        Graph estEvalGraph = HsimUtils.evalEdges(estDAG, simnodes, allParents);
        // SearchGraphUtils.graphComparison(estGraph, estGraphOut, System.out);
        estEvalGraphOut = GraphUtils.replaceNodes(estEvalGraphOut, estEvalGraph.getNodes());
        // if (verbose) System.out.println(estEvalGraph);
        // if (verbose) System.out.println(estEvalGraphOut);
        // SearchGraphUtils.graphComparison(estEvalGraphOut, estEvalGraph, System.out);
        output = HsimUtils.errorEval(estEvalGraphOut, estEvalGraph);
        if (verbose) {
            System.out.println(output[0] + " " + output[1] + " " + output[2] + " " + output[3] + " " + output[4]);
        }
    } catch (Exception IOException) {
        IOException.printStackTrace();
    }
    return output;
}
Also used : PatternToDag(edu.cmu.tetrad.search.PatternToDag) FileWriter(java.io.FileWriter) PatternToDag(edu.cmu.tetrad.search.PatternToDag) Fges(edu.cmu.tetrad.search.Fges) BDeuScore(edu.cmu.tetrad.search.BDeuScore)

Example 4 with Fges

use of edu.cmu.tetrad.search.Fges in project tetrad by cmu-phil.

the class HsimEvalFromData method main.

public static void main(String[] args) {
    long timestart = System.nanoTime();
    System.out.println("Beginning Evaluation");
    String nl = System.lineSeparator();
    String output = "Simulation edu.cmu.tetrad.study output comparing Fsim and Hsim on predicting graph discovery accuracy" + nl;
    int iterations = 100;
    int vars = 20;
    int cases = 500;
    int edgeratio = 3;
    List<Integer> hsimRepeat = Arrays.asList(40);
    List<Integer> fsimRepeat = Arrays.asList(40);
    List<PRAOerrors>[] fsimErrsByPars = new ArrayList[fsimRepeat.size()];
    int whichFrepeat = 0;
    for (int frepeat : fsimRepeat) {
        fsimErrsByPars[whichFrepeat] = new ArrayList<PRAOerrors>();
        whichFrepeat++;
    }
    List<PRAOerrors>[][] hsimErrsByPars = new ArrayList[1][hsimRepeat.size()];
    // System.out.println(resimSize.size()+" "+hsimRepeat.size());
    int whichHrepeat;
    whichHrepeat = 0;
    for (int hrepeat : hsimRepeat) {
        // System.out.println(whichrsize+" "+whichHrepeat);
        hsimErrsByPars[0][whichHrepeat] = new ArrayList<PRAOerrors>();
        whichHrepeat++;
    }
    // !(*%(@!*^!($%!^ START ITERATING HERE !#$%(*$#@!^(*!$*%(!$#
    try {
        for (int iterate = 0; iterate < iterations; iterate++) {
            System.out.println("iteration " + iterate);
            // @#$%@$%^@$^@$^@%$%@$#^ LOADING THE DATA AND GRAPH @$#%%*#^##*^$#@%$
            DataSet data1;
            Graph graph1 = GraphUtils.loadGraphTxt(new File("graph/graph.1.txt"));
            Dag odag = new Dag(graph1);
            Set<String> eVars = new HashSet<String>();
            eVars.add("MULT");
            Path dataFile = Paths.get("data/data.1.txt");
            TabularDataReader dataReader = new ContinuousTabularDataFileReader(dataFile.toFile(), Delimiter.TAB);
            data1 = (DataSet) DataConvertUtils.toDataModel(dataReader.readInData(eVars));
            vars = data1.getNumColumns();
            cases = data1.getNumRows();
            edgeratio = 3;
            // !#@^$@&%^!#$!&@^ CALCULATING TARGET ERRORS $%$#@^@!%!#^$!%$#%
            ICovarianceMatrix newcov = new CovarianceMatrixOnTheFly(data1);
            SemBicScore oscore = new SemBicScore(newcov);
            Fges ofgs = new Fges(oscore);
            ofgs.setVerbose(false);
            ofgs.setNumPatternsToStore(0);
            // ***********This is the original FGS output on the data
            Graph oFGSGraph = ofgs.search();
            PRAOerrors oErrors = new PRAOerrors(HsimUtils.errorEval(oFGSGraph, odag), "target errors");
            // **then step 1: full resim. iterate through the combinations of estimator parameters (just repeat num)
            for (whichFrepeat = 0; whichFrepeat < fsimRepeat.size(); whichFrepeat++) {
                ArrayList<PRAOerrors> errorsList = new ArrayList<PRAOerrors>();
                for (int r = 0; r < fsimRepeat.get(whichFrepeat); r++) {
                    PatternToDag pickdag = new PatternToDag(oFGSGraph);
                    Graph fgsDag = pickdag.patternToDagMeek();
                    Dag fgsdag2 = new Dag(fgsDag);
                    // then fit an IM to this dag and the data. GeneralizedSemEstimator seems to bug out
                    // GeneralizedSemPm simSemPm = new GeneralizedSemPm(fgsdag2);
                    // GeneralizedSemEstimator gsemEstimator = new GeneralizedSemEstimator();
                    // GeneralizedSemIm fittedIM = gsemEstimator.estimate(simSemPm, oData);
                    SemPm simSemPm = new SemPm(fgsdag2);
                    // BayesPm simBayesPm = new BayesPm(fgsdag2, bayesPm);
                    SemEstimator simSemEstimator = new SemEstimator(data1, simSemPm);
                    SemIm fittedIM = simSemEstimator.estimate();
                    DataSet simData = fittedIM.simulateData(data1.getNumRows(), false);
                    // after making the full resim data (simData), run FGS on that
                    ICovarianceMatrix simcov = new CovarianceMatrixOnTheFly(simData);
                    SemBicScore simscore = new SemBicScore(simcov);
                    Fges simfgs = new Fges(simscore);
                    simfgs.setVerbose(false);
                    simfgs.setNumPatternsToStore(0);
                    Graph simGraphOut = simfgs.search();
                    PRAOerrors simErrors = new PRAOerrors(HsimUtils.errorEval(simGraphOut, fgsdag2), "Fsim errors " + r);
                    errorsList.add(simErrors);
                }
                PRAOerrors avErrors = new PRAOerrors(errorsList, "Average errors for Fsim at repeat=" + fsimRepeat.get(whichFrepeat));
                // if (verbosity>3) System.out.println(avErrors.allToString());
                // ****calculate the squared errors of prediction, store all these errors in a list
                double FsimAR2 = (avErrors.getAdjRecall() - oErrors.getAdjRecall()) * (avErrors.getAdjRecall() - oErrors.getAdjRecall());
                double FsimAP2 = (avErrors.getAdjPrecision() - oErrors.getAdjPrecision()) * (avErrors.getAdjPrecision() - oErrors.getAdjPrecision());
                double FsimOR2 = (avErrors.getOrientRecall() - oErrors.getOrientRecall()) * (avErrors.getOrientRecall() - oErrors.getOrientRecall());
                double FsimOP2 = (avErrors.getOrientPrecision() - oErrors.getOrientPrecision()) * (avErrors.getOrientPrecision() - oErrors.getOrientPrecision());
                PRAOerrors Fsim2 = new PRAOerrors(new double[] { FsimAR2, FsimAP2, FsimOR2, FsimOP2 }, "squared errors for Fsim at repeat=" + fsimRepeat.get(whichFrepeat));
                // add the fsim squared errors to the appropriate list
                fsimErrsByPars[whichFrepeat].add(Fsim2);
            }
            // **then step 2: hybrid sim. iterate through combos of params (repeat num, resimsize)
            for (whichHrepeat = 0; whichHrepeat < hsimRepeat.size(); whichHrepeat++) {
                HsimRepeatAC study = new HsimRepeatAC(data1);
                PRAOerrors HsimErrors = new PRAOerrors(study.run(1, hsimRepeat.get(whichHrepeat)), "Hsim errors" + "at rsize=" + 1 + " repeat=" + hsimRepeat.get(whichHrepeat));
                // ****calculate the squared errors of prediction
                double HsimAR2 = (HsimErrors.getAdjRecall() - oErrors.getAdjRecall()) * (HsimErrors.getAdjRecall() - oErrors.getAdjRecall());
                double HsimAP2 = (HsimErrors.getAdjPrecision() - oErrors.getAdjPrecision()) * (HsimErrors.getAdjPrecision() - oErrors.getAdjPrecision());
                double HsimOR2 = (HsimErrors.getOrientRecall() - oErrors.getOrientRecall()) * (HsimErrors.getOrientRecall() - oErrors.getOrientRecall());
                double HsimOP2 = (HsimErrors.getOrientPrecision() - oErrors.getOrientPrecision()) * (HsimErrors.getOrientPrecision() - oErrors.getOrientPrecision());
                PRAOerrors Hsim2 = new PRAOerrors(new double[] { HsimAR2, HsimAP2, HsimOR2, HsimOP2 }, "squared errors for Hsim, rsize=" + 1 + " repeat=" + hsimRepeat.get(whichHrepeat));
                hsimErrsByPars[0][whichHrepeat].add(Hsim2);
            }
        }
        // Average the squared errors for each set of fsim/hsim params across all iterations
        PRAOerrors[] fMSE = new PRAOerrors[fsimRepeat.size()];
        PRAOerrors[][] hMSE = new PRAOerrors[1][hsimRepeat.size()];
        String[][] latexTableArray = new String[1 * hsimRepeat.size() + fsimRepeat.size()][5];
        for (int j = 0; j < fMSE.length; j++) {
            fMSE[j] = new PRAOerrors(fsimErrsByPars[j], "MSE for Fsim at vars=" + vars + " edgeratio=" + edgeratio + " cases=" + cases + " frepeat=" + fsimRepeat.get(j) + " iterations=" + iterations);
            // if(verbosity>0){System.out.println(fMSE[j].allToString());}
            output = output + fMSE[j].allToString() + nl;
            latexTableArray[j] = prelimToPRAOtable(fMSE[j]);
        }
        for (int j = 0; j < hMSE.length; j++) {
            for (int k = 0; k < hMSE[j].length; k++) {
                hMSE[j][k] = new PRAOerrors(hsimErrsByPars[j][k], "MSE for Hsim at vars=" + vars + " edgeratio=" + edgeratio + " cases=" + cases + " rsize=" + 1 + " repeat=" + hsimRepeat.get(k) + " iterations=" + iterations);
                // if(verbosity>0){System.out.println(hMSE[j][k].allToString());}
                output = output + hMSE[j][k].allToString() + nl;
                latexTableArray[fsimRepeat.size() + j * hMSE[j].length + k] = prelimToPRAOtable(hMSE[j][k]);
            }
        }
        // record all the params, the base error values, and the fsim/hsim mean squared errors
        String latexTable = HsimUtils.makeLatexTable(latexTableArray);
        PrintWriter writer = new PrintWriter("latexTable.txt", "UTF-8");
        writer.println(latexTable);
        writer.close();
        PrintWriter writer2 = new PrintWriter("HvsF-SimulationEvaluation.txt", "UTF-8");
        writer2.println(output);
        writer2.close();
        long timestop = System.nanoTime();
        System.out.println("Evaluation Concluded. Duration: " + (timestop - timestart) / 1000000000 + "s");
    } catch (Exception IOException) {
        IOException.printStackTrace();
    }
}
Also used : TabularDataReader(edu.pitt.dbmi.data.reader.tabular.TabularDataReader) DataSet(edu.cmu.tetrad.data.DataSet) ICovarianceMatrix(edu.cmu.tetrad.data.ICovarianceMatrix) SemPm(edu.cmu.tetrad.sem.SemPm) SemEstimator(edu.cmu.tetrad.sem.SemEstimator) PrintWriter(java.io.PrintWriter) Path(java.nio.file.Path) PatternToDag(edu.cmu.tetrad.search.PatternToDag) ContinuousTabularDataFileReader(edu.pitt.dbmi.data.reader.tabular.ContinuousTabularDataFileReader) PatternToDag(edu.cmu.tetrad.search.PatternToDag) Dag(edu.cmu.tetrad.graph.Dag) Fges(edu.cmu.tetrad.search.Fges) Graph(edu.cmu.tetrad.graph.Graph) CovarianceMatrixOnTheFly(edu.cmu.tetrad.data.CovarianceMatrixOnTheFly) File(java.io.File) SemIm(edu.cmu.tetrad.sem.SemIm) SemBicScore(edu.cmu.tetrad.search.SemBicScore)

Example 5 with Fges

use of edu.cmu.tetrad.search.Fges in project tetrad by cmu-phil.

the class HsimRobustCompare method run.

// *************Public Methods*****************8//
public static List<double[]> run(int numVars, double edgesPerNode, int numCases, double penaltyDiscount, int resimSize, int repeat, boolean verbose) {
    // public static void main(String[] args) {
    // first generate the data
    RandomUtil.getInstance().setSeed(1450184147770L);
    // '\t';
    char delimiter = ',';
    final int numEdges = (int) (numVars * edgesPerNode);
    List<Node> vars = new ArrayList<>();
    double[] oErrors = new double[5];
    double[] hsimErrors = new double[5];
    double[] simErrors = new double[5];
    List<double[]> output = new ArrayList<>();
    for (int i = 0; i < numVars; i++) {
        vars.add(new ContinuousVariable("X" + i));
    }
    Graph odag = GraphUtils.randomGraphRandomForwardEdges(vars, 0, numEdges, 30, 15, 15, false, true);
    BayesPm bayesPm = new BayesPm(odag, 2, 2);
    BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
    // oData is the original data set, and odag is the original dag.
    DataSet oData = bayesIm.simulateData(numCases, false);
    // System.out.println(oData);
    // System.out.println(odag);
    // then run FGES
    BDeuScore oscore = new BDeuScore(oData);
    Fges fges = new Fges(oscore);
    fges.setVerbose(false);
    fges.setNumPatternsToStore(0);
    fges.setPenaltyDiscount(penaltyDiscount);
    Graph oGraphOut = fges.search();
    if (verbose)
        System.out.println(oGraphOut);
    // calculate FGES errors
    oErrors = new double[5];
    oErrors = HsimUtils.errorEval(oGraphOut, odag);
    if (verbose)
        System.out.println(oErrors[0] + " " + oErrors[1] + " " + oErrors[2] + " " + oErrors[3] + " " + oErrors[4]);
    // create various simulated data sets
    // //let's do the full simulated data set first: a dag in the FGES pattern fit to the data set.
    PatternToDag pickdag = new PatternToDag(oGraphOut);
    Graph fgesDag = pickdag.patternToDagMeek();
    Dag fgesdag2 = new Dag(fgesDag);
    BayesPm simBayesPm = new BayesPm(fgesdag2, bayesPm);
    DirichletBayesIm simIM = DirichletBayesIm.symmetricDirichletIm(simBayesPm, 1.0);
    DirichletEstimator simEstimator = new DirichletEstimator();
    DirichletBayesIm fittedIM = simEstimator.estimate(simIM, oData);
    DataSet simData = fittedIM.simulateData(numCases, false);
    // //next let's do a schedule of small hsims
    HsimRepeatAutoRun study = new HsimRepeatAutoRun(oData);
    hsimErrors = study.run(resimSize, repeat);
    // calculate errors for all simulated output graphs
    // //full simulation errors first
    BDeuScore simscore = new BDeuScore(simData);
    Fges simfges = new Fges(simscore);
    simfges.setVerbose(false);
    simfges.setNumPatternsToStore(0);
    simfges.setPenaltyDiscount(penaltyDiscount);
    Graph simGraphOut = simfges.search();
    // simErrors = new double[5];
    simErrors = HsimUtils.errorEval(simGraphOut, fgesdag2);
    // first, let's just see what the errors are.
    if (verbose)
        System.out.println("Original erors are: " + oErrors[0] + " " + oErrors[1] + " " + oErrors[2] + " " + oErrors[3] + " " + oErrors[4]);
    if (verbose)
        System.out.println("Full resim errors are: " + simErrors[0] + " " + simErrors[1] + " " + simErrors[2] + " " + simErrors[3] + " " + simErrors[4]);
    if (verbose)
        System.out.println("HSim errors are: " + hsimErrors[0] + " " + hsimErrors[1] + " " + hsimErrors[2] + " " + hsimErrors[3] + " " + hsimErrors[4]);
    // then, let's try to squeeze these numbers down into something more tractable.
    // double[] ErrorDifferenceDifferences;
    // ErrorDifferenceDifferences = new double[5];
    // ErrorDifferenceDifferences[0] = Math.abs(oErrors[0]-simErrors[0])-Math.abs(oErrors[0]-hsimErrors[0]);
    // ErrorDifferenceDifferences[1] = Math.abs(oErrors[1]-simErrors[1])-Math.abs(oErrors[1]-hsimErrors[1]);
    // ErrorDifferenceDifferences[2] = Math.abs(oErrors[2]-simErrors[2])-Math.abs(oErrors[2]-hsimErrors[2]);
    // ErrorDifferenceDifferences[3] = Math.abs(oErrors[3]-simErrors[3])-Math.abs(oErrors[3]-hsimErrors[3]);
    // ErrorDifferenceDifferences[4] = Math.abs(oErrors[4]-simErrors[4])-Math.abs(oErrors[4]-hsimErrors[4]);
    // System.out.println("resim error errors - hsim error errors: " + ErrorDifferenceDifferences[0] + " " + ErrorDifferenceDifferences[1] + " " + ErrorDifferenceDifferences[2] + " " + ErrorDifferenceDifferences[3] + " " + ErrorDifferenceDifferences[4]);
    output.add(oErrors);
    output.add(simErrors);
    output.add(hsimErrors);
    return output;
}
Also used : DataSet(edu.cmu.tetrad.data.DataSet) PatternToDag(edu.cmu.tetrad.search.PatternToDag) ArrayList(java.util.ArrayList) PatternToDag(edu.cmu.tetrad.search.PatternToDag) Fges(edu.cmu.tetrad.search.Fges) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) BDeuScore(edu.cmu.tetrad.search.BDeuScore)

Aggregations

Fges (edu.cmu.tetrad.search.Fges)25 RandomGraph (edu.cmu.tetrad.algcomparison.graph.RandomGraph)15 SemBicDTest (edu.cmu.tetrad.algcomparison.independence.SemBicDTest)11 SemBicTest (edu.cmu.tetrad.algcomparison.independence.SemBicTest)11 Test (org.junit.Test)11 SemBicScore (edu.cmu.tetrad.search.SemBicScore)9 Graph (edu.cmu.tetrad.graph.Graph)4 PatternToDag (edu.cmu.tetrad.search.PatternToDag)3 BootstrapEdgeEnsemble (edu.pitt.dbmi.algo.bootstrap.BootstrapEdgeEnsemble)3 GeneralBootstrapTest (edu.pitt.dbmi.algo.bootstrap.GeneralBootstrapTest)3 DataSet (edu.cmu.tetrad.data.DataSet)2 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)2 BDeuScore (edu.cmu.tetrad.search.BDeuScore)2 Pc (edu.cmu.tetrad.search.Pc)2 ArrayList (java.util.ArrayList)2 Algorithm (edu.cmu.tetrad.algcomparison.algorithm.Algorithm)1 RandomForward (edu.cmu.tetrad.algcomparison.graph.RandomForward)1 FisherZ (edu.cmu.tetrad.algcomparison.independence.FisherZ)1 IndependenceWrapper (edu.cmu.tetrad.algcomparison.independence.IndependenceWrapper)1 ScoreWrapper (edu.cmu.tetrad.algcomparison.score.ScoreWrapper)1