Search in sources :

Example 6 with MlBayesIm

use of edu.cmu.tetrad.bayes.MlBayesIm in project tetrad by cmu-phil.

the class TestGFci method testRandomDiscreteData.

@Test
public void testRandomDiscreteData() {
    int sampleSize = 1000;
    Graph g = GraphConverter.convert("X1-->X2,X1-->X3,X1-->X4,X2-->X3,X2-->X4,X3-->X4");
    Dag dag = new Dag(g);
    BayesPm bayesPm = new BayesPm(dag);
    BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
    DataSet data = bayesIm.simulateData(sampleSize, false);
    IndependenceTest test = new IndTestChiSquare(data, 0.05);
    BDeuScore bDeuScore = new BDeuScore(data);
    bDeuScore.setSamplePrior(1.0);
    bDeuScore.setStructurePrior(1.0);
    GFci gFci = new GFci(test, bDeuScore);
    gFci.setFaithfulnessAssumed(true);
    long start = System.currentTimeMillis();
    gFci.search();
    long stop = System.currentTimeMillis();
    System.out.println("Elapsed " + (stop - start) + " ms");
    DagToPag dagToPag = new DagToPag(g);
    dagToPag.setVerbose(false);
}
Also used : MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) BayesIm(edu.cmu.tetrad.bayes.BayesIm) MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) BayesPm(edu.cmu.tetrad.bayes.BayesPm) Test(org.junit.Test)

Example 7 with MlBayesIm

use of edu.cmu.tetrad.bayes.MlBayesIm in project tetrad by cmu-phil.

the class BayesNetSimulation method simulate.

private DataSet simulate(Graph graph, Parameters parameters) {
    boolean saveLatentVars = parameters.getBoolean("saveLatentVars");
    try {
        BayesIm im = this.im;
        if (im == null) {
            BayesPm pm = this.pm;
            if (pm == null) {
                int minCategories = parameters.getInt("minCategories");
                int maxCategories = parameters.getInt("maxCategories");
                pm = new BayesPm(graph, minCategories, maxCategories);
                im = new MlBayesIm(pm, MlBayesIm.RANDOM);
                ims.add(im);
                return im.simulateData(parameters.getInt("sampleSize"), saveLatentVars);
            } else {
                im = new MlBayesIm(pm, MlBayesIm.RANDOM);
                this.im = im;
                ims.add(im);
                return im.simulateData(parameters.getInt("sampleSize"), saveLatentVars);
            }
        } else {
            ims = new ArrayList<>();
            ims.add(im);
            return im.simulateData(parameters.getInt("sampleSize"), saveLatentVars);
        }
    } catch (Exception e) {
        e.printStackTrace();
        throw new IllegalArgumentException("Sorry, I couldn't simulate from that Bayes IM; perhaps not all of\n" + "the parameters have been specified.");
    }
}
Also used : MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) BayesIm(edu.cmu.tetrad.bayes.BayesIm) MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) BayesPm(edu.cmu.tetrad.bayes.BayesPm)

Example 8 with MlBayesIm

use of edu.cmu.tetrad.bayes.MlBayesIm in project tetrad by cmu-phil.

the class ADTreeTest method main.

