Search in sources :

Example 1 with PatternToDag

use of edu.cmu.tetrad.search.PatternToDag in project tetrad by cmu-phil.

the class IonSearchEditor method addSpecialMenus.

protected void addSpecialMenus(JMenuBar menuBar) {
    if (!(getAlgorithmRunner() instanceof IGesRunner)) {
        JMenu test = new JMenu("Independence");
        menuBar.add(test);
        IndTestMenuItems.addIndependenceTestChoices(test, this);
    // test.addSeparator();
    // 
    // AlgorithmRunner algorithmRunner = getAlgorithmRunner();
    // if (algorithmRunner instanceof IndTestProducer) {
    // IndTestProducer p = (IndTestProducer) algorithmRunner;
    // IndependenceFactsAction action =
    // new IndependenceFactsAction(this, p, "Independence Facts...");
    // test.add(action);
    // }
    }
    JMenu graph = new JMenu("Graph");
    JMenuItem showDags = new JMenuItem("Show DAGs in forbid_latent_common_causes");
    // JMenuItem meekOrient = new JMenuItem("Meek Orientation");
    JMenuItem dagInPattern = new JMenuItem("Choose DAG in forbid_latent_common_causes");
    JMenuItem gesOrient = new JMenuItem("Global Score-based Reorientation");
    JMenuItem nextGraph = new JMenuItem("Next Graph");
    JMenuItem previousGraph = new JMenuItem("Previous Graph");
    // graph.add(new LayoutMenu(this));
    graph.add(new GraphPropertiesAction(getWorkbench()));
    graph.add(new PathsAction(getWorkbench()));
    // graph.add(new DirectedPathsAction(getWorkbench()));
    // graph.add(new TreksAction(getWorkbench()));
    // graph.add(new AllPathsAction(getWorkbench()));
    // graph.add(new NeighborhoodsAction(getWorkbench()));
    graph.add(new TriplesAction(getWorkbench().getGraph(), getAlgorithmRunner()));
    graph.addSeparator();
    // graph.add(meekOrient);
    graph.add(dagInPattern);
    graph.add(gesOrient);
    graph.addSeparator();
    graph.add(previousGraph);
    graph.add(nextGraph);
    graph.addSeparator();
    graph.add(showDags);
    graph.addSeparator();
    graph.add(new JMenuItem(new SelectBidirectedAction(getWorkbench())));
    graph.add(new JMenuItem(new SelectUndirectedAction(getWorkbench())));
    menuBar.add(graph);
    showDags.addActionListener(new ActionListener() {

        public void actionPerformed(ActionEvent e) {
            Window owner = (Window) getTopLevelAncestor();
            new WatchedProcess(owner) {

                public void watch() {
                    // Needs to be a pattern search; this isn't checked
                    // before running the algorithm because of allowable
                    // "slop"--e.g. bidirected edges.
                    AlgorithmRunner runner = getAlgorithmRunner();
                    Graph graph = runner.getGraph();
                    if (graph == null) {
                        JOptionPane.showMessageDialog(JOptionUtils.centeringComp(), "No result gaph.");
                        return;
                    }
                    // if (runner instanceof ImagesRunner) {
                    // GraphScorer scorer = ((ImagesRunner) runner).getGraphScorer();
                    // Graph _graph = ((ImagesRunner) runner).getTopGraphs().get(getIndex()).getGraph();
                    // 
                    // ScoredGraphsDisplay display = new ScoredGraphsDisplay(_graph, scorer);
                    // GraphWorkbench workbench = getWorkbench();
                    // 
                    // EditorWindow editorWindow =
                    // new EditorWindow(display, "Independence Facts",
                    // "Close", false, workbench);
                    // DesktopController.getInstance().addEditorWindow(editorWindow, JLayeredPane.PALETTE_LAYER);
                    // editorWindow.setVisible(true);
                    // }
                    // else {
                    PatternDisplay display = new PatternDisplay(graph);
                    GraphWorkbench workbench = getWorkbench();
                    EditorWindow editorWindow = new EditorWindow(display, "Independence Facts", "Close", false, workbench);
                    DesktopController.getInstance().addEditorWindow(editorWindow, JLayeredPane.PALETTE_LAYER);
                    editorWindow.setVisible(true);
                // }
                }
            };
        }
    });
    // meekOrient.addActionListener(new ActionListener() {
    // public void actionPerformed(ActionEvent e) {
    // ImpliedOrientation rules = getAlgorithmRunner().getMeekRules();
    // rules.setKnowledge((IKnowledge) getAlgorithmRunner().getParams().get("knowledge", new Knowledge2()));
    // rules.orientImplied(getGraph());
    // getGraphHistory().add(getGraph());
    // getWorkbench().setGraph(getGraph());
    // firePropertyChange("modelChanged", null, null);
    // }
    // });
    dagInPattern.addActionListener(new ActionListener() {

        public void actionPerformed(ActionEvent e) {
            Graph graph = new EdgeListGraph(getGraph());
            // Removing bidirected edges from the pattern before selecting a DAG.                                   4
            for (Edge edge : graph.getEdges()) {
                if (Edges.isBidirectedEdge(edge)) {
                    graph.removeEdge(edge);
                }
            }
            PatternToDag search = new PatternToDag(new EdgeListGraphSingleConnections(graph));
            Graph dag = search.patternToDagMeek();
            getGraphHistory().add(dag);
            getWorkbench().setGraph(dag);
            ((AbstractAlgorithmRunner) getAlgorithmRunner()).setResultGraph(dag);
            firePropertyChange("modelChanged", null, null);
        }
    });
    gesOrient.addActionListener(new ActionListener() {

        public void actionPerformed(ActionEvent e) {
            DataModel dataModel = getAlgorithmRunner().getDataModel();
            final Graph graph = SearchGraphUtils.reorient(getGraph(), dataModel, getKnowledge());
            getGraphHistory().add(graph);
            getWorkbench().setGraph(graph);
            firePropertyChange("modelChanged", null, null);
        }
    });
    nextGraph.addActionListener(new ActionListener() {

        public void actionPerformed(ActionEvent e) {
            Graph next = getGraphHistory().next();
            getWorkbench().setGraph(next);
            ((AbstractAlgorithmRunner) getAlgorithmRunner()).setResultGraph(next);
            firePropertyChange("modelChanged", null, null);
        }
    });
    previousGraph.addActionListener(new ActionListener() {

        public void actionPerformed(ActionEvent e) {
            Graph previous = getGraphHistory().previous();
            getWorkbench().setGraph(previous);
            ((AbstractAlgorithmRunner) getAlgorithmRunner()).setResultGraph(previous);
            firePropertyChange("modelChanged", null, null);
        }
    });
    // if (getAlgorithmRunner().supportsKnowledge()) {
    // menuBar.add(new Knowledge2Menu(this));
    // }
    menuBar.add(new LayoutMenu(this));
}
Also used : LayoutMenu(edu.cmu.tetradapp.workbench.LayoutMenu) PatternToDag(edu.cmu.tetrad.search.PatternToDag) ActionEvent(java.awt.event.ActionEvent) WatchedProcess(edu.cmu.tetradapp.util.WatchedProcess) ActionListener(java.awt.event.ActionListener) GraphWorkbench(edu.cmu.tetradapp.workbench.GraphWorkbench) DataModel(edu.cmu.tetrad.data.DataModel)

