use of edu.cmu.tetrad.bayes.MlBayesIm in project tetrad by cmu-phil.
the class TestGFci method testRandomDiscreteData.
@Test
public void testRandomDiscreteData() {
int sampleSize = 1000;
Graph g = GraphConverter.convert("X1-->X2,X1-->X3,X1-->X4,X2-->X3,X2-->X4,X3-->X4");
Dag dag = new Dag(g);
BayesPm bayesPm = new BayesPm(dag);
BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
DataSet data = bayesIm.simulateData(sampleSize, false);
IndependenceTest test = new IndTestChiSquare(data, 0.05);
BDeuScore bDeuScore = new BDeuScore(data);
bDeuScore.setSamplePrior(1.0);
bDeuScore.setStructurePrior(1.0);
GFci gFci = new GFci(test, bDeuScore);
gFci.setFaithfulnessAssumed(true);
long start = System.currentTimeMillis();
gFci.search();
long stop = System.currentTimeMillis();
System.out.println("Elapsed " + (stop - start) + " ms");
DagToPag dagToPag = new DagToPag(g);
dagToPag.setVerbose(false);
}
use of edu.cmu.tetrad.bayes.MlBayesIm in project tetrad by cmu-phil.
the class BayesNetSimulation method simulate.
private DataSet simulate(Graph graph, Parameters parameters) {
boolean saveLatentVars = parameters.getBoolean("saveLatentVars");
try {
BayesIm im = this.im;
if (im == null) {
BayesPm pm = this.pm;
if (pm == null) {
int minCategories = parameters.getInt("minCategories");
int maxCategories = parameters.getInt("maxCategories");
pm = new BayesPm(graph, minCategories, maxCategories);
im = new MlBayesIm(pm, MlBayesIm.RANDOM);
ims.add(im);
return im.simulateData(parameters.getInt("sampleSize"), saveLatentVars);
} else {
im = new MlBayesIm(pm, MlBayesIm.RANDOM);
this.im = im;
ims.add(im);
return im.simulateData(parameters.getInt("sampleSize"), saveLatentVars);
}
} else {
ims = new ArrayList<>();
ims.add(im);
return im.simulateData(parameters.getInt("sampleSize"), saveLatentVars);
}
} catch (Exception e) {
e.printStackTrace();
throw new IllegalArgumentException("Sorry, I couldn't simulate from that Bayes IM; perhaps not all of\n" + "the parameters have been specified.");
}
}
use of edu.cmu.tetrad.bayes.MlBayesIm in project tetrad by cmu-phil.
the class ADTreeTest method main.
public static void main(String[] args) throws Exception {
int columns = 40;
int numEdges = 40;
int rows = 500;
List<Node> variables = new ArrayList<>();
List<String> varNames = new ArrayList<>();
for (int i = 0; i < columns; i++) {
final String name = "X" + (i + 1);
varNames.add(name);
variables.add(new ContinuousVariable(name));
}
Graph graph = GraphUtils.randomGraphRandomForwardEdges(variables, 0, numEdges, 30, 15, 15, false, true);
BayesPm pm = new BayesPm(graph);
BayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM);
DataSet data = im.simulateData(rows, false);
// This implementation uses a DataTable to represent the data
// The first type parameter is the type for the variables
// The second type parameter is the type for the values of the variables
DataTableImpl<Node, Short> dataTable = new DataTableImpl<>(variables);
for (int i = 0; i < rows; i++) {
ArrayList<Short> intArray = new ArrayList<>();
for (int j = 0; j < columns; j++) {
intArray.add((short) data.getInt(i, j));
}
dataTable.addRow(intArray);
}
// create the tree
long start = System.currentTimeMillis();
ADTree<Node, Short> adTree = new ADTree<>(dataTable);
System.out.println(String.format("Generated tree in %s millis", System.currentTimeMillis() - start));
// the query is an arbitrary map of vars and their values
TreeMap<Node, Short> query = new TreeMap<>();
query.put(node(pm, "X1"), (short) 1);
query.put(node(pm, "X5"), (short) 0);
start = System.currentTimeMillis();
System.out.println(String.format("Count is %d", adTree.count(query)));
System.out.println(String.format("Query in %s ms", System.currentTimeMillis() - start));
query.clear();
query.put(node(pm, "X1"), (short) 1);
query.put(node(pm, "X2"), (short) 1);
query.put(node(pm, "X5"), (short) 0);
query.put(node(pm, "X10"), (short) 1);
start = System.currentTimeMillis();
System.out.println(String.format("Count is %d", adTree.count(query)));
System.out.println(String.format("Query in %s ms", System.currentTimeMillis() - start));
}
use of edu.cmu.tetrad.bayes.MlBayesIm in project tetrad by cmu-phil.
the class Comparison method compare.
/**
* Simulates data from model paramerizing the given DAG, and runs the algorithm on that data,
* printing out error statistics.
*/
public static ComparisonResult compare(ComparisonParameters params) {
DataSet dataSet;
Graph trueDag;
IndependenceTest test = null;
Score score = null;
ComparisonResult result = new ComparisonResult(params);
if (params.getDataFile() != null) {
dataSet = loadDataFile(params.getDataFile());
if (params.getGraphFile() == null) {
throw new IllegalArgumentException("True graph file not set.");
}
trueDag = loadGraphFile(params.getGraphFile());
} else {
if (params.getNumVars() == -1) {
throw new IllegalArgumentException("Number of variables not set.");
}
if (params.getNumEdges() == -1) {
throw new IllegalArgumentException("Number of edges not set.");
}
if (params.getDataType() == ComparisonParameters.DataType.Continuous) {
List<Node> nodes = new ArrayList<>();
for (int i = 0; i < params.getNumVars(); i++) {
nodes.add(new ContinuousVariable("X" + (i + 1)));
}
trueDag = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
if (params.getDataType() == null) {
throw new IllegalArgumentException("Data type not set or inferred.");
}
if (params.getSampleSize() == -1) {
throw new IllegalArgumentException("Sample size not set.");
}
LargeScaleSimulation sim = new LargeScaleSimulation(trueDag);
dataSet = sim.simulateDataFisher(params.getSampleSize());
} else if (params.getDataType() == ComparisonParameters.DataType.Discrete) {
List<Node> nodes = new ArrayList<>();
for (int i = 0; i < params.getNumVars(); i++) {
nodes.add(new DiscreteVariable("X" + (i + 1), 3));
}
trueDag = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
if (params.getDataType() == null) {
throw new IllegalArgumentException("Data type not set or inferred.");
}
if (params.getSampleSize() == -1) {
throw new IllegalArgumentException("Sample size not set.");
}
int[] tiers = new int[nodes.size()];
for (int i = 0; i < nodes.size(); i++) {
tiers[i] = i;
}
BayesPm pm = new BayesPm(trueDag, 3, 3);
MlBayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM);
dataSet = im.simulateData(params.getSampleSize(), false, tiers);
} else {
throw new IllegalArgumentException("Unrecognized data type.");
}
if (dataSet == null) {
throw new IllegalArgumentException("No data set.");
}
}
if (params.getIndependenceTest() == ComparisonParameters.IndependenceTestType.FisherZ) {
if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Continuous) {
throw new IllegalArgumentException("Data type previously set to something other than continuous.");
}
if (Double.isNaN(params.getAlpha())) {
throw new IllegalArgumentException("Alpha not set.");
}
test = new IndTestFisherZ(dataSet, params.getAlpha());
params.setDataType(ComparisonParameters.DataType.Continuous);
} else if (params.getIndependenceTest() == ComparisonParameters.IndependenceTestType.ChiSquare) {
if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Discrete) {
throw new IllegalArgumentException("Data type previously set to something other than discrete.");
}
if (Double.isNaN(params.getAlpha())) {
throw new IllegalArgumentException("Alpha not set.");
}
test = new IndTestChiSquare(dataSet, params.getAlpha());
params.setDataType(ComparisonParameters.DataType.Discrete);
}
if (params.getScore() == ScoreType.SemBic) {
if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Continuous) {
throw new IllegalArgumentException("Data type previously set to something other than continuous.");
}
if (Double.isNaN(params.getPenaltyDiscount())) {
throw new IllegalArgumentException("Penalty discount not set.");
}
SemBicScore semBicScore = new SemBicScore(new CovarianceMatrixOnTheFly(dataSet));
semBicScore.setPenaltyDiscount(params.getPenaltyDiscount());
score = semBicScore;
params.setDataType(ComparisonParameters.DataType.Continuous);
} else if (params.getScore() == ScoreType.BDeu) {
if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Discrete) {
throw new IllegalArgumentException("Data type previously set to something other than discrete.");
}
if (Double.isNaN(params.getSamplePrior())) {
throw new IllegalArgumentException("Sample prior not set.");
}
if (Double.isNaN(params.getStructurePrior())) {
throw new IllegalArgumentException("Structure prior not set.");
}
score = new BDeuScore(dataSet);
((BDeuScore) score).setSamplePrior(params.getSamplePrior());
((BDeuScore) score).setStructurePrior(params.getStructurePrior());
params.setDataType(ComparisonParameters.DataType.Discrete);
params.setDataType(ComparisonParameters.DataType.Discrete);
}
if (params.getAlgorithm() == null) {
throw new IllegalArgumentException("Algorithm not set.");
}
long time1 = System.currentTimeMillis();
if (params.getAlgorithm() == ComparisonParameters.Algorithm.PC) {
if (test == null)
throw new IllegalArgumentException("Test not set.");
Pc search = new Pc(test);
result.setResultGraph(search.search());
result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.CPC) {
if (test == null)
throw new IllegalArgumentException("Test not set.");
Cpc search = new Cpc(test);
result.setResultGraph(search.search());
result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCLocal) {
if (test == null)
throw new IllegalArgumentException("Test not set.");
PcLocal search = new PcLocal(test);
result.setResultGraph(search.search());
result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCStableMax) {
if (test == null)
throw new IllegalArgumentException("Test not set.");
PcStableMax search = new PcStableMax(test);
result.setResultGraph(search.search());
result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FGES) {
if (score == null)
throw new IllegalArgumentException("Score not set.");
Fges search = new Fges(score);
search.setFaithfulnessAssumed(params.isOneEdgeFaithfulnessAssumed());
result.setResultGraph(search.search());
result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FGES2) {
if (score == null)
throw new IllegalArgumentException("Score not set.");
Fges search = new Fges(score);
search.setFaithfulnessAssumed(params.isOneEdgeFaithfulnessAssumed());
result.setResultGraph(search.search());
result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FCI) {
if (test == null)
throw new IllegalArgumentException("Test not set.");
Fci search = new Fci(test);
result.setResultGraph(search.search());
result.setCorrectResult(new DagToPag(trueDag).convert());
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.GFCI) {
if (test == null)
throw new IllegalArgumentException("Test not set.");
GFci search = new GFci(test, score);
result.setResultGraph(search.search());
result.setCorrectResult(new DagToPag(trueDag).convert());
} else {
throw new IllegalArgumentException("Unrecognized algorithm.");
}
long time2 = System.currentTimeMillis();
long elapsed = time2 - time1;
result.setElapsed(elapsed);
result.setTrueDag(trueDag);
return result;
}
use of edu.cmu.tetrad.bayes.MlBayesIm in project tetrad by cmu-phil.
the class TestEvidence method testUpdate1.
/**
* Richard's 2-variable example worked by hand.
*/
@Test
public void testUpdate1() {
BayesIm bayesIm = sampleBayesIm2();
Evidence evidence = Evidence.tautology(bayesIm);
evidence.getProposition().removeCategory(0, 1);
evidence.getProposition().setVariable(1, false);
evidence.setManipulated(0, true);
Evidence evidence2 = new Evidence(evidence, bayesIm);
assertEquals(evidence2, evidence);
assertEquals(evidence, new Evidence(evidence));
BayesIm bayesIm2 = new MlBayesIm(bayesIm);
Evidence evidence3 = new Evidence(evidence, bayesIm2);
assertTrue(!(evidence3.equals(evidence2)));
}
Aggregations