public static void main(String[] args) throws Exception {
    int columns = 40;
    int numEdges = 40;
    int rows = 500;
    List<Node> variables = new ArrayList<>();
    List<String> varNames = new ArrayList<>();
    for (int i = 0; i < columns; i++) {
        final String name = "X" + (i + 1);
        varNames.add(name);
        variables.add(new ContinuousVariable(name));
    }
    Graph graph = GraphUtils.randomGraphRandomForwardEdges(variables, 0, numEdges, 30, 15, 15, false, true);
    BayesPm pm = new BayesPm(graph);
    BayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM);
    DataSet data = im.simulateData(rows, false);
    // This implementation uses a DataTable to represent the data
    // The first type parameter is the type for the variables
    // The second type parameter is the type for the values of the variables
    DataTableImpl<Node, Short> dataTable = new DataTableImpl<>(variables);
    for (int i = 0; i < rows; i++) {
        ArrayList<Short> intArray = new ArrayList<>();
        for (int j = 0; j < columns; j++) {
            intArray.add((short) data.getInt(i, j));
        }
        dataTable.addRow(intArray);
    }
    // create the tree
    long start = System.currentTimeMillis();
    ADTree<Node, Short> adTree = new ADTree<>(dataTable);
    System.out.println(String.format("Generated tree in %s millis", System.currentTimeMillis() - start));
    // the query is an arbitrary map of vars and their values
    TreeMap<Node, Short> query = new TreeMap<>();
    query.put(node(pm, "X1"), (short) 1);
    query.put(node(pm, "X5"), (short) 0);
    start = System.currentTimeMillis();
    System.out.println(String.format("Count is %d", adTree.count(query)));
    System.out.println(String.format("Query in %s ms", System.currentTimeMillis() - start));
    query.clear();
    query.put(node(pm, "X1"), (short) 1);
    query.put(node(pm, "X2"), (short) 1);
    query.put(node(pm, "X5"), (short) 0);
    query.put(node(pm, "X10"), (short) 1);
    start = System.currentTimeMillis();
    System.out.println(String.format("Count is %d", adTree.count(query)));
    System.out.println(String.format("Query in %s ms", System.currentTimeMillis() - start));
}
Also used : MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) TreeMap(java.util.TreeMap) ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) Graph(edu.cmu.tetrad.graph.Graph) BayesIm(edu.cmu.tetrad.bayes.BayesIm) MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) BayesPm(edu.cmu.tetrad.bayes.BayesPm)

Example 9 with MlBayesIm

use of edu.cmu.tetrad.bayes.MlBayesIm in project tetrad by cmu-phil.

the class Comparison method compare.

/**
 * Simulates data from model paramerizing the given DAG, and runs the algorithm on that data,
 * printing out error statistics.
 */
