Search in sources :

Example 16 with MlBayesIm

use of edu.cmu.tetrad.bayes.MlBayesIm in project tetrad by cmu-phil.

the class BayesImWrapper method setBayesIm.

private void setBayesIm(BayesPm bayesPm, BayesIm oldBayesIm, int manual) {
    bayesIms = new ArrayList<>();
    bayesIms.add(new MlBayesIm(bayesPm, oldBayesIm, manual));
}
Also used : MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm)

Example 17 with MlBayesIm

use of edu.cmu.tetrad.bayes.MlBayesIm in project tetrad by cmu-phil.

the class XdslXmlParser method buildIM.

private BayesIm buildIM(Element element0, Map<String, String> displayNames) {
    Elements elements = element0.getChildElements();
    for (int i = 0; i < elements.size(); i++) {
        if (!"cpt".equals(elements.get(i).getQualifiedName())) {
            throw new IllegalArgumentException("Expecting cpt element.");
        }
    }
    Dag dag = new Dag();
    // Get the nodes.
    for (int i = 0; i < elements.size(); i++) {
        Element cpt = elements.get(i);
        String name = cpt.getAttribute(0).getValue();
        if (displayNames == null) {
            dag.addNode(new GraphNode(name));
        } else {
            dag.addNode(new GraphNode(displayNames.get(name)));
        }
    }
    // Get the edges.
    for (int i = 0; i < elements.size(); i++) {
        Element cpt = elements.get(i);
        Elements cptElements = cpt.getChildElements();
        for (int j = 0; j < cptElements.size(); j++) {
            Element cptElement = cptElements.get(j);
            if (cptElement.getQualifiedName().equals("parents")) {
                String list = cptElement.getValue();
                String[] parentNames = list.split(" ");
                for (String name : parentNames) {
                    if (displayNames == null) {
                        edu.cmu.tetrad.graph.Node parent = dag.getNode(name);
                        edu.cmu.tetrad.graph.Node child = dag.getNode(cpt.getAttribute(0).getValue());
                        dag.addDirectedEdge(parent, child);
                    } else {
                        edu.cmu.tetrad.graph.Node parent = dag.getNode(displayNames.get(name));
                        edu.cmu.tetrad.graph.Node child = dag.getNode(displayNames.get(cpt.getAttribute(0).getValue()));
                        dag.addDirectedEdge(parent, child);
                    }
                }
            }
        }
        String name;
        if (displayNames == null) {
            name = cpt.getAttribute(0).getValue();
        } else {
            name = displayNames.get(cpt.getAttribute(0).getValue());
        }
        dag.addNode(new GraphNode(name));
    }
    // PM
    BayesPm pm = new BayesPm(dag);
    for (int i = 0; i < elements.size(); i++) {
        Element cpt = elements.get(i);
        String varName = cpt.getAttribute(0).getValue();
        Node node;
        if (displayNames == null) {
            node = dag.getNode(varName);
        } else {
            node = dag.getNode(displayNames.get(varName));
        }
        Elements cptElements = cpt.getChildElements();
        List<String> stateNames = new ArrayList<>();
        for (int j = 0; j < cptElements.size(); j++) {
            Element cptElement = cptElements.get(j);
            if (cptElement.getQualifiedName().equals("state")) {
                Attribute attribute = cptElement.getAttribute(0);
                String stateName = attribute.getValue();
                stateNames.add(stateName);
            }
        }
        pm.setCategories(node, stateNames);
    }
    // IM
    BayesIm im = new MlBayesIm(pm);
    for (int nodeIndex = 0; nodeIndex < elements.size(); nodeIndex++) {
        Element cpt = elements.get(nodeIndex);
        Elements cptElements = cpt.getChildElements();
        for (int j = 0; j < cptElements.size(); j++) {
            Element cptElement = cptElements.get(j);
            if (cptElement.getQualifiedName().equals("probabilities")) {
                String list = cptElement.getValue();
                String[] probsStrings = list.split(" ");
                List<Double> probs = new ArrayList<>();
                for (String probString : probsStrings) {
                    probs.add(Double.parseDouble(probString));
                }
                int count = -1;
                for (int row = 0; row < im.getNumRows(nodeIndex); row++) {
                    for (int col = 0; col < im.getNumColumns(nodeIndex); col++) {
                        im.setProbability(nodeIndex, row, col, probs.get(++count));
                    }
                }
            }
        }
    }
    return im;
}
Also used : MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) Attribute(nu.xom.Attribute) Element(nu.xom.Element) GraphNode(edu.cmu.tetrad.graph.GraphNode) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) GraphNode(edu.cmu.tetrad.graph.GraphNode) Dag(edu.cmu.tetrad.graph.Dag) Elements(nu.xom.Elements) Node(edu.cmu.tetrad.graph.Node) BayesIm(edu.cmu.tetrad.bayes.BayesIm) MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) BayesPm(edu.cmu.tetrad.bayes.BayesPm)

Example 18 with MlBayesIm

use of edu.cmu.tetrad.bayes.MlBayesIm in project tetrad by cmu-phil.

