Search in sources :

Example 21 with BayesPm

use of edu.cmu.tetrad.bayes.BayesPm in project tetrad by cmu-phil.

the class PerformanceTests method testFgesMb.

private void testFgesMb(int numVars, double edgeFactor, int numCases, int numRuns, boolean continuous) {
    double penaltyDiscount = 4.0;
    int structurePrior = 10;
    int samplePrior = 10;
    int maxIndegree = -1;
    // boolean faithfulness = false;
    List<int[][]> allCounts = new ArrayList<>();
    List<double[]> comparisons = new ArrayList<>();
    List<Double> degrees = new ArrayList<>();
    List<Long> elapsedTimes = new ArrayList<>();
    System.out.println("Making dag");
    Graph dag = makeDag(numVars, edgeFactor);
    System.out.println(new Date());
    System.out.println("Calculating pattern for DAG");
    Graph pattern = SearchGraphUtils.patternForDag(dag);
    int[] tiers = new int[dag.getNumNodes()];
    for (int i = 0; i < dag.getNumNodes(); i++) {
        tiers[i] = i;
    }
    System.out.println("Graph done");
    long time1 = System.currentTimeMillis();
    out.println("Graph done");
    System.out.println(new Date());
    System.out.println("Starting simulation");
    Graph estPattern;
    long elapsed;
    FgesMb fges;
    List<Node> vars;
    if (continuous) {
        init(new File("FgesMb.comparison.continuous" + numVars + "." + (int) (edgeFactor * numVars) + "." + numCases + "." + numRuns + ".txt"), "Num runs = " + numRuns);
        out.println("Num vars = " + numVars);
        out.println("Num edges = " + (int) (numVars * edgeFactor));
        out.println("Num cases = " + numCases);
        out.println("Penalty discount = " + penaltyDiscount);
        out.println("Depth = " + maxIndegree);
        out.println();
        out.println(new Date());
        vars = dag.getNodes();
        LargeScaleSimulation simulator = new LargeScaleSimulation(dag, vars, tiers);
        simulator.setVerbose(false);
        simulator.setOut(out);
        DataSet data = simulator.simulateDataFisher(numCases);
        System.out.println("Finishing simulation");
        System.out.println(new Date());
        long time2 = System.currentTimeMillis();
        out.println("Elapsed (simulating the data): " + (time2 - time1) + " ms");
        System.out.println(new Date());
        System.out.println("Making covariance matrix");
        long time3 = System.currentTimeMillis();
        ICovarianceMatrix cov = new CovarianceMatrixOnTheFly(data, true);
        System.out.println("Covariance matrix done");
        out.println("Elapsed (calculating cov): " + (time3 - time2) + " ms\n");
        SemBicScore score = new SemBicScore(cov);
        score.setPenaltyDiscount(penaltyDiscount);
        System.out.println(new Date());
        System.out.println("\nStarting FGES-MB");
        fges = new FgesMb(score);
        fges.setVerbose(false);
        fges.setNumPatternsToStore(0);
        fges.setOut(System.out);
        // fges.setHeuristicSpeedup(faithfulness);
        fges.setMaxIndegree(maxIndegree);
        fges.setCycleBound(-1);
    } else {
        init(new File("FgesMb.comparison.discrete" + numVars + "." + (int) (edgeFactor * numVars) + "." + numCases + "." + numRuns + ".txt"), "Num runs = " + numRuns);
        out.println("Num vars = " + numVars);
        out.println("Num edges = " + (int) (numVars * edgeFactor));
        out.println("Num cases = " + numCases);
        out.println("Sample prior = " + samplePrior);
        out.println("Structure prior = " + structurePrior);
        out.println("Depth = " + maxIndegree);
        out.println();
        out.println(new Date());
        BayesPm pm = new BayesPm(dag, 3, 3);
        MlBayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM);
        DataSet data = im.simulateData(numCases, false, tiers);
        vars = data.getVariables();
        pattern = GraphUtils.replaceNodes(pattern, vars);
        System.out.println("Finishing simulation");
        long time2 = System.currentTimeMillis();
        out.println("Elapsed (simulating the data): " + (time2 - time1) + " ms");
        long time3 = System.currentTimeMillis();
        BDeuScore score = new BDeuScore(data);
        score.setStructurePrior(structurePrior);
        score.setSamplePrior(samplePrior);
        System.out.println(new Date());
        System.out.println("\nStarting FGES");
        long time4 = System.currentTimeMillis();
        fges = new FgesMb(score);
        fges.setVerbose(false);
        fges.setNumPatternsToStore(0);
        fges.setOut(System.out);
        // fges.setHeuristicSpeedup(faithfulness);
        fges.setMaxIndegree(maxIndegree);
        fges.setCycleBound(-1);
        long timeb = System.currentTimeMillis();
        out.println("Time consructing BDeu score " + (time4 - time3) + " ms");
        out.println("Time for FGES-MB constructor " + (timeb - time4) + " ms");
        out.println();
    }
    int numSkipped = 0;
    for (int run = 0; run < numRuns; run++) {
        out.println("\n\n\n******************************** RUN " + (run + 1) + " ********************************\n\n");
        Node target = vars.get(RandomUtil.getInstance().nextInt(vars.size()));
        System.out.println("Target = " + target);
        long timea = System.currentTimeMillis();
        estPattern = fges.search(target);
        long timed = System.currentTimeMillis();
        elapsed = timed - timea;
        Set<Node> mb = new HashSet<>();
        mb.add(target);
        mb.addAll(pattern.getAdjacentNodes(target));
        for (Node child : pattern.getChildren(target)) {
            mb.addAll(pattern.getParents(child));
        }
        Graph trueMbGraph = pattern.subgraph(new ArrayList<>(mb));
        long timec = System.currentTimeMillis();
        out.println("Time for FGES-MB search " + (timec - timea) + " ms");
        out.println();
        System.out.println("Done with FGES");
        System.out.println(new Date());
        double[] comparison = new double[4];
        System.out.println("Counting misclassifications.");
        int[][] counts = GraphUtils.edgeMisclassificationCounts(trueMbGraph, estPattern, false);
        allCounts.add(counts);
        System.out.println(new Date());
        int sumRow = counts[4][0] + counts[4][3] + counts[4][5];
        int sumCol = counts[0][3] + counts[4][3] + counts[5][3] + counts[7][3];
        int trueArrow = counts[4][3];
        int sumTrueAdjacencies = 0;
        for (int i = 0; i < 7; i++) {
            for (int j = 0; j < 5; j++) {
                sumTrueAdjacencies += counts[i][j];
            }
        }
        int falsePositiveAdjacencies = 0;
        for (int j = 0; j < 5; j++) {
            falsePositiveAdjacencies += counts[7][j];
        }
        int falseNegativeAdjacencies = 0;
        for (int i = 0; i < 5; i++) {
            falseNegativeAdjacencies += counts[i][5];
        }
        comparison[0] = sumTrueAdjacencies / (double) (sumTrueAdjacencies + falsePositiveAdjacencies);
        comparison[1] = sumTrueAdjacencies / (double) (sumTrueAdjacencies + falseNegativeAdjacencies);
        comparison[2] = trueArrow / (double) sumCol;
        comparison[3] = trueArrow / (double) sumRow;
        // if (Double.isNaN(comparison[0]) || Double.isNaN(comparison[1]) || Double.isNaN(comparison[2]) ||
        // Double.isNaN(comparison[3])) {
        // run--;
        // numSkipped++;
        // continue;
        // }
        comparisons.add(comparison);
        out.println(GraphUtils.edgeMisclassifications(counts));
        out.println(precisionRecall(comparison));
        // printAverageConfusion("Average", allCounts);
        elapsedTimes.add(elapsed);
        out.println("\nElapsed: " + elapsed + " ms");
    }
    printAverageConfusion("Average", allCounts, new DecimalFormat("0.0"));
    printAveragePrecisionRecall(comparisons);
    out.println("Number of runs skipped because of undefined accuracies: " + numSkipped);
    printAverageStatistics(elapsedTimes, degrees);
    out.close();
}
Also used : MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) DecimalFormat(java.text.DecimalFormat) LargeScaleSimulation(edu.cmu.tetrad.sem.LargeScaleSimulation) BayesPm(edu.cmu.tetrad.bayes.BayesPm)