public static ComparisonResult compare(ComparisonParameters params) {
    DataSet dataSet;
    Graph trueDag;
    IndependenceTest test = null;
    Score score = null;
    ComparisonResult result = new ComparisonResult(params);
    if (params.getDataFile() != null) {
        dataSet = loadDataFile(params.getDataFile());
        if (params.getGraphFile() == null) {
            throw new IllegalArgumentException("True graph file not set.");
        }
        trueDag = loadGraphFile(params.getGraphFile());
    } else {
        if (params.getNumVars() == -1) {
            throw new IllegalArgumentException("Number of variables not set.");
        }
        if (params.getNumEdges() == -1) {
            throw new IllegalArgumentException("Number of edges not set.");
        }
        if (params.getDataType() == ComparisonParameters.DataType.Continuous) {
            List<Node> nodes = new ArrayList<>();
            for (int i = 0; i < params.getNumVars(); i++) {
                nodes.add(new ContinuousVariable("X" + (i + 1)));
            }
            trueDag = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
            if (params.getDataType() == null) {
                throw new IllegalArgumentException("Data type not set or inferred.");
            }
            if (params.getSampleSize() == -1) {
                throw new IllegalArgumentException("Sample size not set.");
            }
            LargeScaleSimulation sim = new LargeScaleSimulation(trueDag);
            dataSet = sim.simulateDataFisher(params.getSampleSize());
        } else if (params.getDataType() == ComparisonParameters.DataType.Discrete) {
            List<Node> nodes = new ArrayList<>();
            for (int i = 0; i < params.getNumVars(); i++) {
                nodes.add(new DiscreteVariable("X" + (i + 1), 3));
            }
            trueDag = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
            if (params.getDataType() == null) {
                throw new IllegalArgumentException("Data type not set or inferred.");
            }
            if (params.getSampleSize() == -1) {
                throw new IllegalArgumentException("Sample size not set.");
            }
            int[] tiers = new int[nodes.size()];
            for (int i = 0; i < nodes.size(); i++) {
                tiers[i] = i;
            }
            BayesPm pm = new BayesPm(trueDag, 3, 3);
            MlBayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM);
            dataSet = im.simulateData(params.getSampleSize(), false, tiers);
        } else {
            throw new IllegalArgumentException("Unrecognized data type.");
        }
        if (dataSet == null) {
            throw new IllegalArgumentException("No data set.");
        }
    }
    if (params.getIndependenceTest() == ComparisonParameters.IndependenceTestType.FisherZ) {
        if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Continuous) {
            throw new IllegalArgumentException("Data type previously set to something other than continuous.");
        }
        if (Double.isNaN(params.getAlpha())) {
            throw new IllegalArgumentException("Alpha not set.");
        }
        test = new IndTestFisherZ(dataSet, params.getAlpha());
        params.setDataType(ComparisonParameters.DataType.Continuous);
    } else if (params.getIndependenceTest() == ComparisonParameters.IndependenceTestType.ChiSquare) {
        if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Discrete) {
            throw new IllegalArgumentException("Data type previously set to something other than discrete.");
        }
        if (Double.isNaN(params.getAlpha())) {
            throw new IllegalArgumentException("Alpha not set.");
        }
        test = new IndTestChiSquare(dataSet, params.getAlpha());
        params.setDataType(ComparisonParameters.DataType.Discrete);
    }
    if (params.getScore() == ScoreType.SemBic) {
        if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Continuous) {
            throw new IllegalArgumentException("Data type previously set to something other than continuous.");
        }
        if (Double.isNaN(params.getPenaltyDiscount())) {
            throw new IllegalArgumentException("Penalty discount not set.");
        }
        SemBicScore semBicScore = new SemBicScore(new CovarianceMatrixOnTheFly(dataSet));
        semBicScore.setPenaltyDiscount(params.getPenaltyDiscount());
        score = semBicScore;
        params.setDataType(ComparisonParameters.DataType.Continuous);
    } else if (params.getScore() == ScoreType.BDeu) {
        if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Discrete) {
            throw new IllegalArgumentException("Data type previously set to something other than discrete.");
        }
        if (Double.isNaN(params.getSamplePrior())) {
            throw new IllegalArgumentException("Sample prior not set.");
        }
        if (Double.isNaN(params.getStructurePrior())) {
            throw new IllegalArgumentException("Structure prior not set.");
        }
        score = new BDeuScore(dataSet);
        ((BDeuScore) score).setSamplePrior(params.getSamplePrior());
        ((BDeuScore) score).setStructurePrior(params.getStructurePrior());
        params.setDataType(ComparisonParameters.DataType.Discrete);
        params.setDataType(ComparisonParameters.DataType.Discrete);
    }
    if (params.getAlgorithm() == null) {
        throw new IllegalArgumentException("Algorithm not set.");
    }
    long time1 = System.currentTimeMillis();
    if (params.getAlgorithm() == ComparisonParameters.Algorithm.PC) {
        if (test == null)
            throw new IllegalArgumentException("Test not set.");
        Pc search = new Pc(test);
        result.setResultGraph(search.search());
        result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.CPC) {
        if (test == null)
            throw new IllegalArgumentException("Test not set.");
        Cpc search = new Cpc(test);
        result.setResultGraph(search.search());
        result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCLocal) {
        if (test == null)
            throw new IllegalArgumentException("Test not set.");
        PcLocal search = new PcLocal(test);
        result.setResultGraph(search.search());
        result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCStableMax) {
        if (test == null)
            throw new IllegalArgumentException("Test not set.");
        PcStableMax search = new PcStableMax(test);
        result.setResultGraph(search.search());
        result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FGES) {
        if (score == null)
            throw new IllegalArgumentException("Score not set.");
        Fges search = new Fges(score);
        search.setFaithfulnessAssumed(params.isOneEdgeFaithfulnessAssumed());
        result.setResultGraph(search.search());
        result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FGES2) {
        if (score == null)
            throw new IllegalArgumentException("Score not set.");
        Fges search = new Fges(score);
        search.setFaithfulnessAssumed(params.isOneEdgeFaithfulnessAssumed());
        result.setResultGraph(search.search());
        result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FCI) {
        if (test == null)
            throw new IllegalArgumentException("Test not set.");
        Fci search = new Fci(test);
        result.setResultGraph(search.search());
        result.setCorrectResult(new DagToPag(trueDag).convert());
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.GFCI) {
        if (test == null)
            throw new IllegalArgumentException("Test not set.");
        GFci search = new GFci(test, score);
        result.setResultGraph(search.search());
        result.setCorrectResult(new DagToPag(trueDag).convert());
    } else {
        throw new IllegalArgumentException("Unrecognized algorithm.");
    }
    long time2 = System.currentTimeMillis();
    long elapsed = time2 - time1;
    result.setElapsed(elapsed);
    result.setTrueDag(trueDag);
    return result;
}
Also used : MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) LargeScaleSimulation(edu.cmu.tetrad.sem.LargeScaleSimulation) ArrayList(java.util.ArrayList) List(java.util.List) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) Graph(edu.cmu.tetrad.graph.Graph) BayesPm(edu.cmu.tetrad.bayes.BayesPm)

