Search in sources :

Example 11 with BayesPm

use of edu.cmu.tetrad.bayes.BayesPm in project tetrad by cmu-phil.

the class BayesNetSimulation method simulate.

private DataSet simulate(Graph graph, Parameters parameters) {
    boolean saveLatentVars = parameters.getBoolean("saveLatentVars");
    try {
        BayesIm im = this.im;
        if (im == null) {
            BayesPm pm = this.pm;
            if (pm == null) {
                int minCategories = parameters.getInt("minCategories");
                int maxCategories = parameters.getInt("maxCategories");
                pm = new BayesPm(graph, minCategories, maxCategories);
                im = new MlBayesIm(pm, MlBayesIm.RANDOM);
                ims.add(im);
                return im.simulateData(parameters.getInt("sampleSize"), saveLatentVars);
            } else {
                im = new MlBayesIm(pm, MlBayesIm.RANDOM);
                this.im = im;
                ims.add(im);
                return im.simulateData(parameters.getInt("sampleSize"), saveLatentVars);
            }
        } else {
            ims = new ArrayList<>();
            ims.add(im);
            return im.simulateData(parameters.getInt("sampleSize"), saveLatentVars);
        }
    } catch (Exception e) {
        e.printStackTrace();
        throw new IllegalArgumentException("Sorry, I couldn't simulate from that Bayes IM; perhaps not all of\n" + "the parameters have been specified.");
    }
}
Also used : MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) BayesIm(edu.cmu.tetrad.bayes.BayesIm) MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) BayesPm(edu.cmu.tetrad.bayes.BayesPm)

Example 12 with BayesPm

use of edu.cmu.tetrad.bayes.BayesPm in project tetrad by cmu-phil.

the class ADTreeTest method main.

public static void main(String[] args) throws Exception {
    int columns = 40;
    int numEdges = 40;
    int rows = 500;
    List<Node> variables = new ArrayList<>();
    List<String> varNames = new ArrayList<>();
    for (int i = 0; i < columns; i++) {
        final String name = "X" + (i + 1);
        varNames.add(name);
        variables.add(new ContinuousVariable(name));
    }
    Graph graph = GraphUtils.randomGraphRandomForwardEdges(variables, 0, numEdges, 30, 15, 15, false, true);
    BayesPm pm = new BayesPm(graph);
    BayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM);
    DataSet data = im.simulateData(rows, false);
    // This implementation uses a DataTable to represent the data
    // The first type parameter is the type for the variables
    // The second type parameter is the type for the values of the variables
    DataTableImpl<Node, Short> dataTable = new DataTableImpl<>(variables);
    for (int i = 0; i < rows; i++) {
        ArrayList<Short> intArray = new ArrayList<>();
        for (int j = 0; j < columns; j++) {
            intArray.add((short) data.getInt(i, j));
        }
        dataTable.addRow(intArray);
    }
    // create the tree
    long start = System.currentTimeMillis();
    ADTree<Node, Short> adTree = new ADTree<>(dataTable);
    System.out.println(String.format("Generated tree in %s millis", System.currentTimeMillis() - start));
    // the query is an arbitrary map of vars and their values
    TreeMap<Node, Short> query = new TreeMap<>();
    query.put(node(pm, "X1"), (short) 1);
    query.put(node(pm, "X5"), (short) 0);
    start = System.currentTimeMillis();
    System.out.println(String.format("Count is %d", adTree.count(query)));
    System.out.println(String.format("Query in %s ms", System.currentTimeMillis() - start));
    query.clear();
    query.put(node(pm, "X1"), (short) 1);
    query.put(node(pm, "X2"), (short) 1);
    query.put(node(pm, "X5"), (short) 0);
    query.put(node(pm, "X10"), (short) 1);
    start = System.currentTimeMillis();
    System.out.println(String.format("Count is %d", adTree.count(query)));
    System.out.println(String.format("Query in %s ms", System.currentTimeMillis() - start));
}
Also used : MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) TreeMap(java.util.TreeMap) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) Graph(edu.cmu.tetrad.graph.Graph) BayesIm(edu.cmu.tetrad.bayes.BayesIm) MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) BayesPm(edu.cmu.tetrad.bayes.BayesPm)

Example 13 with BayesPm

use of edu.cmu.tetrad.bayes.BayesPm in project tetrad by cmu-phil.

the class EvidenceWizardMultipleObs method appendJoint.

