Search in sources :

Example 1 with BDeuScore

use of edu.cmu.tetrad.search.BDeuScore in project tetrad by cmu-phil.

the class HsimAutoRun method run.

// ***********Public methods*************//
public double[] run(int resimSize) {
    // modify this so that verbose is a private data value, and so that data can be taken from either a dataset or a file.
    // ===========read data from file=============
    Set<String> eVars = new HashSet<String>();
    eVars.add("MULT");
    double[] output;
    output = new double[5];
    try {
        // ==== try with BigDataSetUtility ==============
        // DataSet regularDataSet = BigDataSetUtility.readInDiscreteData(new File(readfilename), delimiter, eVars);
        // ======done with BigDataSetUtility=============
        // if (verbose) System.out.println("Regular cols: " + regularDataSet.getNumColumns() + " rows: " + regularDataSet.getNumRows());
        // testing the read file
        // DataWriter.writeRectangularData(dataSet, new FileWriter("dataOut2.txt"), '\t');
        // apply Hsim to data, with whatever parameters
        // ========first make the Dag for Hsim==========
        BDeuScore score = new BDeuScore(data);
        // ICovarianceMatrix cov = new CovarianceMatrixOnTheFly(dataSet);
        double penaltyDiscount = 2.0;
        Fges fges = new Fges(score);
        fges.setVerbose(false);
        fges.setNumPatternsToStore(0);
        fges.setPenaltyDiscount(penaltyDiscount);
        Graph estGraph = fges.search();
        // if (verbose) System.out.println(estGraph);
        Graph estPattern = new EdgeListGraphSingleConnections(estGraph);
        PatternToDag patternToDag = new PatternToDag(estPattern);
        Graph estGraphDAG = patternToDag.patternToDagMeek();
        Dag estDAG = new Dag(estGraphDAG);
        // ===========Identify the nodes to be resimulated===========
        // select a random node as the centroid
        List<Node> allNodes = estGraph.getNodes();
        int size = allNodes.size();
        int randIndex = new Random().nextInt(size);
        Node centroid = allNodes.get(randIndex);
        if (verbose) {
            System.out.println("the centroid is " + centroid);
        }
        List<Node> queue = new ArrayList<>();
        queue.add(centroid);
        List<Node> queueAdd = new ArrayList<Node>();
        // if (verbose) System.out.println(queue);
        while (queue.size() < resimSize) {
            // if (verbose) System.out.println(queue.size() + " vs " + resimSize);
            // find nodes adjacent to nodes in current queue, add them to a queue without duplicating nodes
            int qsize = queue.size();
            for (int i = 0; i < qsize; i++) {
                // find set of adjacent nodes
                queueAdd = estGraph.getAdjacentNodes(queue.get(i));
                // remove nodes that are already in queue
                queueAdd.removeAll(queue);
                // //**** If queueAdd is empty at this stage, randomly select a node to add
                while (queueAdd.size() < 1) {
                    queueAdd.add(allNodes.get(new Random().nextInt(size)));
                }
                // add remaining nodes to queue
                queue.addAll(queueAdd);
                // break early when queue outgrows resimsize
                if (queue.size() >= resimSize) {
                    break;
                }
            }
        }
        // if queue is too big, remove nodes from the end until it is small enough.
        while (queue.size() > resimSize) {
            queue.remove(queue.size() - 1);
        // if (verbose) System.out.println(queue);
        }
        Set<Node> simnodes = new HashSet<Node>(queue);
        if (verbose) {
            System.out.println("the resimmed nodes are " + simnodes);
        }
        // ===========Apply the hybrid resimulation===============
        // regularDataSet
        Hsim hsim = new Hsim(estDAG, simnodes, data);
        DataSet newDataSet = hsim.hybridsimulate();
        // write output to a new file
        if (write) {
            FileWriter fileWriter = new FileWriter(filenameOut);
            DataWriter.writeRectangularData(newDataSet, fileWriter, delimiter);
            fileWriter.close();
        }
        // =======Run FGES on the output data, and compare it to the original learned graph
        // Path dataFileOut = Paths.get(filenameOut);
        // edu.cmu.tetrad.io.DataReader dataReaderOut = new VerticalTabularDiscreteDataReader(dataFileOut, delimiter);
        // DataSet dataSetOut = dataReaderOut.readInData(eVars);
        BDeuScore newscore = new BDeuScore(newDataSet);
        Fges fgesOut = new Fges(newscore);
        fgesOut.setVerbose(false);
        fgesOut.setNumPatternsToStore(0);
        fgesOut.setPenaltyDiscount(2.0);
        // fgesOut.setOut(out);
        // fgesOut.setFaithfulnessAssumed(true);
        // fgesOut.setMaxIndegree(1);
        // fgesOut.setCycleBound(5);
        Graph estGraphOut = fgesOut.search();
        // if (verbose) System.out.println(" bugchecking: fges estGraphOut: " + estGraphOut);
        // doing the replaceNodes trick to fix some bugs
        estGraphOut = GraphUtils.replaceNodes(estGraphOut, estDAG.getNodes());
        // restrict the comparison to the simnodes and edges to their parents
        Set<Node> allParents = HsimUtils.getAllParents(estGraphOut, simnodes);
        Set<Node> addParents = HsimUtils.getAllParents(estDAG, simnodes);
        allParents.addAll(addParents);
        Graph estEvalGraphOut = HsimUtils.evalEdges(estGraphOut, simnodes, allParents);
        Graph estEvalGraph = HsimUtils.evalEdges(estDAG, simnodes, allParents);
        // SearchGraphUtils.graphComparison(estGraph, estGraphOut, System.out);
        estEvalGraphOut = GraphUtils.replaceNodes(estEvalGraphOut, estEvalGraph.getNodes());
        // if (verbose) System.out.println(estEvalGraph);
        // if (verbose) System.out.println(estEvalGraphOut);
        // SearchGraphUtils.graphComparison(estEvalGraphOut, estEvalGraph, System.out);
        output = HsimUtils.errorEval(estEvalGraphOut, estEvalGraph);
        if (verbose) {
            System.out.println(output[0] + " " + output[1] + " " + output[2] + " " + output[3] + " " + output[4]);
        }
    } catch (Exception IOException) {
        IOException.printStackTrace();
    }
    return output;
}
Also used : PatternToDag(edu.cmu.tetrad.search.PatternToDag) FileWriter(java.io.FileWriter) PatternToDag(edu.cmu.tetrad.search.PatternToDag) Fges(edu.cmu.tetrad.search.Fges) BDeuScore(edu.cmu.tetrad.search.BDeuScore)