Example 2 with PatternToDag

use of edu.cmu.tetrad.search.PatternToDag in project tetrad by cmu-phil.

the class FciCcdSearchEditor method addSpecialMenus.

protected void addSpecialMenus(JMenuBar menuBar) {
    if (!(getAlgorithmRunner() instanceof IGesRunner)) {
        JMenu test = new JMenu("Independence");
        menuBar.add(test);
        IndTestMenuItems.addIndependenceTestChoices(test, this);
    // test.addSeparator();
    // 
    // AlgorithmRunner algorithmRunner = getAlgorithmRunner();
    // if (algorithmRunner instanceof IndTestProducer) {
    // IndTestProducer p = (IndTestProducer) algorithmRunner;
    // IndependenceFactsAction action =
    // new IndependenceFactsAction(this, p, "Independence Facts...");
    // test.add(action);
    // }
    }
    JMenu graph = new JMenu("Graph");
    JMenuItem showDags = new JMenuItem("Show DAGs in forbid_latent_common_causes");
    // JMenuItem meekOrient = new JMenuItem("Meek Orientation");
    JMenuItem dagInPattern = new JMenuItem("Choose DAG in forbid_latent_common_causes");
    JMenuItem gesOrient = new JMenuItem("Global Score-based Reorientation");
    JMenuItem nextGraph = new JMenuItem("Next Graph");
    JMenuItem previousGraph = new JMenuItem("Previous Graph");
    // graph.add(new LayoutMenu(this));
    graph.add(new GraphPropertiesAction(getWorkbench()));
    graph.add(new PathsAction(getWorkbench()));
    // graph.add(new DirectedPathsAction(getWorkbench()));
    // graph.add(new TreksAction(getWorkbench()));
    // graph.add(new AllPathsAction(getWorkbench()));
    // graph.add(new NeighborhoodsAction(getWorkbench()));
    graph.add(new TriplesAction(getWorkbench().getGraph(), getAlgorithmRunner()));
    graph.addSeparator();
    // graph.add(meekOrient);
    graph.add(dagInPattern);
    graph.add(gesOrient);
    graph.addSeparator();
    graph.add(previousGraph);
    graph.add(nextGraph);
    graph.addSeparator();
    graph.add(showDags);
    graph.addSeparator();
    graph.add(new JMenuItem(new SelectBidirectedAction(getWorkbench())));
    graph.add(new JMenuItem(new SelectUndirectedAction(getWorkbench())));
    menuBar.add(graph);
    showDags.addActionListener(new ActionListener() {

        public void actionPerformed(ActionEvent e) {
            Window owner = (Window) getTopLevelAncestor();
            new WatchedProcess(owner) {

                public void watch() {
                    // Needs to be a pattern search; this isn't checked
                    // before running the algorithm because of allowable
                    // "slop"--e.g. bidirected edges.
                    AlgorithmRunner runner = getAlgorithmRunner();
                    Graph graph = runner.getGraph();
                    if (graph == null) {
                        JOptionPane.showMessageDialog(JOptionUtils.centeringComp(), "No result gaph.");
                        return;
                    }
                    // if (runner instanceof ImagesRunner) {
                    // GraphScorer scorer = ((ImagesRunner) runner).getGraphScorer();
                    // Graph _graph = ((ImagesRunner) runner).getTopGraphs().get(getIndex()).getGraph();
                    // 
                    // ScoredGraphsDisplay display = new ScoredGraphsDisplay(_graph, scorer);
                    // GraphWorkbench workbench = getWorkbench();
                    // 
                    // EditorWindow editorWindow =
                    // new EditorWindow(display, "Independence Facts",
                    // "Close", false, workbench);
                    // DesktopController.getInstance().addEditorWindow(editorWindow, JLayeredPane.PALETTE_LAYER);
                    // editorWindow.setVisible(true);
                    // }
                    // else {
                    PatternDisplay display = new PatternDisplay(graph);
                    GraphWorkbench workbench = getWorkbench();
                    EditorWindow editorWindow = new EditorWindow(display, "Independence Facts", "Close", false, workbench);
                    DesktopController.getInstance().addEditorWindow(editorWindow, JLayeredPane.PALETTE_LAYER);
                    editorWindow.setVisible(true);
                // }
                }
            };
        }
    });
    // meekOrient.addActionListener(new ActionListener() {
    // public void actionPerformed(ActionEvent e) {
    // ImpliedOrientation rules = getAlgorithmRunner().getMeekRules();
    // rules.setKnowledge((IKnowledge) getAlgorithmRunner().getParams().get("knowledge", new Knowledge2()));
    // rules.orientImplied(getGraph());
    // getGraphHistory().add(getGraph());
    // getWorkbench().setGraph(getGraph());
    // firePropertyChange("modelChanged", null, null);
    // }
    // });
    dagInPattern.addActionListener(new ActionListener() {

        public void actionPerformed(ActionEvent e) {
            Graph graph = new EdgeListGraph(getGraph());
            // Removing bidirected edges from the pattern before selecting a DAG.                                   4
            for (Edge edge : graph.getEdges()) {
                if (Edges.isBidirectedEdge(edge)) {
                    graph.removeEdge(edge);
                }
            }
            PatternToDag search = new PatternToDag(new EdgeListGraphSingleConnections(graph));
            Graph dag = search.patternToDagMeek();
            getGraphHistory().add(dag);
            getWorkbench().setGraph(dag);
            ((AbstractAlgorithmRunner) getAlgorithmRunner()).setResultGraph(dag);
            firePropertyChange("modelChanged", null, null);
        }
    });
    gesOrient.addActionListener(new ActionListener() {

        public void actionPerformed(ActionEvent e) {
            DataModel dataModel = getAlgorithmRunner().getDataModel();
            final Graph graph = SearchGraphUtils.reorient(getGraph(), dataModel, getKnowledge());
            getGraphHistory().add(graph);
            getWorkbench().setGraph(graph);
            firePropertyChange("modelChanged", null, null);
        }
    });
    nextGraph.addActionListener(new ActionListener() {

        public void actionPerformed(ActionEvent e) {
            Graph next = getGraphHistory().next();
            getWorkbench().setGraph(next);
            ((AbstractAlgorithmRunner) getAlgorithmRunner()).setResultGraph(next);
            firePropertyChange("modelChanged", null, null);
        }
    });
    previousGraph.addActionListener(new ActionListener() {

        public void actionPerformed(ActionEvent e) {
            Graph previous = getGraphHistory().previous();
            getWorkbench().setGraph(previous);
            ((AbstractAlgorithmRunner) getAlgorithmRunner()).setResultGraph(previous);
            firePropertyChange("modelChanged", null, null);
        }
    });
    // if (getAlgorithmRunner().supportsKnowledge()) {
    // menuBar.add(new Knowledge2Menu(this));
    // }
    menuBar.add(new LayoutMenu(this));
}
Also used : LayoutMenu(edu.cmu.tetradapp.workbench.LayoutMenu) PatternToDag(edu.cmu.tetrad.search.PatternToDag) ActionEvent(java.awt.event.ActionEvent) WatchedProcess(edu.cmu.tetradapp.util.WatchedProcess) ActionListener(java.awt.event.ActionListener) GraphWorkbench(edu.cmu.tetradapp.workbench.GraphWorkbench) DataModel(edu.cmu.tetrad.data.DataModel)