private void appendJoint(List<Node> selectedNodes, JTextArea marginalsArea, BayesIm manipulatedIm, NumberFormat nf) {
    if (!getUpdaterWrapper().getBayesUpdater().isJointMarginalSupported()) {
        marginalsArea.append("\n\n(Calculation of joint not supported " + "for this updater.)");
        return;
    }
    BayesPm bayesPm = manipulatedIm.getBayesPm();
    int numNodes = selectedNodes.size();
    int[] dims = new int[numNodes];
    int[] variables = new int[numNodes];
    int numRows = 1;
    for (int i = 0; i < numNodes; i++) {
        Node node = selectedNodes.get(i);
        int numCategories = bayesPm.getNumCategories(node);
        variables[i] = manipulatedIm.getNodeIndex(node);
        dims[i] = numCategories;
        numRows *= numCategories;
    }
    marginalsArea.append("\n\nJOINT OVER SELECTED VARIABLES:\n\n");
    for (int i = 0; i < numNodes; i++) {
        marginalsArea.append(selectedNodes.get(i) + "\t");
    }
    marginalsArea.append("Joint\tLog odds\n");
    for (int row = 0; row < numRows; row++) {
        int[] values = getCategories(row, dims);
        double prob = getUpdaterWrapper().getBayesUpdater().getJointMarginal(variables, values);
        marginalsArea.append("\n");
        for (int j = 0; j < numNodes; j++) {
            Node node = selectedNodes.get(j);
            marginalsArea.append(bayesPm.getCategory(node, values[j]));
            marginalsArea.append("\t");
        }
        // identifiability returns -1 if the requested prob is unidentifiable
        if (prob < 0.0) {
            marginalsArea.append("Unidentifiable" + "\t");
            marginalsArea.append("*");
        } else {
            double logOdds = Math.log(prob / (1. - prob));
            marginalsArea.append(nf.format(prob) + "\t");
            marginalsArea.append(nf.format(logOdds));
        }
    }
}
Also used : Node(edu.cmu.tetrad.graph.Node) DisplayNode(edu.cmu.tetradapp.workbench.DisplayNode) BayesPm(edu.cmu.tetrad.bayes.BayesPm)

Example 14 with BayesPm

use of edu.cmu.tetrad.bayes.BayesPm in project tetrad by cmu-phil.

the class Comparison method compare.

/**
 * Simulates data from model paramerizing the given DAG, and runs the algorithm on that data,
 * printing out error statistics.
 */