Example 22 with BayesPm

use of edu.cmu.tetrad.bayes.BayesPm in project tetrad by cmu-phil.

the class PerformanceTests method testFges.

private void testFges(int numVars, double edgeFactor, int numCases, int numRuns, boolean continuous) {
    out.println(new Date());
    // RandomUtil.getInstance().setSeed(4828384343999L);
    double penaltyDiscount = 4.0;
    int maxIndegree = 5;
    boolean faithfulness = true;
    // RandomUtil.getInstance().setSeed(50304050454L);
    List<int[][]> allCounts = new ArrayList<>();
    List<double[]> comparisons = new ArrayList<>();
    List<Double> degrees = new ArrayList<>();
    List<Long> elapsedTimes = new ArrayList<>();
    if (continuous) {
        init(new File("fges.comparison.continuous" + numVars + "." + (int) (edgeFactor * numVars) + "." + numCases + "." + numRuns + ".txt"), "Num runs = " + numRuns);
        out.println("Num vars = " + numVars);
        out.println("Num edges = " + (int) (numVars * edgeFactor));
        out.println("Num cases = " + numCases);
        out.println("Penalty discount = " + penaltyDiscount);
        out.println("Depth = " + maxIndegree);
        out.println();
    } else {
        init(new File("fges.comparison.discrete" + numVars + "." + (int) (edgeFactor * numVars) + "." + numCases + "." + numRuns + ".txt"), "Num runs = " + numRuns);
        out.println("Num vars = " + numVars);
        out.println("Num edges = " + (int) (numVars * edgeFactor));
        out.println("Num cases = " + numCases);
        out.println("Sample prior = " + 1);
        out.println("Structure prior = " + 1);
        out.println("Depth = " + 1);
        out.println();
    }
    for (int run = 0; run < numRuns; run++) {
        out.println("\n\n\n******************************** RUN " + (run + 1) + " ********************************\n\n");
        System.out.println("Making dag");
        out.println(new Date());
        Graph dag = makeDag(numVars, edgeFactor);
        System.out.println(new Date());
        System.out.println("Calculating pattern for DAG");
        Graph pattern = SearchGraphUtils.patternForDag(dag);
        List<Node> vars = dag.getNodes();
        int[] tiers = new int[vars.size()];
        for (int i = 0; i < vars.size(); i++) {
            tiers[i] = i;
        }
        System.out.println("Graph done");
        long time1 = System.currentTimeMillis();
        out.println("Graph done");
        System.out.println(new Date());
        System.out.println("Starting simulation");
        Graph estPattern;
        long elapsed;
        if (continuous) {
            LargeScaleSimulation simulator = new LargeScaleSimulation(dag, vars, tiers);
            simulator.setVerbose(false);
            simulator.setOut(out);
            DataSet data = simulator.simulateDataFisher(numCases);
            System.out.println("Finishing simulation");
            System.out.println(new Date());
            long time2 = System.currentTimeMillis();
            out.println("Elapsed (simulating the data): " + (time2 - time1) + " ms");
            System.out.println(new Date());
            System.out.println("Making covariance matrix");
            long time3 = System.currentTimeMillis();
            ICovarianceMatrix cov = new CovarianceMatrixOnTheFly(data, true);
            System.out.println("Covariance matrix done");
            out.println("Elapsed (calculating cov): " + (time3 - time2) + " ms\n");
            SemBicScore score = new SemBicScore(cov);
            score.setPenaltyDiscount(penaltyDiscount);
            System.out.println(new Date());
            System.out.println("\nStarting FGES");
            long timea = System.currentTimeMillis();
            Fges fges = new Fges(score);
            // fges.setVerbose(false);
            fges.setNumPatternsToStore(0);
            fges.setOut(System.out);
            fges.setFaithfulnessAssumed(faithfulness);
            fges.setCycleBound(-1);
            long timeb = System.currentTimeMillis();
            estPattern = fges.search();
            long timec = System.currentTimeMillis();
            out.println("Time for FGES constructor " + (timeb - timea) + " ms");
            out.println("Time for FGES search " + (timec - timea) + " ms");
            out.println();
            out.flush();
            elapsed = timec - timea;
        } else {
            BayesPm pm = new BayesPm(dag, 3, 3);
            MlBayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM);
            DataSet data = im.simulateData(numCases, false, tiers);
            System.out.println("Finishing simulation");
            long time2 = System.currentTimeMillis();
            out.println("Elapsed (simulating the data): " + (time2 - time1) + " ms");
            long time3 = System.currentTimeMillis();
            BDeuScore score = new BDeuScore(data);
            score.setStructurePrior(1);
            score.setSamplePrior(1);
            System.out.println(new Date());
            System.out.println("\nStarting FGES");
            long timea = System.currentTimeMillis();
            Fges fges = new Fges(score);
            // fges.setVerbose(false);
            fges.setNumPatternsToStore(0);
            fges.setOut(System.out);
            fges.setFaithfulnessAssumed(faithfulness);
            fges.setCycleBound(-1);
            long timeb = System.currentTimeMillis();
            estPattern = fges.search();
            long timec = System.currentTimeMillis();
            out.println("Time consructing BDeu score " + (timea - time3) + " ms");
            out.println("Time for FGES constructor " + (timeb - timea) + " ms");
            out.println("Time for FGES search " + (timec - timea) + " ms");
            out.println();
            elapsed = timec - timea;
        }
        System.out.println("Done with FGES");
        System.out.println(new Date());
        // System.out.println("Replacing nodes");d
        // 
        // estPattern = GraphUtils.replaceNodes(estPattern, dag.getNodes());
        // System.out.println("Calculating degree");
        // 
        // double degree = GraphUtils.degree(estPattern);
        // degrees.add(degree);
        // 
        // out.println("Degree out output graph = " + degree);
        double[] comparison = new double[4];
        // int adjFn = GraphUtils.countAdjErrors(pattern, estPattern);
        // int adjFp = GraphUtils.countAdjErrors(estPattern, pattern);
        // int trueAdj = pattern.getNumEdges();
        // 
        // comparison[0] = trueAdj / (double) (trueAdj + adjFp);
        // comparison[1] = trueAdj / (double) (trueAdj + adjFn);
        System.out.println("Counting misclassifications.");
        estPattern = GraphUtils.replaceNodes(estPattern, pattern.getNodes());
        int[][] counts = GraphUtils.edgeMisclassificationCounts(pattern, estPattern, false);
        allCounts.add(counts);
        System.out.println(new Date());
        int sumRow = counts[4][0] + counts[4][3] + counts[4][5];
        int sumCol = counts[0][3] + counts[4][3] + counts[5][3] + counts[7][3];
        int trueArrow = counts[4][3];
        int sumTrueAdjacencies = 0;
        for (int i = 0; i < 7; i++) {
            for (int j = 0; j < 5; j++) {
                sumTrueAdjacencies += counts[i][j];
            }
        }
        int falsePositiveAdjacencies = 0;
        for (int j = 0; j < 5; j++) {
            falsePositiveAdjacencies += counts[7][j];
        }
        int falseNegativeAdjacencies = 0;
        for (int i = 0; i < 5; i++) {
            falseNegativeAdjacencies += counts[i][5];
        }
        comparison[0] = sumTrueAdjacencies / (double) (sumTrueAdjacencies + falsePositiveAdjacencies);
        comparison[1] = sumTrueAdjacencies / (double) (sumTrueAdjacencies + falseNegativeAdjacencies);
        comparison[2] = trueArrow / (double) sumCol;
        comparison[3] = trueArrow / (double) sumRow;
        comparisons.add(comparison);
        out.println(GraphUtils.edgeMisclassifications(counts));
        out.println(precisionRecall(comparison));
        elapsedTimes.add(elapsed);
        out.println("\nElapsed: " + elapsed + " ms");
    }
    printAverageConfusion("Average", allCounts);
    printAveragePrecisionRecall(comparisons);
    printAverageStatistics(elapsedTimes, degrees);
    out.close();
}
Also used : MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) LargeScaleSimulation(edu.cmu.tetrad.sem.LargeScaleSimulation) BayesPm(edu.cmu.tetrad.bayes.BayesPm)