Example 3 with PatternToDag

use of edu.cmu.tetrad.search.PatternToDag in project tetrad by cmu-phil.

the class HsimAutoRun method run.

// ***********Public methods*************//
public double[] run(int resimSize) {
    // modify this so that verbose is a private data value, and so that data can be taken from either a dataset or a file.
    // ===========read data from file=============
    Set<String> eVars = new HashSet<String>();
    eVars.add("MULT");
    double[] output;
    output = new double[5];
    try {
        // ==== try with BigDataSetUtility ==============
        // DataSet regularDataSet = BigDataSetUtility.readInDiscreteData(new File(readfilename), delimiter, eVars);
        // ======done with BigDataSetUtility=============
        // if (verbose) System.out.println("Regular cols: " + regularDataSet.getNumColumns() + " rows: " + regularDataSet.getNumRows());
        // testing the read file
        // DataWriter.writeRectangularData(dataSet, new FileWriter("dataOut2.txt"), '\t');
        // apply Hsim to data, with whatever parameters
        // ========first make the Dag for Hsim==========
        BDeuScore score = new BDeuScore(data);
        // ICovarianceMatrix cov = new CovarianceMatrixOnTheFly(dataSet);
        double penaltyDiscount = 2.0;
        Fges fges = new Fges(score);
        fges.setVerbose(false);
        fges.setNumPatternsToStore(0);
        fges.setPenaltyDiscount(penaltyDiscount);
        Graph estGraph = fges.search();
        // if (verbose) System.out.println(estGraph);
        Graph estPattern = new EdgeListGraphSingleConnections(estGraph);
        PatternToDag patternToDag = new PatternToDag(estPattern);
        Graph estGraphDAG = patternToDag.patternToDagMeek();
        Dag estDAG = new Dag(estGraphDAG);
        // ===========Identify the nodes to be resimulated===========
        // select a random node as the centroid
        List<Node> allNodes = estGraph.getNodes();
        int size = allNodes.size();
        int randIndex = new Random().nextInt(size);
        Node centroid = allNodes.get(randIndex);
        if (verbose) {
            System.out.println("the centroid is " + centroid);
        }
        List<Node> queue = new ArrayList<>();
        queue.add(centroid);
        List<Node> queueAdd = new ArrayList<Node>();
        // if (verbose) System.out.println(queue);
        while (queue.size() < resimSize) {
            // if (verbose) System.out.println(queue.size() + " vs " + resimSize);
            // find nodes adjacent to nodes in current queue, add them to a queue without duplicating nodes
            int qsize = queue.size();
            for (int i = 0; i < qsize; i++) {
                // find set of adjacent nodes
                queueAdd = estGraph.getAdjacentNodes(queue.get(i));
                // remove nodes that are already in queue
                queueAdd.removeAll(queue);
                // //**** If queueAdd is empty at this stage, randomly select a node to add
                while (queueAdd.size() < 1) {
                    queueAdd.add(allNodes.get(new Random().nextInt(size)));
                }
                // add remaining nodes to queue
                queue.addAll(queueAdd);
                // break early when queue outgrows resimsize
                if (queue.size() >= resimSize) {
                    break;
                }
            }
        }
        // if queue is too big, remove nodes from the end until it is small enough.
        while (queue.size() > resimSize) {
            queue.remove(queue.size() - 1);
        // if (verbose) System.out.println(queue);
        }
        Set<Node> simnodes = new HashSet<Node>(queue);
        if (verbose) {
            System.out.println("the resimmed nodes are " + simnodes);
        }
        // ===========Apply the hybrid resimulation===============
        // regularDataSet
        Hsim hsim = new Hsim(estDAG, simnodes, data);
        DataSet newDataSet = hsim.hybridsimulate();
        // write output to a new file
        if (write) {
            FileWriter fileWriter = new FileWriter(filenameOut);
            DataWriter.writeRectangularData(newDataSet, fileWriter, delimiter);
            fileWriter.close();
        }
        // =======Run FGES on the output data, and compare it to the original learned graph
        // Path dataFileOut = Paths.get(filenameOut);
        // edu.cmu.tetrad.io.DataReader dataReaderOut = new VerticalTabularDiscreteDataReader(dataFileOut, delimiter);
        // DataSet dataSetOut = dataReaderOut.readInData(eVars);
        BDeuScore newscore = new BDeuScore(newDataSet);
        Fges fgesOut = new Fges(newscore);
        fgesOut.setVerbose(false);
        fgesOut.setNumPatternsToStore(0);
        fgesOut.setPenaltyDiscount(2.0);
        // fgesOut.setOut(out);
        // fgesOut.setFaithfulnessAssumed(true);
        // fgesOut.setMaxIndegree(1);
        // fgesOut.setCycleBound(5);
        Graph estGraphOut = fgesOut.search();
        // if (verbose) System.out.println(" bugchecking: fges estGraphOut: " + estGraphOut);
        // doing the replaceNodes trick to fix some bugs
        estGraphOut = GraphUtils.replaceNodes(estGraphOut, estDAG.getNodes());
        // restrict the comparison to the simnodes and edges to their parents
        Set<Node> allParents = HsimUtils.getAllParents(estGraphOut, simnodes);
        Set<Node> addParents = HsimUtils.getAllParents(estDAG, simnodes);
        allParents.addAll(addParents);
        Graph estEvalGraphOut = HsimUtils.evalEdges(estGraphOut, simnodes, allParents);
        Graph estEvalGraph = HsimUtils.evalEdges(estDAG, simnodes, allParents);
        // SearchGraphUtils.graphComparison(estGraph, estGraphOut, System.out);
        estEvalGraphOut = GraphUtils.replaceNodes(estEvalGraphOut, estEvalGraph.getNodes());
        // if (verbose) System.out.println(estEvalGraph);
        // if (verbose) System.out.println(estEvalGraphOut);
        // SearchGraphUtils.graphComparison(estEvalGraphOut, estEvalGraph, System.out);
        output = HsimUtils.errorEval(estEvalGraphOut, estEvalGraph);
        if (verbose) {
            System.out.println(output[0] + " " + output[1] + " " + output[2] + " " + output[3] + " " + output[4]);
        }
    } catch (Exception IOException) {
        IOException.printStackTrace();
    }
    return output;
}
Also used : PatternToDag(edu.cmu.tetrad.search.PatternToDag) FileWriter(java.io.FileWriter) PatternToDag(edu.cmu.tetrad.search.PatternToDag) Fges(edu.cmu.tetrad.search.Fges) BDeuScore(edu.cmu.tetrad.search.BDeuScore)