the class TestEvidence method sampleBayesIm2.

private static BayesIm sampleBayesIm2() {
    Node a = new GraphNode("a");
    Node b = new GraphNode("b");
    Node c = new GraphNode("c");
    Dag graph;
    graph = new Dag();
    graph.addNode(a);
    graph.addNode(b);
    graph.addNode(c);
    graph.addDirectedEdge(a, b);
    graph.addDirectedEdge(a, c);
    graph.addDirectedEdge(b, c);
    BayesPm bayesPm = new BayesPm(graph);
    bayesPm.setNumCategories(b, 3);
    BayesIm bayesIm1 = new MlBayesIm(bayesPm);
    bayesIm1.setProbability(0, 0, 0, .3);
    bayesIm1.setProbability(0, 0, 1, .7);
    bayesIm1.setProbability(1, 0, 0, .3);
    bayesIm1.setProbability(1, 0, 1, .4);
    bayesIm1.setProbability(1, 0, 2, .3);
    bayesIm1.setProbability(1, 1, 0, .6);
    bayesIm1.setProbability(1, 1, 1, .1);
    bayesIm1.setProbability(1, 1, 2, .3);
    bayesIm1.setProbability(2, 0, 0, .9);
    bayesIm1.setProbability(2, 0, 1, .1);
    bayesIm1.setProbability(2, 1, 0, .1);
    bayesIm1.setProbability(2, 1, 1, .9);
    bayesIm1.setProbability(2, 2, 0, .5);
    bayesIm1.setProbability(2, 2, 1, .5);
    bayesIm1.setProbability(2, 3, 0, .2);
    bayesIm1.setProbability(2, 3, 1, .8);
    bayesIm1.setProbability(2, 4, 0, .6);
    bayesIm1.setProbability(2, 4, 1, .4);
    bayesIm1.setProbability(2, 5, 0, .7);
    bayesIm1.setProbability(2, 5, 1, .3);
    return bayesIm1;
}
Also used : MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) BayesIm(edu.cmu.tetrad.bayes.BayesIm) MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) GraphNode(edu.cmu.tetrad.graph.GraphNode) Node(edu.cmu.tetrad.graph.Node) GraphNode(edu.cmu.tetrad.graph.GraphNode) Dag(edu.cmu.tetrad.graph.Dag) BayesPm(edu.cmu.tetrad.bayes.BayesPm)

Example 19 with MlBayesIm

use of edu.cmu.tetrad.bayes.MlBayesIm in project tetrad by cmu-phil.

the class TestFges method explore2.

@Test
public void explore2() {
    RandomUtil.getInstance().setSeed(1457220623122L);
    int numVars = 20;
    double edgeFactor = 1.0;
    int numCases = 1000;
    double structurePrior = 1;
    double samplePrior = 1;
    List<Node> vars = new ArrayList<>();
    for (int i = 0; i < numVars; i++) {
        vars.add(new ContinuousVariable("X" + i));
    }
    Graph dag = GraphUtils.randomGraphRandomForwardEdges(vars, 0, (int) (numVars * edgeFactor), 30, 15, 15, false, true);
    // printDegreeDistribution(dag, out);
    BayesPm pm = new BayesPm(dag, 2, 3);
    BayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM);
    DataSet data = im.simulateData(numCases, false);
    // out.println("Finishing simulation");
    BDeScore score = new BDeScore(data);
    score.setSamplePrior(samplePrior);
    score.setStructurePrior(structurePrior);
    Fges ges = new Fges(score);
    ges.setVerbose(false);
    ges.setNumPatternsToStore(0);
    ges.setFaithfulnessAssumed(false);
    Graph estPattern = ges.search();
    final Graph truePattern = SearchGraphUtils.patternForDag(dag);
    int[][] counts = SearchGraphUtils.graphComparison(estPattern, truePattern, null);
    int[][] expectedCounts = { { 2, 0, 0, 0, 0, 1 }, { 0, 0, 0, 0, 0, 0 }, { 0, 0, 0, 0, 0, 0 }, { 0, 0, 0, 0, 0, 0 }, { 2, 0, 0, 13, 0, 3 }, { 0, 0, 0, 0, 0, 0 }, { 0, 0, 0, 0, 0, 0 }, { 0, 0, 0, 0, 0, 0 } };
// for (int i = 0; i < counts.length; i++) {
// assertTrue(Arrays.equals(counts[i], expectedCounts[i]));
// }
// System.out.println(MatrixUtils.toString(expectedCounts));
// System.out.println(MatrixUtils.toString(counts));
// System.out.println(RandomUtil.getInstance().getSeed());
}
Also used : MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) Fges(edu.cmu.tetrad.search.Fges) RandomGraph(edu.cmu.tetrad.algcomparison.graph.RandomGraph) BayesIm(edu.cmu.tetrad.bayes.BayesIm) MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) BayesPm(edu.cmu.tetrad.bayes.BayesPm) SemBicDTest(edu.cmu.tetrad.algcomparison.independence.SemBicDTest) SemBicTest(edu.cmu.tetrad.algcomparison.independence.SemBicTest) Test(org.junit.Test)