Example 23 with BayesPm

use of edu.cmu.tetrad.bayes.BayesPm in project tetrad by cmu-phil.

the class BayesPmWrapper method setBayesPm.

private void setBayesPm(Graph graph, int lowerBound, int upperBound) {
    BayesPm b = new BayesPm(graph, lowerBound, upperBound);
    setBayesPm(b);
}
Also used : BayesPm(edu.cmu.tetrad.bayes.BayesPm)

Example 24 with BayesPm

use of edu.cmu.tetrad.bayes.BayesPm in project tetrad by cmu-phil.

the class XdslXmlParser method buildIM.

private BayesIm buildIM(Element element0, Map<String, String> displayNames) {
    Elements elements = element0.getChildElements();
    for (int i = 0; i < elements.size(); i++) {
        if (!"cpt".equals(elements.get(i).getQualifiedName())) {
            throw new IllegalArgumentException("Expecting cpt element.");
        }
    }
    Dag dag = new Dag();
    // Get the nodes.
    for (int i = 0; i < elements.size(); i++) {
        Element cpt = elements.get(i);
        String name = cpt.getAttribute(0).getValue();
        if (displayNames == null) {
            dag.addNode(new GraphNode(name));
        } else {
            dag.addNode(new GraphNode(displayNames.get(name)));
        }
    }
    // Get the edges.
    for (int i = 0; i < elements.size(); i++) {
        Element cpt = elements.get(i);
        Elements cptElements = cpt.getChildElements();
        for (int j = 0; j < cptElements.size(); j++) {
            Element cptElement = cptElements.get(j);
            if (cptElement.getQualifiedName().equals("parents")) {
                String list = cptElement.getValue();
                String[] parentNames = list.split(" ");
                for (String name : parentNames) {
                    if (displayNames == null) {
                        edu.cmu.tetrad.graph.Node parent = dag.getNode(name);
                        edu.cmu.tetrad.graph.Node child = dag.getNode(cpt.getAttribute(0).getValue());
                        dag.addDirectedEdge(parent, child);
                    } else {
                        edu.cmu.tetrad.graph.Node parent = dag.getNode(displayNames.get(name));
                        edu.cmu.tetrad.graph.Node child = dag.getNode(displayNames.get(cpt.getAttribute(0).getValue()));
                        dag.addDirectedEdge(parent, child);
                    }
                }
            }
        }
        String name;
        if (displayNames == null) {
            name = cpt.getAttribute(0).getValue();
        } else {
            name = displayNames.get(cpt.getAttribute(0).getValue());
        }
        dag.addNode(new GraphNode(name));
    }
    // PM
    BayesPm pm = new BayesPm(dag);
    for (int i = 0; i < elements.size(); i++) {
        Element cpt = elements.get(i);
        String varName = cpt.getAttribute(0).getValue();
        Node node;
        if (displayNames == null) {
            node = dag.getNode(varName);
        } else {
            node = dag.getNode(displayNames.get(varName));
        }
        Elements cptElements = cpt.getChildElements();
        List<String> stateNames = new ArrayList<>();
        for (int j = 0; j < cptElements.size(); j++) {
            Element cptElement = cptElements.get(j);
            if (cptElement.getQualifiedName().equals("state")) {
                Attribute attribute = cptElement.getAttribute(0);
                String stateName = attribute.getValue();
                stateNames.add(stateName);
            }
        }
        pm.setCategories(node, stateNames);
    }
    // IM
    BayesIm im = new MlBayesIm(pm);
    for (int nodeIndex = 0; nodeIndex < elements.size(); nodeIndex++) {
        Element cpt = elements.get(nodeIndex);
        Elements cptElements = cpt.getChildElements();
        for (int j = 0; j < cptElements.size(); j++) {
            Element cptElement = cptElements.get(j);
            if (cptElement.getQualifiedName().equals("probabilities")) {
                String list = cptElement.getValue();
                String[] probsStrings = list.split(" ");
                List<Double> probs = new ArrayList<>();
                for (String probString : probsStrings) {
                    probs.add(Double.parseDouble(probString));
                }
                int count = -1;
                for (int row = 0; row < im.getNumRows(nodeIndex); row++) {
                    for (int col = 0; col < im.getNumColumns(nodeIndex); col++) {
                        im.setProbability(nodeIndex, row, col, probs.get(++count));
                    }
                }
            }
        }
    }
    return im;
}
Also used : MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) Attribute(nu.xom.Attribute) Element(nu.xom.Element) GraphNode(edu.cmu.tetrad.graph.GraphNode) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) GraphNode(edu.cmu.tetrad.graph.GraphNode) Dag(edu.cmu.tetrad.graph.Dag) Elements(nu.xom.Elements) Node(edu.cmu.tetrad.graph.Node) BayesIm(edu.cmu.tetrad.bayes.BayesIm) MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) BayesPm(edu.cmu.tetrad.bayes.BayesPm)