Example 4 with PatternToDag

use of edu.cmu.tetrad.search.PatternToDag in project tetrad by cmu-phil.

the class HsimEvalFromData method main.

public static void main(String[] args) {
    long timestart = System.nanoTime();
    System.out.println("Beginning Evaluation");
    String nl = System.lineSeparator();
    String output = "Simulation edu.cmu.tetrad.study output comparing Fsim and Hsim on predicting graph discovery accuracy" + nl;
    int iterations = 100;
    int vars = 20;
    int cases = 500;
    int edgeratio = 3;
    List<Integer> hsimRepeat = Arrays.asList(40);
    List<Integer> fsimRepeat = Arrays.asList(40);
    List<PRAOerrors>[] fsimErrsByPars = new ArrayList[fsimRepeat.size()];
    int whichFrepeat = 0;
    for (int frepeat : fsimRepeat) {
        fsimErrsByPars[whichFrepeat] = new ArrayList<PRAOerrors>();
        whichFrepeat++;
    }
    List<PRAOerrors>[][] hsimErrsByPars = new ArrayList[1][hsimRepeat.size()];
    // System.out.println(resimSize.size()+" "+hsimRepeat.size());
    int whichHrepeat;
    whichHrepeat = 0;
    for (int hrepeat : hsimRepeat) {
        // System.out.println(whichrsize+" "+whichHrepeat);
        hsimErrsByPars[0][whichHrepeat] = new ArrayList<PRAOerrors>();
        whichHrepeat++;
    }
    // !(*%(@!*^!($%!^ START ITERATING HERE !#$%(*$#@!^(*!$*%(!$#
    try {
        for (int iterate = 0; iterate < iterations; iterate++) {
            System.out.println("iteration " + iterate);
            // @#$%@$%^@$^@$^@%$%@$#^ LOADING THE DATA AND GRAPH @$#%%*#^##*^$#@%$
            DataSet data1;
            Graph graph1 = GraphUtils.loadGraphTxt(new File("graph/graph.1.txt"));
            Dag odag = new Dag(graph1);
            Set<String> eVars = new HashSet<String>();
            eVars.add("MULT");
            Path dataFile = Paths.get("data/data.1.txt");
            TabularDataReader dataReader = new ContinuousTabularDataFileReader(dataFile.toFile(), Delimiter.TAB);
            data1 = (DataSet) DataConvertUtils.toDataModel(dataReader.readInData(eVars));
            vars = data1.getNumColumns();
            cases = data1.getNumRows();
            edgeratio = 3;
            // !#@^$@&%^!#$!&@^ CALCULATING TARGET ERRORS $%$#@^@!%!#^$!%$#%
            ICovarianceMatrix newcov = new CovarianceMatrixOnTheFly(data1);
            SemBicScore oscore = new SemBicScore(newcov);
            Fges ofgs = new Fges(oscore);
            ofgs.setVerbose(false);
            ofgs.setNumPatternsToStore(0);
            // ***********This is the original FGS output on the data
            Graph oFGSGraph = ofgs.search();
            PRAOerrors oErrors = new PRAOerrors(HsimUtils.errorEval(oFGSGraph, odag), "target errors");
            // **then step 1: full resim. iterate through the combinations of estimator parameters (just repeat num)
            for (whichFrepeat = 0; whichFrepeat < fsimRepeat.size(); whichFrepeat++) {
                ArrayList<PRAOerrors> errorsList = new ArrayList<PRAOerrors>();
                for (int r = 0; r < fsimRepeat.get(whichFrepeat); r++) {
                    PatternToDag pickdag = new PatternToDag(oFGSGraph);
                    Graph fgsDag = pickdag.patternToDagMeek();
                    Dag fgsdag2 = new Dag(fgsDag);
                    // then fit an IM to this dag and the data. GeneralizedSemEstimator seems to bug out
                    // GeneralizedSemPm simSemPm = new GeneralizedSemPm(fgsdag2);
                    // GeneralizedSemEstimator gsemEstimator = new GeneralizedSemEstimator();
                    // GeneralizedSemIm fittedIM = gsemEstimator.estimate(simSemPm, oData);
                    SemPm simSemPm = new SemPm(fgsdag2);
                    // BayesPm simBayesPm = new BayesPm(fgsdag2, bayesPm);
                    SemEstimator simSemEstimator = new SemEstimator(data1, simSemPm);
                    SemIm fittedIM = simSemEstimator.estimate();
                    DataSet simData = fittedIM.simulateData(data1.getNumRows(), false);
                    // after making the full resim data (simData), run FGS on that
                    ICovarianceMatrix simcov = new CovarianceMatrixOnTheFly(simData);
                    SemBicScore simscore = new SemBicScore(simcov);
                    Fges simfgs = new Fges(simscore);
                    simfgs.setVerbose(false);
                    simfgs.setNumPatternsToStore(0);
                    Graph simGraphOut = simfgs.search();
                    PRAOerrors simErrors = new PRAOerrors(HsimUtils.errorEval(simGraphOut, fgsdag2), "Fsim errors " + r);
                    errorsList.add(simErrors);
                }
                PRAOerrors avErrors = new PRAOerrors(errorsList, "Average errors for Fsim at repeat=" + fsimRepeat.get(whichFrepeat));
                // if (verbosity>3) System.out.println(avErrors.allToString());
                // ****calculate the squared errors of prediction, store all these errors in a list
                double FsimAR2 = (avErrors.getAdjRecall() - oErrors.getAdjRecall()) * (avErrors.getAdjRecall() - oErrors.getAdjRecall());
                double FsimAP2 = (avErrors.getAdjPrecision() - oErrors.getAdjPrecision()) * (avErrors.getAdjPrecision() - oErrors.getAdjPrecision());
                double FsimOR2 = (avErrors.getOrientRecall() - oErrors.getOrientRecall()) * (avErrors.getOrientRecall() - oErrors.getOrientRecall());
                double FsimOP2 = (avErrors.getOrientPrecision() - oErrors.getOrientPrecision()) * (avErrors.getOrientPrecision() - oErrors.getOrientPrecision());
                PRAOerrors Fsim2 = new PRAOerrors(new double[] { FsimAR2, FsimAP2, FsimOR2, FsimOP2 }, "squared errors for Fsim at repeat=" + fsimRepeat.get(whichFrepeat));
                // add the fsim squared errors to the appropriate list
                fsimErrsByPars[whichFrepeat].add(Fsim2);
            }
            // **then step 2: hybrid sim. iterate through combos of params (repeat num, resimsize)
            for (whichHrepeat = 0; whichHrepeat < hsimRepeat.size(); whichHrepeat++) {
                HsimRepeatAC study = new HsimRepeatAC(data1);
                PRAOerrors HsimErrors = new PRAOerrors(study.run(1, hsimRepeat.get(whichHrepeat)), "Hsim errors" + "at rsize=" + 1 + " repeat=" + hsimRepeat.get(whichHrepeat));
                // ****calculate the squared errors of prediction
                double HsimAR2 = (HsimErrors.getAdjRecall() - oErrors.getAdjRecall()) * (HsimErrors.getAdjRecall() - oErrors.getAdjRecall());
                double HsimAP2 = (HsimErrors.getAdjPrecision() - oErrors.getAdjPrecision()) * (HsimErrors.getAdjPrecision() - oErrors.getAdjPrecision());
                double HsimOR2 = (HsimErrors.getOrientRecall() - oErrors.getOrientRecall()) * (HsimErrors.getOrientRecall() - oErrors.getOrientRecall());
                double HsimOP2 = (HsimErrors.getOrientPrecision() - oErrors.getOrientPrecision()) * (HsimErrors.getOrientPrecision() - oErrors.getOrientPrecision());
                PRAOerrors Hsim2 = new PRAOerrors(new double[] { HsimAR2, HsimAP2, HsimOR2, HsimOP2 }, "squared errors for Hsim, rsize=" + 1 + " repeat=" + hsimRepeat.get(whichHrepeat));
                hsimErrsByPars[0][whichHrepeat].add(Hsim2);
            }
        }
        // Average the squared errors for each set of fsim/hsim params across all iterations
        PRAOerrors[] fMSE = new PRAOerrors[fsimRepeat.size()];
        PRAOerrors[][] hMSE = new PRAOerrors[1][hsimRepeat.size()];
        String[][] latexTableArray = new String[1 * hsimRepeat.size() + fsimRepeat.size()][5];
        for (int j = 0; j < fMSE.length; j++) {
            fMSE[j] = new PRAOerrors(fsimErrsByPars[j], "MSE for Fsim at vars=" + vars + " edgeratio=" + edgeratio + " cases=" + cases + " frepeat=" + fsimRepeat.get(j) + " iterations=" + iterations);
            // if(verbosity>0){System.out.println(fMSE[j].allToString());}
            output = output + fMSE[j].allToString() + nl;
            latexTableArray[j] = prelimToPRAOtable(fMSE[j]);
        }
        for (int j = 0; j < hMSE.length; j++) {
            for (int k = 0; k < hMSE[j].length; k++) {
                hMSE[j][k] = new PRAOerrors(hsimErrsByPars[j][k], "MSE for Hsim at vars=" + vars + " edgeratio=" + edgeratio + " cases=" + cases + " rsize=" + 1 + " repeat=" + hsimRepeat.get(k) + " iterations=" + iterations);
                // if(verbosity>0){System.out.println(hMSE[j][k].allToString());}
                output = output + hMSE[j][k].allToString() + nl;
                latexTableArray[fsimRepeat.size() + j * hMSE[j].length + k] = prelimToPRAOtable(hMSE[j][k]);
            }
        }
        // record all the params, the base error values, and the fsim/hsim mean squared errors
        String latexTable = HsimUtils.makeLatexTable(latexTableArray);
        PrintWriter writer = new PrintWriter("latexTable.txt", "UTF-8");
        writer.println(latexTable);
        writer.close();
        PrintWriter writer2 = new PrintWriter("HvsF-SimulationEvaluation.txt", "UTF-8");
        writer2.println(output);
        writer2.close();
        long timestop = System.nanoTime();
        System.out.println("Evaluation Concluded. Duration: " + (timestop - timestart) / 1000000000 + "s");
    } catch (Exception IOException) {
        IOException.printStackTrace();
    }
}
Also used : TabularDataReader(edu.pitt.dbmi.data.reader.tabular.TabularDataReader) DataSet(edu.cmu.tetrad.data.DataSet) ICovarianceMatrix(edu.cmu.tetrad.data.ICovarianceMatrix) SemPm(edu.cmu.tetrad.sem.SemPm) SemEstimator(edu.cmu.tetrad.sem.SemEstimator) PrintWriter(java.io.PrintWriter) Path(java.nio.file.Path) PatternToDag(edu.cmu.tetrad.search.PatternToDag) ContinuousTabularDataFileReader(edu.pitt.dbmi.data.reader.tabular.ContinuousTabularDataFileReader) PatternToDag(edu.cmu.tetrad.search.PatternToDag) Dag(edu.cmu.tetrad.graph.Dag) Fges(edu.cmu.tetrad.search.Fges) Graph(edu.cmu.tetrad.graph.Graph) CovarianceMatrixOnTheFly(edu.cmu.tetrad.data.CovarianceMatrixOnTheFly) File(java.io.File) SemIm(edu.cmu.tetrad.sem.SemIm) SemBicScore(edu.cmu.tetrad.search.SemBicScore)