public static ComparisonResult compare(ComparisonParameters params) {
    DataSet dataSet;
    Graph trueDag;
    IndependenceTest test = null;
    Score score = null;
    ComparisonResult result = new ComparisonResult(params);
    if (params.getDataFile() != null) {
        dataSet = loadDataFile(params.getDataFile());
        if (params.getGraphFile() == null) {
            throw new IllegalArgumentException("True graph file not set.");
        }
        trueDag = loadGraphFile(params.getGraphFile());
    } else {
        if (params.getNumVars() == -1) {
            throw new IllegalArgumentException("Number of variables not set.");
        }
        if (params.getNumEdges() == -1) {
            throw new IllegalArgumentException("Number of edges not set.");
        }
        if (params.getDataType() == ComparisonParameters.DataType.Continuous) {
            List<Node> nodes = new ArrayList<>();
            for (int i = 0; i < params.getNumVars(); i++) {
                nodes.add(new ContinuousVariable("X" + (i + 1)));
            }
            trueDag = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
            if (params.getDataType() == null) {
                throw new IllegalArgumentException("Data type not set or inferred.");
            }
            if (params.getSampleSize() == -1) {
                throw new IllegalArgumentException("Sample size not set.");
            }
            LargeScaleSimulation sim = new LargeScaleSimulation(trueDag);
            dataSet = sim.simulateDataFisher(params.getSampleSize());
        } else if (params.getDataType() == ComparisonParameters.DataType.Discrete) {
            List<Node> nodes = new ArrayList<>();
            for (int i = 0; i < params.getNumVars(); i++) {
                nodes.add(new DiscreteVariable("X" + (i + 1), 3));
            }
            trueDag = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
            if (params.getDataType() == null) {
                throw new IllegalArgumentException("Data type not set or inferred.");
            }
            if (params.getSampleSize() == -1) {
                throw new IllegalArgumentException("Sample size not set.");
            }
            int[] tiers = new int[nodes.size()];
            for (int i = 0; i < nodes.size(); i++) {
                tiers[i] = i;
            }
            BayesPm pm = new BayesPm(trueDag, 3, 3);
            MlBayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM);
            dataSet = im.simulateData(params.getSampleSize(), false, tiers);
        } else {
            throw new IllegalArgumentException("Unrecognized data type.");
        }
        if (dataSet == null) {
            throw new IllegalArgumentException("No data set.");
        }
    }
    if (params.getIndependenceTest() == ComparisonParameters.IndependenceTestType.FisherZ) {
        if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Continuous) {
            throw new IllegalArgumentException("Data type previously set to something other than continuous.");
        }
        if (Double.isNaN(params.getAlpha())) {
            throw new IllegalArgumentException("Alpha not set.");
        }
        test = new IndTestFisherZ(dataSet, params.getAlpha());
        params.setDataType(ComparisonParameters.DataType.Continuous);
    } else if (params.getIndependenceTest() == ComparisonParameters.IndependenceTestType.ChiSquare) {
        if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Discrete) {
            throw new IllegalArgumentException("Data type previously set to something other than discrete.");
        }
        if (Double.isNaN(params.getAlpha())) {
            throw new IllegalArgumentException("Alpha not set.");
        }
        test = new IndTestChiSquare(dataSet, params.getAlpha());
        params.setDataType(ComparisonParameters.DataType.Discrete);
    }
    if (params.getScore() == ScoreType.SemBic) {
        if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Continuous) {
            throw new IllegalArgumentException("Data type previously set to something other than continuous.");
        }
        if (Double.isNaN(params.getPenaltyDiscount())) {
            throw new IllegalArgumentException("Penalty discount not set.");
        }
        SemBicScore semBicScore = new SemBicScore(new CovarianceMatrixOnTheFly(dataSet));
        semBicScore.setPenaltyDiscount(params.getPenaltyDiscount());
        score = semBicScore;
        params.setDataType(ComparisonParameters.DataType.Continuous);
    } else if (params.getScore() == ScoreType.BDeu) {
        if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Discrete) {
            throw new IllegalArgumentException("Data type previously set to something other than discrete.");
        }
        if (Double.isNaN(params.getSamplePrior())) {
            throw new IllegalArgumentException("Sample prior not set.");
        }
        if (Double.isNaN(params.getStructurePrior())) {
            throw new IllegalArgumentException("Structure prior not set.");
        }
        score = new BDeuScore(dataSet);
        ((BDeuScore) score).setSamplePrior(params.getSamplePrior());
        ((BDeuScore) score).setStructurePrior(params.getStructurePrior());
        params.setDataType(ComparisonParameters.DataType.Discrete);
        params.setDataType(ComparisonParameters.DataType.Discrete);
    }
    if (params.getAlgorithm() == null) {
        throw new IllegalArgumentException("Algorithm not set.");
    }
    long time1 = System.currentTimeMillis();
    if (params.getAlgorithm() == ComparisonParameters.Algorithm.PC) {
        if (test == null)
            throw new IllegalArgumentException("Test not set.");
        Pc search = new Pc(test);
        result.setResultGraph(search.search());
        result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.CPC) {
        if (test == null)
            throw new IllegalArgumentException("Test not set.");
        Cpc search = new Cpc(test);
        result.setResultGraph(search.search());
        result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCLocal) {
        if (test == null)
            throw new IllegalArgumentException("Test not set.");
        PcLocal search = new PcLocal(test);
        result.setResultGraph(search.search());
        result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCStableMax) {
        if (test == null)
            throw new IllegalArgumentException("Test not set.");
        PcStableMax search = new PcStableMax(test);
        result.setResultGraph(search.search());
        result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FGES) {
        if (score == null)
            throw new IllegalArgumentException("Score not set.");
        Fges search = new Fges(score);
        search.setFaithfulnessAssumed(params.isOneEdgeFaithfulnessAssumed());
        result.setResultGraph(search.search());
        result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FGES2) {
        if (score == null)
            throw new IllegalArgumentException("Score not set.");
        Fges search = new Fges(score);
        search.setFaithfulnessAssumed(params.isOneEdgeFaithfulnessAssumed());
        result.setResultGraph(search.search());
        result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FCI) {
        if (test == null)
            throw new IllegalArgumentException("Test not set.");
        Fci search = new Fci(test);
        result.setResultGraph(search.search());
        result.setCorrectResult(new DagToPag(trueDag).convert());
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.GFCI) {
        if (test == null)
            throw new IllegalArgumentException("Test not set.");
        GFci search = new GFci(test, score);
        result.setResultGraph(search.search());
        result.setCorrectResult(new DagToPag(trueDag).convert());
    } else {
        throw new IllegalArgumentException("Unrecognized algorithm.");
    }
    long time2 = System.currentTimeMillis();
    long elapsed = time2 - time1;
    result.setElapsed(elapsed);
    result.setTrueDag(trueDag);
    return result;
}
Also used : MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) LargeScaleSimulation(edu.cmu.tetrad.sem.LargeScaleSimulation) ArrayList(java.util.ArrayList) List(java.util.List) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) Graph(edu.cmu.tetrad.graph.Graph) BayesPm(edu.cmu.tetrad.bayes.BayesPm)