Example 25 with BayesPm

use of edu.cmu.tetrad.bayes.BayesPm in project tetrad by cmu-phil.

the class FgesSearchEditor method reportIfDiscrete.

private String reportIfDiscrete(Graph dag, DataSet dataSet) {
    List vars = dataSet.getVariables();
    Map<String, DiscreteVariable> nodesToVars = new HashMap<>();
    for (int i = 0; i < dataSet.getNumColumns(); i++) {
        DiscreteVariable var = (DiscreteVariable) vars.get(i);
        String name = var.getName();
        Node node = new GraphNode(name);
        nodesToVars.put(node.getName(), var);
    }
    BayesPm bayesPm = new BayesPm(new Dag(dag));
    List<Node> nodes = bayesPm.getDag().getNodes();
    for (Node node : nodes) {
        Node var = nodesToVars.get(node.getName());
        if (var instanceof DiscreteVariable) {
            DiscreteVariable var2 = nodesToVars.get(node.getName());
            int numCategories = var2.getNumCategories();
            List<String> categories = new ArrayList<>();
            for (int j = 0; j < numCategories; j++) {
                categories.add(var2.getCategory(j));
            }
            bayesPm.setCategories(node, categories);
        }
    }
    NumberFormat nf = NumberFormat.getInstance();
    nf.setMaximumFractionDigits(4);
    StringBuilder buf = new StringBuilder();
    BayesProperties properties = new BayesProperties(dataSet);
    double p = properties.getLikelihoodRatioP(dag);
    double chisq = properties.getChisq();
    double bic = properties.getBic();
    double dof = properties.getDof();
    buf.append("\nP  = ").append(p);
    buf.append("\nDOF = ").append(dof);
    buf.append("\nChiSq = ").append(nf.format(chisq));
    buf.append("\nBIC = ").append(nf.format(bic));
    buf.append("\n\nH0: Complete DAG.");
    return buf.toString();
}
Also used : BayesProperties(edu.cmu.tetrad.bayes.BayesProperties) List(java.util.List) BayesPm(edu.cmu.tetrad.bayes.BayesPm) NumberFormat(java.text.NumberFormat)

