use of edu.cmu.tetrad.bayes.BayesPm in project tetrad by cmu-phil.
the class BayesNetSimulation method simulate.
private DataSet simulate(Graph graph, Parameters parameters) {
boolean saveLatentVars = parameters.getBoolean("saveLatentVars");
try {
BayesIm im = this.im;
if (im == null) {
BayesPm pm = this.pm;
if (pm == null) {
int minCategories = parameters.getInt("minCategories");
int maxCategories = parameters.getInt("maxCategories");
pm = new BayesPm(graph, minCategories, maxCategories);
im = new MlBayesIm(pm, MlBayesIm.RANDOM);
ims.add(im);
return im.simulateData(parameters.getInt("sampleSize"), saveLatentVars);
} else {
im = new MlBayesIm(pm, MlBayesIm.RANDOM);
this.im = im;
ims.add(im);
return im.simulateData(parameters.getInt("sampleSize"), saveLatentVars);
}
} else {
ims = new ArrayList<>();
ims.add(im);
return im.simulateData(parameters.getInt("sampleSize"), saveLatentVars);
}
} catch (Exception e) {
e.printStackTrace();
throw new IllegalArgumentException("Sorry, I couldn't simulate from that Bayes IM; perhaps not all of\n" + "the parameters have been specified.");
}
}
use of edu.cmu.tetrad.bayes.BayesPm in project tetrad by cmu-phil.
the class ADTreeTest method main.
public static void main(String[] args) throws Exception {
int columns = 40;
int numEdges = 40;
int rows = 500;
List<Node> variables = new ArrayList<>();
List<String> varNames = new ArrayList<>();
for (int i = 0; i < columns; i++) {
final String name = "X" + (i + 1);
varNames.add(name);
variables.add(new ContinuousVariable(name));
}
Graph graph = GraphUtils.randomGraphRandomForwardEdges(variables, 0, numEdges, 30, 15, 15, false, true);
BayesPm pm = new BayesPm(graph);
BayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM);
DataSet data = im.simulateData(rows, false);
// This implementation uses a DataTable to represent the data
// The first type parameter is the type for the variables
// The second type parameter is the type for the values of the variables
DataTableImpl<Node, Short> dataTable = new DataTableImpl<>(variables);
for (int i = 0; i < rows; i++) {
ArrayList<Short> intArray = new ArrayList<>();
for (int j = 0; j < columns; j++) {
intArray.add((short) data.getInt(i, j));
}
dataTable.addRow(intArray);
}
// create the tree
long start = System.currentTimeMillis();
ADTree<Node, Short> adTree = new ADTree<>(dataTable);
System.out.println(String.format("Generated tree in %s millis", System.currentTimeMillis() - start));
// the query is an arbitrary map of vars and their values
TreeMap<Node, Short> query = new TreeMap<>();
query.put(node(pm, "X1"), (short) 1);
query.put(node(pm, "X5"), (short) 0);
start = System.currentTimeMillis();
System.out.println(String.format("Count is %d", adTree.count(query)));
System.out.println(String.format("Query in %s ms", System.currentTimeMillis() - start));
query.clear();
query.put(node(pm, "X1"), (short) 1);
query.put(node(pm, "X2"), (short) 1);
query.put(node(pm, "X5"), (short) 0);
query.put(node(pm, "X10"), (short) 1);
start = System.currentTimeMillis();
System.out.println(String.format("Count is %d", adTree.count(query)));
System.out.println(String.format("Query in %s ms", System.currentTimeMillis() - start));
}
use of edu.cmu.tetrad.bayes.BayesPm in project tetrad by cmu-phil.
the class EvidenceWizardMultipleObs method appendJoint.
private void appendJoint(List<Node> selectedNodes, JTextArea marginalsArea, BayesIm manipulatedIm, NumberFormat nf) {
if (!getUpdaterWrapper().getBayesUpdater().isJointMarginalSupported()) {
marginalsArea.append("\n\n(Calculation of joint not supported " + "for this updater.)");
return;
}
BayesPm bayesPm = manipulatedIm.getBayesPm();
int numNodes = selectedNodes.size();
int[] dims = new int[numNodes];
int[] variables = new int[numNodes];
int numRows = 1;
for (int i = 0; i < numNodes; i++) {
Node node = selectedNodes.get(i);
int numCategories = bayesPm.getNumCategories(node);
variables[i] = manipulatedIm.getNodeIndex(node);
dims[i] = numCategories;
numRows *= numCategories;
}
marginalsArea.append("\n\nJOINT OVER SELECTED VARIABLES:\n\n");
for (int i = 0; i < numNodes; i++) {
marginalsArea.append(selectedNodes.get(i) + "\t");
}
marginalsArea.append("Joint\tLog odds\n");
for (int row = 0; row < numRows; row++) {
int[] values = getCategories(row, dims);
double prob = getUpdaterWrapper().getBayesUpdater().getJointMarginal(variables, values);
marginalsArea.append("\n");
for (int j = 0; j < numNodes; j++) {
Node node = selectedNodes.get(j);
marginalsArea.append(bayesPm.getCategory(node, values[j]));
marginalsArea.append("\t");
}
// identifiability returns -1 if the requested prob is unidentifiable
if (prob < 0.0) {
marginalsArea.append("Unidentifiable" + "\t");
marginalsArea.append("*");
} else {
double logOdds = Math.log(prob / (1. - prob));
marginalsArea.append(nf.format(prob) + "\t");
marginalsArea.append(nf.format(logOdds));
}
}
}
use of edu.cmu.tetrad.bayes.BayesPm in project tetrad by cmu-phil.
the class Comparison method compare.
/**
* Simulates data from model paramerizing the given DAG, and runs the algorithm on that data,
* printing out error statistics.
*/
public static ComparisonResult compare(ComparisonParameters params) {
DataSet dataSet;
Graph trueDag;
IndependenceTest test = null;
Score score = null;
ComparisonResult result = new ComparisonResult(params);
if (params.getDataFile() != null) {
dataSet = loadDataFile(params.getDataFile());
if (params.getGraphFile() == null) {
throw new IllegalArgumentException("True graph file not set.");
}
trueDag = loadGraphFile(params.getGraphFile());
} else {
if (params.getNumVars() == -1) {
throw new IllegalArgumentException("Number of variables not set.");
}
if (params.getNumEdges() == -1) {
throw new IllegalArgumentException("Number of edges not set.");
}
if (params.getDataType() == ComparisonParameters.DataType.Continuous) {
List<Node> nodes = new ArrayList<>();
for (int i = 0; i < params.getNumVars(); i++) {
nodes.add(new ContinuousVariable("X" + (i + 1)));
}
trueDag = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
if (params.getDataType() == null) {
throw new IllegalArgumentException("Data type not set or inferred.");
}
if (params.getSampleSize() == -1) {
throw new IllegalArgumentException("Sample size not set.");
}
LargeScaleSimulation sim = new LargeScaleSimulation(trueDag);
dataSet = sim.simulateDataFisher(params.getSampleSize());
} else if (params.getDataType() == ComparisonParameters.DataType.Discrete) {
List<Node> nodes = new ArrayList<>();
for (int i = 0; i < params.getNumVars(); i++) {
nodes.add(new DiscreteVariable("X" + (i + 1), 3));
}
trueDag = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
if (params.getDataType() == null) {
throw new IllegalArgumentException("Data type not set or inferred.");
}
if (params.getSampleSize() == -1) {
throw new IllegalArgumentException("Sample size not set.");
}
int[] tiers = new int[nodes.size()];
for (int i = 0; i < nodes.size(); i++) {
tiers[i] = i;
}
BayesPm pm = new BayesPm(trueDag, 3, 3);
MlBayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM);
dataSet = im.simulateData(params.getSampleSize(), false, tiers);
} else {
throw new IllegalArgumentException("Unrecognized data type.");
}
if (dataSet == null) {
throw new IllegalArgumentException("No data set.");
}
}
if (params.getIndependenceTest() == ComparisonParameters.IndependenceTestType.FisherZ) {
if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Continuous) {
throw new IllegalArgumentException("Data type previously set to something other than continuous.");
}
if (Double.isNaN(params.getAlpha())) {
throw new IllegalArgumentException("Alpha not set.");
}
test = new IndTestFisherZ(dataSet, params.getAlpha());
params.setDataType(ComparisonParameters.DataType.Continuous);
} else if (params.getIndependenceTest() == ComparisonParameters.IndependenceTestType.ChiSquare) {
if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Discrete) {
throw new IllegalArgumentException("Data type previously set to something other than discrete.");
}
if (Double.isNaN(params.getAlpha())) {
throw new IllegalArgumentException("Alpha not set.");
}
test = new IndTestChiSquare(dataSet, params.getAlpha());
params.setDataType(ComparisonParameters.DataType.Discrete);
}
if (params.getScore() == ScoreType.SemBic) {
if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Continuous) {
throw new IllegalArgumentException("Data type previously set to something other than continuous.");
}
if (Double.isNaN(params.getPenaltyDiscount())) {
throw new IllegalArgumentException("Penalty discount not set.");
}
SemBicScore semBicScore = new SemBicScore(new CovarianceMatrixOnTheFly(dataSet));
semBicScore.setPenaltyDiscount(params.getPenaltyDiscount());
score = semBicScore;
params.setDataType(ComparisonParameters.DataType.Continuous);
} else if (params.getScore() == ScoreType.BDeu) {
if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Discrete) {
throw new IllegalArgumentException("Data type previously set to something other than discrete.");
}
if (Double.isNaN(params.getSamplePrior())) {
throw new IllegalArgumentException("Sample prior not set.");
}
if (Double.isNaN(params.getStructurePrior())) {
throw new IllegalArgumentException("Structure prior not set.");
}
score = new BDeuScore(dataSet);
((BDeuScore) score).setSamplePrior(params.getSamplePrior());
((BDeuScore) score).setStructurePrior(params.getStructurePrior());
params.setDataType(ComparisonParameters.DataType.Discrete);
params.setDataType(ComparisonParameters.DataType.Discrete);
}
if (params.getAlgorithm() == null) {
throw new IllegalArgumentException("Algorithm not set.");
}
long time1 = System.currentTimeMillis();
if (params.getAlgorithm() == ComparisonParameters.Algorithm.PC) {
if (test == null)
throw new IllegalArgumentException("Test not set.");
Pc search = new Pc(test);
result.setResultGraph(search.search());
result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.CPC) {
if (test == null)
throw new IllegalArgumentException("Test not set.");
Cpc search = new Cpc(test);
result.setResultGraph(search.search());
result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCLocal) {
if (test == null)
throw new IllegalArgumentException("Test not set.");
PcLocal search = new PcLocal(test);
result.setResultGraph(search.search());
result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCStableMax) {
if (test == null)
throw new IllegalArgumentException("Test not set.");
PcStableMax search = new PcStableMax(test);
result.setResultGraph(search.search());
result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FGES) {
if (score == null)
throw new IllegalArgumentException("Score not set.");
Fges search = new Fges(score);
search.setFaithfulnessAssumed(params.isOneEdgeFaithfulnessAssumed());
result.setResultGraph(search.search());
result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FGES2) {
if (score == null)
throw new IllegalArgumentException("Score not set.");
Fges search = new Fges(score);
search.setFaithfulnessAssumed(params.isOneEdgeFaithfulnessAssumed());
result.setResultGraph(search.search());
result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FCI) {
if (test == null)
throw new IllegalArgumentException("Test not set.");
Fci search = new Fci(test);
result.setResultGraph(search.search());
result.setCorrectResult(new DagToPag(trueDag).convert());
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.GFCI) {
if (test == null)
throw new IllegalArgumentException("Test not set.");
GFci search = new GFci(test, score);
result.setResultGraph(search.search());
result.setCorrectResult(new DagToPag(trueDag).convert());
} else {
throw new IllegalArgumentException("Unrecognized algorithm.");
}
long time2 = System.currentTimeMillis();
long elapsed = time2 - time1;
result.setElapsed(elapsed);
result.setTrueDag(trueDag);
return result;
}
use of edu.cmu.tetrad.bayes.BayesPm in project tetrad by cmu-phil.
the class TestProposition method sampleBayesIm2.
private BayesIm sampleBayesIm2() {
Node a = new GraphNode("a");
Node b = new GraphNode("b");
Node c = new GraphNode("c");
Dag graph;
graph = new Dag();
graph.addNode(a);
graph.addNode(b);
graph.addNode(c);
graph.addDirectedEdge(a, b);
graph.addDirectedEdge(a, c);
graph.addDirectedEdge(b, c);
BayesPm bayesPm = new BayesPm(graph);
bayesPm.setNumCategories(b, 3);
BayesIm bayesIm1 = new MlBayesIm(bayesPm);
bayesIm1.setProbability(0, 0, 0, .3);
bayesIm1.setProbability(0, 0, 1, .7);
bayesIm1.setProbability(1, 0, 0, .3);
bayesIm1.setProbability(1, 0, 1, .4);
bayesIm1.setProbability(1, 0, 2, .3);
bayesIm1.setProbability(1, 1, 0, .6);
bayesIm1.setProbability(1, 1, 1, .1);
bayesIm1.setProbability(1, 1, 2, .3);
bayesIm1.setProbability(2, 0, 0, .9);
bayesIm1.setProbability(2, 0, 1, .1);
bayesIm1.setProbability(2, 1, 0, .1);
bayesIm1.setProbability(2, 1, 1, .9);
bayesIm1.setProbability(2, 2, 0, .5);
bayesIm1.setProbability(2, 2, 1, .5);
bayesIm1.setProbability(2, 3, 0, .2);
bayesIm1.setProbability(2, 3, 1, .8);
bayesIm1.setProbability(2, 4, 0, .6);
bayesIm1.setProbability(2, 4, 1, .4);
bayesIm1.setProbability(2, 5, 0, .7);
bayesIm1.setProbability(2, 5, 1, .3);
return bayesIm1;
}
Aggregations