Example 5 with PatternToDag

use of edu.cmu.tetrad.search.PatternToDag in project tetrad by cmu-phil.

the class HsimRobustCompare method run.

// *************Public Methods*****************8//
public static List<double[]> run(int numVars, double edgesPerNode, int numCases, double penaltyDiscount, int resimSize, int repeat, boolean verbose) {
    // public static void main(String[] args) {
    // first generate the data
    RandomUtil.getInstance().setSeed(1450184147770L);
    // '\t';
    char delimiter = ',';
    final int numEdges = (int) (numVars * edgesPerNode);
    List<Node> vars = new ArrayList<>();
    double[] oErrors = new double[5];
    double[] hsimErrors = new double[5];
    double[] simErrors = new double[5];
    List<double[]> output = new ArrayList<>();
    for (int i = 0; i < numVars; i++) {
        vars.add(new ContinuousVariable("X" + i));
    }
    Graph odag = GraphUtils.randomGraphRandomForwardEdges(vars, 0, numEdges, 30, 15, 15, false, true);
    BayesPm bayesPm = new BayesPm(odag, 2, 2);
    BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
    // oData is the original data set, and odag is the original dag.
    DataSet oData = bayesIm.simulateData(numCases, false);
    // System.out.println(oData);
    // System.out.println(odag);
    // then run FGES
    BDeuScore oscore = new BDeuScore(oData);
    Fges fges = new Fges(oscore);
    fges.setVerbose(false);
    fges.setNumPatternsToStore(0);
    fges.setPenaltyDiscount(penaltyDiscount);
    Graph oGraphOut = fges.search();
    if (verbose)
        System.out.println(oGraphOut);
    // calculate FGES errors
    oErrors = new double[5];
    oErrors = HsimUtils.errorEval(oGraphOut, odag);
    if (verbose)
        System.out.println(oErrors[0] + " " + oErrors[1] + " " + oErrors[2] + " " + oErrors[3] + " " + oErrors[4]);
    // create various simulated data sets
    // //let's do the full simulated data set first: a dag in the FGES pattern fit to the data set.
    PatternToDag pickdag = new PatternToDag(oGraphOut);
    Graph fgesDag = pickdag.patternToDagMeek();
    Dag fgesdag2 = new Dag(fgesDag);
    BayesPm simBayesPm = new BayesPm(fgesdag2, bayesPm);
    DirichletBayesIm simIM = DirichletBayesIm.symmetricDirichletIm(simBayesPm, 1.0);
    DirichletEstimator simEstimator = new DirichletEstimator();
    DirichletBayesIm fittedIM = simEstimator.estimate(simIM, oData);
    DataSet simData = fittedIM.simulateData(numCases, false);
    // //next let's do a schedule of small hsims
    HsimRepeatAutoRun study = new HsimRepeatAutoRun(oData);
    hsimErrors = study.run(resimSize, repeat);
    // calculate errors for all simulated output graphs
    // //full simulation errors first
    BDeuScore simscore = new BDeuScore(simData);
    Fges simfges = new Fges(simscore);
    simfges.setVerbose(false);
    simfges.setNumPatternsToStore(0);
    simfges.setPenaltyDiscount(penaltyDiscount);
    Graph simGraphOut = simfges.search();
    // simErrors = new double[5];
    simErrors = HsimUtils.errorEval(simGraphOut, fgesdag2);
    // first, let's just see what the errors are.
    if (verbose)
        System.out.println("Original erors are: " + oErrors[0] + " " + oErrors[1] + " " + oErrors[2] + " " + oErrors[3] + " " + oErrors[4]);
    if (verbose)
        System.out.println("Full resim errors are: " + simErrors[0] + " " + simErrors[1] + " " + simErrors[2] + " " + simErrors[3] + " " + simErrors[4]);
    if (verbose)
        System.out.println("HSim errors are: " + hsimErrors[0] + " " + hsimErrors[1] + " " + hsimErrors[2] + " " + hsimErrors[3] + " " + hsimErrors[4]);
    // then, let's try to squeeze these numbers down into something more tractable.
    // double[] ErrorDifferenceDifferences;
    // ErrorDifferenceDifferences = new double[5];
    // ErrorDifferenceDifferences[0] = Math.abs(oErrors[0]-simErrors[0])-Math.abs(oErrors[0]-hsimErrors[0]);
    // ErrorDifferenceDifferences[1] = Math.abs(oErrors[1]-simErrors[1])-Math.abs(oErrors[1]-hsimErrors[1]);
    // ErrorDifferenceDifferences[2] = Math.abs(oErrors[2]-simErrors[2])-Math.abs(oErrors[2]-hsimErrors[2]);
    // ErrorDifferenceDifferences[3] = Math.abs(oErrors[3]-simErrors[3])-Math.abs(oErrors[3]-hsimErrors[3]);
    // ErrorDifferenceDifferences[4] = Math.abs(oErrors[4]-simErrors[4])-Math.abs(oErrors[4]-hsimErrors[4]);
    // System.out.println("resim error errors - hsim error errors: " + ErrorDifferenceDifferences[0] + " " + ErrorDifferenceDifferences[1] + " " + ErrorDifferenceDifferences[2] + " " + ErrorDifferenceDifferences[3] + " " + ErrorDifferenceDifferences[4]);
    output.add(oErrors);
    output.add(simErrors);
    output.add(hsimErrors);
    return output;
}
Also used : DataSet(edu.cmu.tetrad.data.DataSet) PatternToDag(edu.cmu.tetrad.search.PatternToDag) ArrayList(java.util.ArrayList) PatternToDag(edu.cmu.tetrad.search.PatternToDag) Fges(edu.cmu.tetrad.search.Fges) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) BDeuScore(edu.cmu.tetrad.search.BDeuScore)