Aggregations

BayesPm (edu.cmu.tetrad.bayes.BayesPm)38 MlBayesIm (edu.cmu.tetrad.bayes.MlBayesIm)23 BayesIm (edu.cmu.tetrad.bayes.BayesIm)18 Test (org.junit.Test)17 Node (edu.cmu.tetrad.graph.Node)14 Graph (edu.cmu.tetrad.graph.Graph)10 DataSet (edu.cmu.tetrad.data.DataSet)6 Dag (edu.cmu.tetrad.graph.Dag)6 DisplayNode (edu.cmu.tetradapp.workbench.DisplayNode)6 ArrayList (java.util.ArrayList)6 List (java.util.List)5 BayesProperties (edu.cmu.tetrad.bayes.BayesProperties)4 LargeScaleSimulation (edu.cmu.tetrad.sem.LargeScaleSimulation)4 Algorithm (edu.cmu.tetrad.algcomparison.algorithm.Algorithm)3 GraphNode (edu.cmu.tetrad.graph.GraphNode)3 Parameters (edu.cmu.tetrad.util.Parameters)3 GeneralBootstrapTest (edu.pitt.dbmi.algo.bootstrap.GeneralBootstrapTest)3 NumberFormat (java.text.NumberFormat)3 RandomGraph (edu.cmu.tetrad.algcomparison.graph.RandomGraph)2 ChiSquare (edu.cmu.tetrad.algcomparison.independence.ChiSquare)2