Example 10 with MlBayesIm

use of edu.cmu.tetrad.bayes.MlBayesIm in project tetrad by cmu-phil.

the class TestEvidence method testUpdate1.

/**
 * Richard's 2-variable example worked by hand.
 */
@Test
public void testUpdate1() {
    BayesIm bayesIm = sampleBayesIm2();
    Evidence evidence = Evidence.tautology(bayesIm);
    evidence.getProposition().removeCategory(0, 1);
    evidence.getProposition().setVariable(1, false);
    evidence.setManipulated(0, true);
    Evidence evidence2 = new Evidence(evidence, bayesIm);
    assertEquals(evidence2, evidence);
    assertEquals(evidence, new Evidence(evidence));
    BayesIm bayesIm2 = new MlBayesIm(bayesIm);
    Evidence evidence3 = new Evidence(evidence, bayesIm2);
    assertTrue(!(evidence3.equals(evidence2)));
}
Also used : MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) BayesIm(edu.cmu.tetrad.bayes.BayesIm) MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) Evidence(edu.cmu.tetrad.bayes.Evidence) Test(org.junit.Test)

Aggregations

MlBayesIm (edu.cmu.tetrad.bayes.MlBayesIm)26 BayesPm (edu.cmu.tetrad.bayes.BayesPm)23 BayesIm (edu.cmu.tetrad.bayes.BayesIm)20 Test (org.junit.Test)15 Graph (edu.cmu.tetrad.graph.Graph)9 Node (edu.cmu.tetrad.graph.Node)8 DataSet (edu.cmu.tetrad.data.DataSet)6 Dag (edu.cmu.tetrad.graph.Dag)6 ArrayList (java.util.ArrayList)6 LargeScaleSimulation (edu.cmu.tetrad.sem.LargeScaleSimulation)4 Algorithm (edu.cmu.tetrad.algcomparison.algorithm.Algorithm)3 GraphNode (edu.cmu.tetrad.graph.GraphNode)3 Parameters (edu.cmu.tetrad.util.Parameters)3 GeneralBootstrapTest (edu.pitt.dbmi.algo.bootstrap.GeneralBootstrapTest)3 RandomGraph (edu.cmu.tetrad.algcomparison.graph.RandomGraph)2 ChiSquare (edu.cmu.tetrad.algcomparison.independence.ChiSquare)2 IndependenceWrapper (edu.cmu.tetrad.algcomparison.independence.IndependenceWrapper)2 BdeuScore (edu.cmu.tetrad.algcomparison.score.BdeuScore)2 ScoreWrapper (edu.cmu.tetrad.algcomparison.score.ScoreWrapper)2 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)2