Aggregations

PatternToDag (edu.cmu.tetrad.search.PatternToDag)8 Fges (edu.cmu.tetrad.search.Fges)3 WatchedProcess (edu.cmu.tetradapp.util.WatchedProcess)3 GraphWorkbench (edu.cmu.tetradapp.workbench.GraphWorkbench)3 LayoutMenu (edu.cmu.tetradapp.workbench.LayoutMenu)3 ActionEvent (java.awt.event.ActionEvent)3 ActionListener (java.awt.event.ActionListener)3 DataModel (edu.cmu.tetrad.data.DataModel)2 DataSet (edu.cmu.tetrad.data.DataSet)2 BDeuScore (edu.cmu.tetrad.search.BDeuScore)2 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)1 CovarianceMatrixOnTheFly (edu.cmu.tetrad.data.CovarianceMatrixOnTheFly)1 ICovarianceMatrix (edu.cmu.tetrad.data.ICovarianceMatrix)1 Dag (edu.cmu.tetrad.graph.Dag)1 Graph (edu.cmu.tetrad.graph.Graph)1 SemBicScore (edu.cmu.tetrad.search.SemBicScore)1 SemEstimator (edu.cmu.tetrad.sem.SemEstimator)1 SemIm (edu.cmu.tetrad.sem.SemIm)1 SemPm (edu.cmu.tetrad.sem.SemPm)1 ContinuousTabularDataFileReader (edu.pitt.dbmi.data.reader.tabular.ContinuousTabularDataFileReader)1