Example 15 with BayesPm

use of edu.cmu.tetrad.bayes.BayesPm in project tetrad by cmu-phil.

the class TestProposition method sampleBayesIm2.

private BayesIm sampleBayesIm2() {
    Node a = new GraphNode("a");
    Node b = new GraphNode("b");
    Node c = new GraphNode("c");
    Dag graph;
    graph = new Dag();
    graph.addNode(a);
    graph.addNode(b);
    graph.addNode(c);
    graph.addDirectedEdge(a, b);
    graph.addDirectedEdge(a, c);
    graph.addDirectedEdge(b, c);
    BayesPm bayesPm = new BayesPm(graph);
    bayesPm.setNumCategories(b, 3);
    BayesIm bayesIm1 = new MlBayesIm(bayesPm);
    bayesIm1.setProbability(0, 0, 0, .3);
    bayesIm1.setProbability(0, 0, 1, .7);
    bayesIm1.setProbability(1, 0, 0, .3);
    bayesIm1.setProbability(1, 0, 1, .4);
    bayesIm1.setProbability(1, 0, 2, .3);
    bayesIm1.setProbability(1, 1, 0, .6);
    bayesIm1.setProbability(1, 1, 1, .1);
    bayesIm1.setProbability(1, 1, 2, .3);
    bayesIm1.setProbability(2, 0, 0, .9);
    bayesIm1.setProbability(2, 0, 1, .1);
    bayesIm1.setProbability(2, 1, 0, .1);
    bayesIm1.setProbability(2, 1, 1, .9);
    bayesIm1.setProbability(2, 2, 0, .5);
    bayesIm1.setProbability(2, 2, 1, .5);
    bayesIm1.setProbability(2, 3, 0, .2);
    bayesIm1.setProbability(2, 3, 1, .8);
    bayesIm1.setProbability(2, 4, 0, .6);
    bayesIm1.setProbability(2, 4, 1, .4);
    bayesIm1.setProbability(2, 5, 0, .7);
    bayesIm1.setProbability(2, 5, 1, .3);
    return bayesIm1;
}
Also used : MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) BayesIm(edu.cmu.tetrad.bayes.BayesIm) MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) GraphNode(edu.cmu.tetrad.graph.GraphNode) Node(edu.cmu.tetrad.graph.Node) GraphNode(edu.cmu.tetrad.graph.GraphNode) Dag(edu.cmu.tetrad.graph.Dag) BayesPm(edu.cmu.tetrad.bayes.BayesPm)

Aggregations

BayesPm (edu.cmu.tetrad.bayes.BayesPm)38 MlBayesIm (edu.cmu.tetrad.bayes.MlBayesIm)23 BayesIm (edu.cmu.tetrad.bayes.BayesIm)18 Test (org.junit.Test)17 Node (edu.cmu.tetrad.graph.Node)14 Graph (edu.cmu.tetrad.graph.Graph)10 DataSet (edu.cmu.tetrad.data.DataSet)6 Dag (edu.cmu.tetrad.graph.Dag)6 DisplayNode (edu.cmu.tetradapp.workbench.DisplayNode)6 ArrayList (java.util.ArrayList)6 List (java.util.List)5 BayesProperties (edu.cmu.tetrad.bayes.BayesProperties)4 LargeScaleSimulation (edu.cmu.tetrad.sem.LargeScaleSimulation)4 Algorithm (edu.cmu.tetrad.algcomparison.algorithm.Algorithm)3 GraphNode (edu.cmu.tetrad.graph.GraphNode)3 Parameters (edu.cmu.tetrad.util.Parameters)3 GeneralBootstrapTest (edu.pitt.dbmi.algo.bootstrap.GeneralBootstrapTest)3 NumberFormat (java.text.NumberFormat)3 RandomGraph (edu.cmu.tetrad.algcomparison.graph.RandomGraph)2 ChiSquare (edu.cmu.tetrad.algcomparison.independence.ChiSquare)2