Example 20 with MlBayesIm

use of edu.cmu.tetrad.bayes.MlBayesIm in project tetrad by cmu-phil.

the class TestGeneralBootstrapTest method testFCId.

@Test
public void testFCId() {
    double structurePrior = 1, samplePrior = 1;
    int depth = -1;
    int maxPathLength = -1;
    int numVars = 20;
    int edgesPerNode = 2;
    int numLatentConfounders = 4;
    int numCases = 50;
    int numBootstrapSamples = 5;
    boolean verbose = true;
    long seed = 123;
    Graph dag = makeDiscreteDAG(numVars, numLatentConfounders, edgesPerNode);
    DagToPag dagToPag = new DagToPag(dag);
    Graph truePag = dagToPag.convert();
    System.out.println("Truth PAG_of_the_true_DAG Graph:");
    System.out.println(truePag.toString());
    BayesPm pm = new BayesPm(dag, 2, 3);
    BayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM);
    DataSet data = im.simulateData(numCases, seed, false);
    Parameters parameters = new Parameters();
    parameters.set("structurePrior", structurePrior);
    parameters.set("samplePrior", samplePrior);
    parameters.set("depth", depth);
    parameters.set("maxPathLength", maxPathLength);
    parameters.set("numPatternsToStore", 0);
    parameters.set("verbose", verbose);
    IndependenceWrapper test = new ChiSquare();
    Algorithm algorithm = new Fci(test);
    GeneralBootstrapTest bootstrapTest = new GeneralBootstrapTest(data, algorithm, numBootstrapSamples);
    bootstrapTest.setVerbose(verbose);
    bootstrapTest.setParameters(parameters);
    bootstrapTest.setEdgeEnsemble(BootstrapEdgeEnsemble.Highest);
    Graph resultGraph = bootstrapTest.search();
    System.out.println("Estimated Bootstrapped PAG_of_the_true_DAG Graph:");
    System.out.println(resultGraph.toString());
    // Adjacency Confusion Matrix
    int[][] adjAr = GeneralBootstrapTest.getAdjConfusionMatrix(truePag, resultGraph);
    printAdjConfusionMatrix(adjAr);
    // Edge Type Confusion Matrix
    int[][] edgeAr = GeneralBootstrapTest.getEdgeTypeConfusionMatrix(truePag, resultGraph);
    printEdgeTypeConfusionMatrix(edgeAr);
}
Also used : MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) Parameters(edu.cmu.tetrad.util.Parameters) ChiSquare(edu.cmu.tetrad.algcomparison.independence.ChiSquare) GeneralBootstrapTest(edu.pitt.dbmi.algo.bootstrap.GeneralBootstrapTest) DataSet(edu.cmu.tetrad.data.DataSet) Algorithm(edu.cmu.tetrad.algcomparison.algorithm.Algorithm) Fci(edu.cmu.tetrad.algcomparison.algorithm.oracle.pag.Fci) IndependenceWrapper(edu.cmu.tetrad.algcomparison.independence.IndependenceWrapper) Graph(edu.cmu.tetrad.graph.Graph) DagToPag(edu.cmu.tetrad.search.DagToPag) BayesIm(edu.cmu.tetrad.bayes.BayesIm) MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) BayesPm(edu.cmu.tetrad.bayes.BayesPm) Test(org.junit.Test) GeneralBootstrapTest(edu.pitt.dbmi.algo.bootstrap.GeneralBootstrapTest)

Aggregations

MlBayesIm (edu.cmu.tetrad.bayes.MlBayesIm)26 BayesPm (edu.cmu.tetrad.bayes.BayesPm)23 BayesIm (edu.cmu.tetrad.bayes.BayesIm)20 Test (org.junit.Test)15 Graph (edu.cmu.tetrad.graph.Graph)9 Node (edu.cmu.tetrad.graph.Node)8 DataSet (edu.cmu.tetrad.data.DataSet)6 Dag (edu.cmu.tetrad.graph.Dag)6 ArrayList (java.util.ArrayList)6 LargeScaleSimulation (edu.cmu.tetrad.sem.LargeScaleSimulation)4 Algorithm (edu.cmu.tetrad.algcomparison.algorithm.Algorithm)3 GraphNode (edu.cmu.tetrad.graph.GraphNode)3 Parameters (edu.cmu.tetrad.util.Parameters)3 GeneralBootstrapTest (edu.pitt.dbmi.algo.bootstrap.GeneralBootstrapTest)3 RandomGraph (edu.cmu.tetrad.algcomparison.graph.RandomGraph)2 ChiSquare (edu.cmu.tetrad.algcomparison.independence.ChiSquare)2 IndependenceWrapper (edu.cmu.tetrad.algcomparison.independence.IndependenceWrapper)2 BdeuScore (edu.cmu.tetrad.algcomparison.score.BdeuScore)2 ScoreWrapper (edu.cmu.tetrad.algcomparison.score.ScoreWrapper)2 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)2