Example 2 with BDeuScore

use of edu.cmu.tetrad.search.BDeuScore in project tetrad by cmu-phil.

the class HsimRobustCompare method run.

// *************Public Methods*****************8//
public static List<double[]> run(int numVars, double edgesPerNode, int numCases, double penaltyDiscount, int resimSize, int repeat, boolean verbose) {
    // public static void main(String[] args) {
    // first generate the data
    RandomUtil.getInstance().setSeed(1450184147770L);
    // '\t';
    char delimiter = ',';
    final int numEdges = (int) (numVars * edgesPerNode);
    List<Node> vars = new ArrayList<>();
    double[] oErrors = new double[5];
    double[] hsimErrors = new double[5];
    double[] simErrors = new double[5];
    List<double[]> output = new ArrayList<>();
    for (int i = 0; i < numVars; i++) {
        vars.add(new ContinuousVariable("X" + i));
    }
    Graph odag = GraphUtils.randomGraphRandomForwardEdges(vars, 0, numEdges, 30, 15, 15, false, true);
    BayesPm bayesPm = new BayesPm(odag, 2, 2);
    BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
    // oData is the original data set, and odag is the original dag.
    DataSet oData = bayesIm.simulateData(numCases, false);
    // System.out.println(oData);
    // System.out.println(odag);
    // then run FGES
    BDeuScore oscore = new BDeuScore(oData);
    Fges fges = new Fges(oscore);
    fges.setVerbose(false);
    fges.setNumPatternsToStore(0);
    fges.setPenaltyDiscount(penaltyDiscount);
    Graph oGraphOut = fges.search();
    if (verbose)
        System.out.println(oGraphOut);
    // calculate FGES errors
    oErrors = new double[5];
    oErrors = HsimUtils.errorEval(oGraphOut, odag);
    if (verbose)
        System.out.println(oErrors[0] + " " + oErrors[1] + " " + oErrors[2] + " " + oErrors[3] + " " + oErrors[4]);
    // create various simulated data sets
    // //let's do the full simulated data set first: a dag in the FGES pattern fit to the data set.
    PatternToDag pickdag = new PatternToDag(oGraphOut);
    Graph fgesDag = pickdag.patternToDagMeek();
    Dag fgesdag2 = new Dag(fgesDag);
    BayesPm simBayesPm = new BayesPm(fgesdag2, bayesPm);
    DirichletBayesIm simIM = DirichletBayesIm.symmetricDirichletIm(simBayesPm, 1.0);
    DirichletEstimator simEstimator = new DirichletEstimator();
    DirichletBayesIm fittedIM = simEstimator.estimate(simIM, oData);
    DataSet simData = fittedIM.simulateData(numCases, false);
    // //next let's do a schedule of small hsims
    HsimRepeatAutoRun study = new HsimRepeatAutoRun(oData);
    hsimErrors = study.run(resimSize, repeat);
    // calculate errors for all simulated output graphs
    // //full simulation errors first
    BDeuScore simscore = new BDeuScore(simData);
    Fges simfges = new Fges(simscore);
    simfges.setVerbose(false);
    simfges.setNumPatternsToStore(0);
    simfges.setPenaltyDiscount(penaltyDiscount);
    Graph simGraphOut = simfges.search();
    // simErrors = new double[5];
    simErrors = HsimUtils.errorEval(simGraphOut, fgesdag2);
    // first, let's just see what the errors are.
    if (verbose)
        System.out.println("Original erors are: " + oErrors[0] + " " + oErrors[1] + " " + oErrors[2] + " " + oErrors[3] + " " + oErrors[4]);
    if (verbose)
        System.out.println("Full resim errors are: " + simErrors[0] + " " + simErrors[1] + " " + simErrors[2] + " " + simErrors[3] + " " + simErrors[4]);
    if (verbose)
        System.out.println("HSim errors are: " + hsimErrors[0] + " " + hsimErrors[1] + " " + hsimErrors[2] + " " + hsimErrors[3] + " " + hsimErrors[4]);
    // then, let's try to squeeze these numbers down into something more tractable.
    // double[] ErrorDifferenceDifferences;
    // ErrorDifferenceDifferences = new double[5];
    // ErrorDifferenceDifferences[0] = Math.abs(oErrors[0]-simErrors[0])-Math.abs(oErrors[0]-hsimErrors[0]);
    // ErrorDifferenceDifferences[1] = Math.abs(oErrors[1]-simErrors[1])-Math.abs(oErrors[1]-hsimErrors[1]);
    // ErrorDifferenceDifferences[2] = Math.abs(oErrors[2]-simErrors[2])-Math.abs(oErrors[2]-hsimErrors[2]);
    // ErrorDifferenceDifferences[3] = Math.abs(oErrors[3]-simErrors[3])-Math.abs(oErrors[3]-hsimErrors[3]);
    // ErrorDifferenceDifferences[4] = Math.abs(oErrors[4]-simErrors[4])-Math.abs(oErrors[4]-hsimErrors[4]);
    // System.out.println("resim error errors - hsim error errors: " + ErrorDifferenceDifferences[0] + " " + ErrorDifferenceDifferences[1] + " " + ErrorDifferenceDifferences[2] + " " + ErrorDifferenceDifferences[3] + " " + ErrorDifferenceDifferences[4]);
    output.add(oErrors);
    output.add(simErrors);
    output.add(hsimErrors);
    return output;
}
Also used : DataSet(edu.cmu.tetrad.data.DataSet) PatternToDag(edu.cmu.tetrad.search.PatternToDag) ArrayList(java.util.ArrayList) PatternToDag(edu.cmu.tetrad.search.PatternToDag) Fges(edu.cmu.tetrad.search.Fges) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) BDeuScore(edu.cmu.tetrad.search.BDeuScore)

Aggregations

BDeuScore (edu.cmu.tetrad.search.BDeuScore)2 Fges (edu.cmu.tetrad.search.Fges)2 PatternToDag (edu.cmu.tetrad.search.PatternToDag)2 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)1 DataSet (edu.cmu.tetrad.data.DataSet)1 FileWriter (java.io.FileWriter)1 ArrayList (java.util.ArrayList)1