use of edu.cmu.tetrad.bayes.MlBayesIm in project tetrad by cmu-phil.
the class TestProposition method sampleBayesIm2.
private BayesIm sampleBayesIm2() {
Node a = new GraphNode("a");
Node b = new GraphNode("b");
Node c = new GraphNode("c");
Dag graph;
graph = new Dag();
graph.addNode(a);
graph.addNode(b);
graph.addNode(c);
graph.addDirectedEdge(a, b);
graph.addDirectedEdge(a, c);
graph.addDirectedEdge(b, c);
BayesPm bayesPm = new BayesPm(graph);
bayesPm.setNumCategories(b, 3);
BayesIm bayesIm1 = new MlBayesIm(bayesPm);
bayesIm1.setProbability(0, 0, 0, .3);
bayesIm1.setProbability(0, 0, 1, .7);
bayesIm1.setProbability(1, 0, 0, .3);
bayesIm1.setProbability(1, 0, 1, .4);
bayesIm1.setProbability(1, 0, 2, .3);
bayesIm1.setProbability(1, 1, 0, .6);
bayesIm1.setProbability(1, 1, 1, .1);
bayesIm1.setProbability(1, 1, 2, .3);
bayesIm1.setProbability(2, 0, 0, .9);
bayesIm1.setProbability(2, 0, 1, .1);
bayesIm1.setProbability(2, 1, 0, .1);
bayesIm1.setProbability(2, 1, 1, .9);
bayesIm1.setProbability(2, 2, 0, .5);
bayesIm1.setProbability(2, 2, 1, .5);
bayesIm1.setProbability(2, 3, 0, .2);
bayesIm1.setProbability(2, 3, 1, .8);
bayesIm1.setProbability(2, 4, 0, .6);
bayesIm1.setProbability(2, 4, 1, .4);
bayesIm1.setProbability(2, 5, 0, .7);
bayesIm1.setProbability(2, 5, 1, .3);
return bayesIm1;
}
use of edu.cmu.tetrad.bayes.MlBayesIm in project tetrad by cmu-phil.
the class ConditionalGaussianSimulation method simulate.
private DataSet simulate(Graph G, Parameters parameters) {
HashMap<String, Integer> nd = new HashMap<>();
List<Node> nodes = G.getNodes();
Collections.shuffle(nodes);
if (this.shuffledOrder == null) {
List<Node> shuffledNodes = new ArrayList<>(nodes);
Collections.shuffle(shuffledNodes);
this.shuffledOrder = shuffledNodes;
}
for (int i = 0; i < nodes.size(); i++) {
if (i < nodes.size() * parameters.getDouble("percentDiscrete") * 0.01) {
final int minNumCategories = parameters.getInt("minCategories");
final int maxNumCategories = parameters.getInt("maxCategories");
final int value = pickNumCategories(minNumCategories, maxNumCategories);
nd.put(shuffledOrder.get(i).getName(), value);
} else {
nd.put(shuffledOrder.get(i).getName(), 0);
}
}
G = makeMixedGraph(G, nd);
nodes = G.getNodes();
DataSet mixedData = new BoxDataSet(new MixedDataBox(nodes, parameters.getInt("sampleSize")), nodes);
List<Node> X = new ArrayList<>();
List<Node> A = new ArrayList<>();
for (Node node : G.getNodes()) {
if (node instanceof ContinuousVariable) {
X.add(node);
} else {
A.add(node);
}
}
Graph AG = G.subgraph(A);
Graph XG = G.subgraph(X);
Map<ContinuousVariable, DiscreteVariable> erstatzNodes = new HashMap<>();
Map<String, ContinuousVariable> erstatzNodesReverse = new HashMap<>();
for (Node y : A) {
for (Node x : G.getParents(y)) {
if (x instanceof ContinuousVariable) {
DiscreteVariable ersatz = erstatzNodes.get(x);
if (ersatz == null) {
ersatz = new DiscreteVariable("Ersatz_" + x.getName(), RandomUtil.getInstance().nextInt(3) + 2);
erstatzNodes.put((ContinuousVariable) x, ersatz);
erstatzNodesReverse.put(ersatz.getName(), (ContinuousVariable) x);
AG.addNode(ersatz);
}
AG.addDirectedEdge(ersatz, y);
}
}
}
BayesPm bayesPm = new BayesPm(AG);
BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
SemPm semPm = new SemPm(XG);
Map<Combination, Double> paramValues = new HashMap<>();
List<Node> tierOrdering = G.getCausalOrdering();
int[] tiers = new int[tierOrdering.size()];
for (int t = 0; t < tierOrdering.size(); t++) {
tiers[t] = nodes.indexOf(tierOrdering.get(t));
}
Map<Integer, double[]> breakpointsMap = new HashMap<>();
for (int mixedIndex : tiers) {
for (int i = 0; i < parameters.getInt("sampleSize"); i++) {
if (nodes.get(mixedIndex) instanceof DiscreteVariable) {
int bayesIndex = bayesIm.getNodeIndex(nodes.get(mixedIndex));
int[] bayesParents = bayesIm.getParents(bayesIndex);
int[] parentValues = new int[bayesParents.length];
for (int k = 0; k < parentValues.length; k++) {
int bayesParentColumn = bayesParents[k];
Node bayesParent = bayesIm.getVariables().get(bayesParentColumn);
DiscreteVariable _parent = (DiscreteVariable) bayesParent;
int value;
ContinuousVariable orig = erstatzNodesReverse.get(_parent.getName());
if (orig != null) {
int mixedParentColumn = mixedData.getColumn(orig);
double d = mixedData.getDouble(i, mixedParentColumn);
double[] breakpoints = breakpointsMap.get(mixedParentColumn);
if (breakpoints == null) {
breakpoints = getBreakpoints(mixedData, _parent, mixedParentColumn);
breakpointsMap.put(mixedParentColumn, breakpoints);
}
value = breakpoints.length;
for (int j = 0; j < breakpoints.length; j++) {
if (d < breakpoints[j]) {
value = j;
break;
}
}
} else {
int mixedColumn = mixedData.getColumn(bayesParent);
value = mixedData.getInt(i, mixedColumn);
}
parentValues[k] = value;
}
int rowIndex = bayesIm.getRowIndex(bayesIndex, parentValues);
double sum = 0.0;
double r = RandomUtil.getInstance().nextDouble();
mixedData.setInt(i, mixedIndex, 0);
for (int k = 0; k < bayesIm.getNumColumns(bayesIndex); k++) {
double probability = bayesIm.getProbability(bayesIndex, rowIndex, k);
sum += probability;
if (sum >= r) {
mixedData.setInt(i, mixedIndex, k);
break;
}
}
} else {
Node y = nodes.get(mixedIndex);
Set<DiscreteVariable> discreteParents = new HashSet<>();
Set<ContinuousVariable> continuousParents = new HashSet<>();
for (Node node : G.getParents(y)) {
if (node instanceof DiscreteVariable) {
discreteParents.add((DiscreteVariable) node);
} else {
continuousParents.add((ContinuousVariable) node);
}
}
Parameter varParam = semPm.getParameter(y, y);
Parameter muParam = semPm.getMeanParameter(y);
Combination varComb = new Combination(varParam);
Combination muComb = new Combination(muParam);
for (DiscreteVariable v : discreteParents) {
varComb.addParamValue(v, mixedData.getInt(i, mixedData.getColumn(v)));
muComb.addParamValue(v, mixedData.getInt(i, mixedData.getColumn(v)));
}
double value = RandomUtil.getInstance().nextNormal(0, getParamValue(varComb, paramValues));
for (Node x : continuousParents) {
Parameter coefParam = semPm.getParameter(x, y);
Combination coefComb = new Combination(coefParam);
for (DiscreteVariable v : discreteParents) {
coefComb.addParamValue(v, mixedData.getInt(i, mixedData.getColumn(v)));
}
int parent = nodes.indexOf(x);
double parentValue = mixedData.getDouble(i, parent);
double parentCoef = getParamValue(coefComb, paramValues);
value += parentValue * parentCoef;
}
value += getParamValue(muComb, paramValues);
mixedData.setDouble(i, mixedIndex, value);
}
}
}
boolean saveLatentVars = parameters.getBoolean("saveLatentVars");
return saveLatentVars ? mixedData : DataUtils.restrictToMeasured(mixedData);
}
use of edu.cmu.tetrad.bayes.MlBayesIm in project tetrad by cmu-phil.
the class Comparison2 method compare.
/**
* Simulates data from model parameterizing the given DAG, and runs the
* algorithm on that data, printing out error statistics.
*/
public static ComparisonResult compare(ComparisonParameters params) {
DataSet dataSet = null;
Graph trueDag = null;
IndependenceTest test = null;
Score score = null;
ComparisonResult result = new ComparisonResult(params);
if (params.isDataFromFile()) {
/**
* Set path to the data directory *
*/
String path = "/Users/dmalinsky/Documents/research/data/danexamples";
File dir = new File(path);
File[] files = dir.listFiles();
if (files == null) {
throw new NullPointerException("No files in " + path);
}
for (File file : files) {
if (file.getName().startsWith("graph") && file.getName().contains(String.valueOf(params.getGraphNum())) && file.getName().endsWith(".g.txt")) {
params.setGraphFile(file.getName());
trueDag = GraphUtils.loadGraphTxt(file);
break;
}
}
String trialGraph = String.valueOf(params.getGraphNum()).concat("-").concat(String.valueOf(params.getTrial())).concat(".dat.txt");
for (File file : files) {
if (file.getName().startsWith("graph") && file.getName().endsWith(trialGraph)) {
Path dataFile = Paths.get(path.concat("/").concat(file.getName()));
Delimiter delimiter = Delimiter.TAB;
if (params.getDataType() == ComparisonParameters.DataType.Continuous) {
try {
TabularDataReader dataReader = new ContinuousTabularDataFileReader(dataFile.toFile(), delimiter);
dataSet = (DataSet) DataConvertUtils.toDataModel(dataReader.readInData());
} catch (IOException e) {
e.printStackTrace();
}
params.setDataFile(file.getName());
break;
} else {
try {
TabularDataReader dataReader = new VerticalDiscreteTabularDataReader(dataFile.toFile(), delimiter);
dataSet = (DataSet) DataConvertUtils.toDataModel(dataReader.readInData());
} catch (IOException e) {
e.printStackTrace();
}
params.setDataFile(file.getName());
break;
}
}
}
System.out.println("current graph file = " + params.getGraphFile());
System.out.println("current data set file = " + params.getDataFile());
}
if (params.isNoData()) {
List<Node> nodes = new ArrayList<>();
for (int i = 0; i < params.getNumVars(); i++) {
nodes.add(new ContinuousVariable("X" + (i + 1)));
}
trueDag = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
/**
* added 5.25.16 for tsFCI *
*/
if (params.getAlgorithm() == ComparisonParameters.Algorithm.TsFCI) {
trueDag = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
trueDag = TimeSeriesUtils.graphToLagGraph(trueDag, 2);
System.out.println("Creating Time Lag Graph : " + trueDag);
}
/**
* ************************
*/
test = new IndTestDSep(trueDag);
score = new GraphScore(trueDag);
if (params.getAlgorithm() == null) {
throw new IllegalArgumentException("Algorithm not set.");
}
long time1 = System.currentTimeMillis();
if (params.getAlgorithm() == ComparisonParameters.Algorithm.PC) {
if (test == null) {
throw new IllegalArgumentException("Test not set.");
}
Pc search = new Pc(test);
result.setResultGraph(search.search());
result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.CPC) {
if (test == null) {
throw new IllegalArgumentException("Test not set.");
}
Cpc search = new Cpc(test);
result.setResultGraph(search.search());
result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCLocal) {
if (test == null) {
throw new IllegalArgumentException("Test not set.");
}
PcLocal search = new PcLocal(test);
result.setResultGraph(search.search());
result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCStableMax) {
if (test == null) {
throw new IllegalArgumentException("Test not set.");
}
PcStableMax search = new PcStableMax(test);
result.setResultGraph(search.search());
result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FGES) {
if (score == null) {
throw new IllegalArgumentException("Score not set.");
}
Fges search = new Fges(score);
// search.setFaithfulnessAssumed(params.isOneEdgeFaithfulnessAssumed());
result.setResultGraph(search.search());
result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FCI) {
if (test == null) {
throw new IllegalArgumentException("Test not set.");
}
Fci search = new Fci(test);
result.setResultGraph(search.search());
result.setCorrectResult(new DagToPag(trueDag).convert());
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.GFCI) {
if (test == null) {
throw new IllegalArgumentException("Test not set.");
}
GFci search = new GFci(test, score);
result.setResultGraph(search.search());
result.setCorrectResult(new DagToPag(trueDag).convert());
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.TsFCI) {
if (test == null) {
throw new IllegalArgumentException("Test not set.");
}
TsFci search = new TsFci(test);
IKnowledge knowledge = getKnowledge(trueDag);
search.setKnowledge(knowledge);
result.setResultGraph(search.search());
result.setCorrectResult(new TsDagToPag(trueDag).convert());
System.out.println("Correct result for trial = " + result.getCorrectResult());
System.out.println("Search result for trial = " + result.getResultGraph());
} else {
throw new IllegalArgumentException("Unrecognized algorithm.");
}
long time2 = System.currentTimeMillis();
long elapsed = time2 - time1;
result.setElapsed(elapsed);
result.setTrueDag(trueDag);
return result;
} else if (params.getDataFile() != null) {
// dataSet = loadDataFile(params.getDataFile());
System.out.println("Using data from file... ");
if (params.getGraphFile() == null) {
throw new IllegalArgumentException("True graph file not set.");
} else {
System.out.println("Using graph from file... ");
// trueDag = GraphUtils.loadGraph(File params.getGraphFile());
}
} else {
if (params.getNumVars() == -1) {
throw new IllegalArgumentException("Number of variables not set.");
}
if (params.getNumEdges() == -1) {
throw new IllegalArgumentException("Number of edges not set.");
}
if (params.getDataType() == ComparisonParameters.DataType.Continuous) {
List<Node> nodes = new ArrayList<>();
for (int i = 0; i < params.getNumVars(); i++) {
nodes.add(new ContinuousVariable("X" + (i + 1)));
}
trueDag = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
/**
* added 6.08.16 for tsFCI *
*/
if (params.getAlgorithm() == ComparisonParameters.Algorithm.TsFCI) {
trueDag = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
trueDag = TimeSeriesUtils.graphToLagGraph(trueDag, 2);
System.out.println("Creating Time Lag Graph : " + trueDag);
}
if (params.getDataType() == null) {
throw new IllegalArgumentException("Data type not set or inferred.");
}
if (params.getSampleSize() == -1) {
throw new IllegalArgumentException("Sample size not set.");
}
LargeScaleSimulation sim = new LargeScaleSimulation(trueDag);
/**
* added 6.08.16 for tsFCI *
*/
if (params.getAlgorithm() == ComparisonParameters.Algorithm.TsFCI) {
sim.setCoefRange(0.20, 0.50);
}
/**
* added 6.08.16 for tsFCI *
*/
if (params.getAlgorithm() == ComparisonParameters.Algorithm.TsFCI) {
// // System.out.println("Coefs matrix : " + sim.getCoefs());
// System.out.println(MatrixUtils.toString(sim.getCoefficientMatrix()));
// // System.out.println("dim = " + sim.getCoefs()[1][1]);
// boolean isStableTetradMatrix = allEigenvaluesAreSmallerThanOneInModulus(new TetradMatrix(sim.getCoefficientMatrix()));
// //this TetradMatrix needs to be the matrix of coefficients from the SEM!
// if (!isStableTetradMatrix) {
// System.out.println("%%%%%%%%%% WARNING %%%%%%%%% not a stable set of eigenvalues for data generation");
// System.out.println("Skipping this attempt!");
// sim.setCoefRange(0.2, 0.5);
// dataSet = sim.simulateDataAcyclic(params.getSampleSize());
// }
//
// /***************************/
boolean isStableTetradMatrix;
int attempt = 1;
int tierSize = params.getNumVars();
int[] sub = new int[tierSize];
int[] sub2 = new int[tierSize];
for (int i = 0; i < tierSize; i++) {
sub[i] = i;
sub2[i] = tierSize + i;
}
do {
dataSet = sim.simulateDataFisher(params.getSampleSize());
// System.out.println("Variable Nodes : " + sim.getVariableNodes());
// System.out.println(MatrixUtils.toString(sim.getCoefficientMatrix()));
TetradMatrix coefMat = new TetradMatrix(sim.getCoefficientMatrix());
TetradMatrix B = coefMat.getSelection(sub, sub);
TetradMatrix Gamma1 = coefMat.getSelection(sub2, sub);
TetradMatrix Gamma0 = TetradMatrix.identity(tierSize).minus(B);
TetradMatrix A1 = Gamma0.inverse().times(Gamma1);
// TetradMatrix B2 = coefMat.getSelection(sub2, sub2);
// System.out.println("B matrix : " + B);
// System.out.println("B2 matrix : " + B2);
// System.out.println("Gamma1 matrix : " + Gamma1);
// isStableTetradMatrix = allEigenvaluesAreSmallerThanOneInModulus(new TetradMatrix(sim.getCoefficientMatrix()));
isStableTetradMatrix = TimeSeriesUtils.allEigenvaluesAreSmallerThanOneInModulus(A1);
System.out.println("isStableTetradMatrix? : " + isStableTetradMatrix);
attempt++;
} while ((!isStableTetradMatrix) && attempt <= 5);
if (!isStableTetradMatrix) {
System.out.println("%%%%%%%%%% WARNING %%%%%%%% not a stable coefficient matrix, forcing coefs to [0.15,0.3]");
System.out.println("Made " + (attempt - 1) + " attempts to get stable matrix.");
sim.setCoefRange(0.15, 0.3);
dataSet = sim.simulateDataFisher(params.getSampleSize());
} else {
System.out.println("Coefficient matrix is stable.");
}
}
} else if (params.getDataType() == ComparisonParameters.DataType.Discrete) {
List<Node> nodes = new ArrayList<>();
for (int i = 0; i < params.getNumVars(); i++) {
nodes.add(new DiscreteVariable("X" + (i + 1), 3));
}
trueDag = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
if (params.getDataType() == null) {
throw new IllegalArgumentException("Data type not set or inferred.");
}
if (params.getSampleSize() == -1) {
throw new IllegalArgumentException("Sample size not set.");
}
int[] tiers = new int[nodes.size()];
for (int i = 0; i < nodes.size(); i++) {
tiers[i] = i;
}
BayesPm pm = new BayesPm(trueDag, 3, 3);
MlBayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM);
dataSet = im.simulateData(params.getSampleSize(), false, tiers);
} else {
throw new IllegalArgumentException("Unrecognized data type.");
}
if (dataSet == null) {
throw new IllegalArgumentException("No data set.");
}
}
if (params.getIndependenceTest() == ComparisonParameters.IndependenceTestType.FisherZ) {
if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Continuous) {
throw new IllegalArgumentException("Data type previously set to something other than continuous.");
}
if (Double.isNaN(params.getAlpha())) {
throw new IllegalArgumentException("Alpha not set.");
}
test = new IndTestFisherZ(dataSet, params.getAlpha());
params.setDataType(ComparisonParameters.DataType.Continuous);
} else if (params.getIndependenceTest() == ComparisonParameters.IndependenceTestType.ChiSquare) {
if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Discrete) {
throw new IllegalArgumentException("Data type previously set to something other than discrete.");
}
if (Double.isNaN(params.getAlpha())) {
throw new IllegalArgumentException("Alpha not set.");
}
test = new IndTestChiSquare(dataSet, params.getAlpha());
params.setDataType(ComparisonParameters.DataType.Discrete);
}
if (params.getScore() == ScoreType.SemBic) {
if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Continuous) {
throw new IllegalArgumentException("Data type previously set to something other than continuous.");
}
if (Double.isNaN(params.getPenaltyDiscount())) {
throw new IllegalArgumentException("Penalty discount not set.");
}
SemBicScore semBicScore = new SemBicScore(new CovarianceMatrixOnTheFly(dataSet));
semBicScore.setPenaltyDiscount(params.getPenaltyDiscount());
score = semBicScore;
params.setDataType(ComparisonParameters.DataType.Continuous);
} else if (params.getScore() == ScoreType.BDeu) {
if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Discrete) {
throw new IllegalArgumentException("Data type previously set to something other than discrete.");
}
if (Double.isNaN(params.getSamplePrior())) {
throw new IllegalArgumentException("Sample prior not set.");
}
if (Double.isNaN(params.getStructurePrior())) {
throw new IllegalArgumentException("Structure prior not set.");
}
score = new BDeuScore(dataSet);
((BDeuScore) score).setSamplePrior(params.getSamplePrior());
((BDeuScore) score).setStructurePrior(params.getStructurePrior());
params.setDataType(ComparisonParameters.DataType.Discrete);
params.setDataType(ComparisonParameters.DataType.Discrete);
}
if (params.getAlgorithm() == null) {
throw new IllegalArgumentException("Algorithm not set.");
}
long time1 = System.currentTimeMillis();
if (params.getAlgorithm() == ComparisonParameters.Algorithm.PC) {
if (test == null) {
throw new IllegalArgumentException("Test not set.");
}
Pc search = new Pc(test);
result.setResultGraph(search.search());
result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.CPC) {
if (test == null) {
throw new IllegalArgumentException("Test not set.");
}
Cpc search = new Cpc(test);
result.setResultGraph(search.search());
result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCLocal) {
if (test == null) {
throw new IllegalArgumentException("Test not set.");
}
PcLocal search = new PcLocal(test);
result.setResultGraph(search.search());
result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCStableMax) {
if (test == null) {
throw new IllegalArgumentException("Test not set.");
}
PcStableMax search = new PcStableMax(test);
result.setResultGraph(search.search());
result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FGES) {
if (score == null) {
throw new IllegalArgumentException("Score not set.");
}
Fges search = new Fges(score);
// search.setFaithfulnessAssumed(params.isOneEdgeFaithfulnessAssumed());
result.setResultGraph(search.search());
result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FCI) {
if (test == null) {
throw new IllegalArgumentException("Test not set.");
}
Fci search = new Fci(test);
result.setResultGraph(search.search());
result.setCorrectResult(new DagToPag(trueDag).convert());
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.GFCI) {
if (test == null) {
throw new IllegalArgumentException("Test not set.");
}
GFci search = new GFci(test, score);
result.setResultGraph(search.search());
result.setCorrectResult(new DagToPag(trueDag).convert());
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.TsFCI) {
if (test == null) {
throw new IllegalArgumentException("Test not set.");
}
TsFci search = new TsFci(test);
IKnowledge knowledge = getKnowledge(trueDag);
search.setKnowledge(knowledge);
result.setResultGraph(search.search());
result.setCorrectResult(new TsDagToPag(trueDag).convert());
} else {
throw new IllegalArgumentException("Unrecognized algorithm.");
}
long time2 = System.currentTimeMillis();
long elapsed = time2 - time1;
result.setElapsed(elapsed);
result.setTrueDag(trueDag);
return result;
}
use of edu.cmu.tetrad.bayes.MlBayesIm in project tetrad by cmu-phil.
the class PerformanceTests method testFgesMb.
private void testFgesMb(int numVars, double edgeFactor, int numCases, int numRuns, boolean continuous) {
double penaltyDiscount = 4.0;
int structurePrior = 10;
int samplePrior = 10;
int maxIndegree = -1;
// boolean faithfulness = false;
List<int[][]> allCounts = new ArrayList<>();
List<double[]> comparisons = new ArrayList<>();
List<Double> degrees = new ArrayList<>();
List<Long> elapsedTimes = new ArrayList<>();
System.out.println("Making dag");
Graph dag = makeDag(numVars, edgeFactor);
System.out.println(new Date());
System.out.println("Calculating pattern for DAG");
Graph pattern = SearchGraphUtils.patternForDag(dag);
int[] tiers = new int[dag.getNumNodes()];
for (int i = 0; i < dag.getNumNodes(); i++) {
tiers[i] = i;
}
System.out.println("Graph done");
long time1 = System.currentTimeMillis();
out.println("Graph done");
System.out.println(new Date());
System.out.println("Starting simulation");
Graph estPattern;
long elapsed;
FgesMb fges;
List<Node> vars;
if (continuous) {
init(new File("FgesMb.comparison.continuous" + numVars + "." + (int) (edgeFactor * numVars) + "." + numCases + "." + numRuns + ".txt"), "Num runs = " + numRuns);
out.println("Num vars = " + numVars);
out.println("Num edges = " + (int) (numVars * edgeFactor));
out.println("Num cases = " + numCases);
out.println("Penalty discount = " + penaltyDiscount);
out.println("Depth = " + maxIndegree);
out.println();
out.println(new Date());
vars = dag.getNodes();
LargeScaleSimulation simulator = new LargeScaleSimulation(dag, vars, tiers);
simulator.setVerbose(false);
simulator.setOut(out);
DataSet data = simulator.simulateDataFisher(numCases);
System.out.println("Finishing simulation");
System.out.println(new Date());
long time2 = System.currentTimeMillis();
out.println("Elapsed (simulating the data): " + (time2 - time1) + " ms");
System.out.println(new Date());
System.out.println("Making covariance matrix");
long time3 = System.currentTimeMillis();
ICovarianceMatrix cov = new CovarianceMatrixOnTheFly(data, true);
System.out.println("Covariance matrix done");
out.println("Elapsed (calculating cov): " + (time3 - time2) + " ms\n");
SemBicScore score = new SemBicScore(cov);
score.setPenaltyDiscount(penaltyDiscount);
System.out.println(new Date());
System.out.println("\nStarting FGES-MB");
fges = new FgesMb(score);
fges.setVerbose(false);
fges.setNumPatternsToStore(0);
fges.setOut(System.out);
// fges.setHeuristicSpeedup(faithfulness);
fges.setMaxIndegree(maxIndegree);
fges.setCycleBound(-1);
} else {
init(new File("FgesMb.comparison.discrete" + numVars + "." + (int) (edgeFactor * numVars) + "." + numCases + "." + numRuns + ".txt"), "Num runs = " + numRuns);
out.println("Num vars = " + numVars);
out.println("Num edges = " + (int) (numVars * edgeFactor));
out.println("Num cases = " + numCases);
out.println("Sample prior = " + samplePrior);
out.println("Structure prior = " + structurePrior);
out.println("Depth = " + maxIndegree);
out.println();
out.println(new Date());
BayesPm pm = new BayesPm(dag, 3, 3);
MlBayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM);
DataSet data = im.simulateData(numCases, false, tiers);
vars = data.getVariables();
pattern = GraphUtils.replaceNodes(pattern, vars);
System.out.println("Finishing simulation");
long time2 = System.currentTimeMillis();
out.println("Elapsed (simulating the data): " + (time2 - time1) + " ms");
long time3 = System.currentTimeMillis();
BDeuScore score = new BDeuScore(data);
score.setStructurePrior(structurePrior);
score.setSamplePrior(samplePrior);
System.out.println(new Date());
System.out.println("\nStarting FGES");
long time4 = System.currentTimeMillis();
fges = new FgesMb(score);
fges.setVerbose(false);
fges.setNumPatternsToStore(0);
fges.setOut(System.out);
// fges.setHeuristicSpeedup(faithfulness);
fges.setMaxIndegree(maxIndegree);
fges.setCycleBound(-1);
long timeb = System.currentTimeMillis();
out.println("Time consructing BDeu score " + (time4 - time3) + " ms");
out.println("Time for FGES-MB constructor " + (timeb - time4) + " ms");
out.println();
}
int numSkipped = 0;
for (int run = 0; run < numRuns; run++) {
out.println("\n\n\n******************************** RUN " + (run + 1) + " ********************************\n\n");
Node target = vars.get(RandomUtil.getInstance().nextInt(vars.size()));
System.out.println("Target = " + target);
long timea = System.currentTimeMillis();
estPattern = fges.search(target);
long timed = System.currentTimeMillis();
elapsed = timed - timea;
Set<Node> mb = new HashSet<>();
mb.add(target);
mb.addAll(pattern.getAdjacentNodes(target));
for (Node child : pattern.getChildren(target)) {
mb.addAll(pattern.getParents(child));
}
Graph trueMbGraph = pattern.subgraph(new ArrayList<>(mb));
long timec = System.currentTimeMillis();
out.println("Time for FGES-MB search " + (timec - timea) + " ms");
out.println();
System.out.println("Done with FGES");
System.out.println(new Date());
double[] comparison = new double[4];
System.out.println("Counting misclassifications.");
int[][] counts = GraphUtils.edgeMisclassificationCounts(trueMbGraph, estPattern, false);
allCounts.add(counts);
System.out.println(new Date());
int sumRow = counts[4][0] + counts[4][3] + counts[4][5];
int sumCol = counts[0][3] + counts[4][3] + counts[5][3] + counts[7][3];
int trueArrow = counts[4][3];
int sumTrueAdjacencies = 0;
for (int i = 0; i < 7; i++) {
for (int j = 0; j < 5; j++) {
sumTrueAdjacencies += counts[i][j];
}
}
int falsePositiveAdjacencies = 0;
for (int j = 0; j < 5; j++) {
falsePositiveAdjacencies += counts[7][j];
}
int falseNegativeAdjacencies = 0;
for (int i = 0; i < 5; i++) {
falseNegativeAdjacencies += counts[i][5];
}
comparison[0] = sumTrueAdjacencies / (double) (sumTrueAdjacencies + falsePositiveAdjacencies);
comparison[1] = sumTrueAdjacencies / (double) (sumTrueAdjacencies + falseNegativeAdjacencies);
comparison[2] = trueArrow / (double) sumCol;
comparison[3] = trueArrow / (double) sumRow;
// if (Double.isNaN(comparison[0]) || Double.isNaN(comparison[1]) || Double.isNaN(comparison[2]) ||
// Double.isNaN(comparison[3])) {
// run--;
// numSkipped++;
// continue;
// }
comparisons.add(comparison);
out.println(GraphUtils.edgeMisclassifications(counts));
out.println(precisionRecall(comparison));
// printAverageConfusion("Average", allCounts);
elapsedTimes.add(elapsed);
out.println("\nElapsed: " + elapsed + " ms");
}
printAverageConfusion("Average", allCounts, new DecimalFormat("0.0"));
printAveragePrecisionRecall(comparisons);
out.println("Number of runs skipped because of undefined accuracies: " + numSkipped);
printAverageStatistics(elapsedTimes, degrees);
out.close();
}
use of edu.cmu.tetrad.bayes.MlBayesIm in project tetrad by cmu-phil.
the class PerformanceTests method testFges.
private void testFges(int numVars, double edgeFactor, int numCases, int numRuns, boolean continuous) {
out.println(new Date());
// RandomUtil.getInstance().setSeed(4828384343999L);
double penaltyDiscount = 4.0;
int maxIndegree = 5;
boolean faithfulness = true;
// RandomUtil.getInstance().setSeed(50304050454L);
List<int[][]> allCounts = new ArrayList<>();
List<double[]> comparisons = new ArrayList<>();
List<Double> degrees = new ArrayList<>();
List<Long> elapsedTimes = new ArrayList<>();
if (continuous) {
init(new File("fges.comparison.continuous" + numVars + "." + (int) (edgeFactor * numVars) + "." + numCases + "." + numRuns + ".txt"), "Num runs = " + numRuns);
out.println("Num vars = " + numVars);
out.println("Num edges = " + (int) (numVars * edgeFactor));
out.println("Num cases = " + numCases);
out.println("Penalty discount = " + penaltyDiscount);
out.println("Depth = " + maxIndegree);
out.println();
} else {
init(new File("fges.comparison.discrete" + numVars + "." + (int) (edgeFactor * numVars) + "." + numCases + "." + numRuns + ".txt"), "Num runs = " + numRuns);
out.println("Num vars = " + numVars);
out.println("Num edges = " + (int) (numVars * edgeFactor));
out.println("Num cases = " + numCases);
out.println("Sample prior = " + 1);
out.println("Structure prior = " + 1);
out.println("Depth = " + 1);
out.println();
}
for (int run = 0; run < numRuns; run++) {
out.println("\n\n\n******************************** RUN " + (run + 1) + " ********************************\n\n");
System.out.println("Making dag");
out.println(new Date());
Graph dag = makeDag(numVars, edgeFactor);
System.out.println(new Date());
System.out.println("Calculating pattern for DAG");
Graph pattern = SearchGraphUtils.patternForDag(dag);
List<Node> vars = dag.getNodes();
int[] tiers = new int[vars.size()];
for (int i = 0; i < vars.size(); i++) {
tiers[i] = i;
}
System.out.println("Graph done");
long time1 = System.currentTimeMillis();
out.println("Graph done");
System.out.println(new Date());
System.out.println("Starting simulation");
Graph estPattern;
long elapsed;
if (continuous) {
LargeScaleSimulation simulator = new LargeScaleSimulation(dag, vars, tiers);
simulator.setVerbose(false);
simulator.setOut(out);
DataSet data = simulator.simulateDataFisher(numCases);
System.out.println("Finishing simulation");
System.out.println(new Date());
long time2 = System.currentTimeMillis();
out.println("Elapsed (simulating the data): " + (time2 - time1) + " ms");
System.out.println(new Date());
System.out.println("Making covariance matrix");
long time3 = System.currentTimeMillis();
ICovarianceMatrix cov = new CovarianceMatrixOnTheFly(data, true);
System.out.println("Covariance matrix done");
out.println("Elapsed (calculating cov): " + (time3 - time2) + " ms\n");
SemBicScore score = new SemBicScore(cov);
score.setPenaltyDiscount(penaltyDiscount);
System.out.println(new Date());
System.out.println("\nStarting FGES");
long timea = System.currentTimeMillis();
Fges fges = new Fges(score);
// fges.setVerbose(false);
fges.setNumPatternsToStore(0);
fges.setOut(System.out);
fges.setFaithfulnessAssumed(faithfulness);
fges.setCycleBound(-1);
long timeb = System.currentTimeMillis();
estPattern = fges.search();
long timec = System.currentTimeMillis();
out.println("Time for FGES constructor " + (timeb - timea) + " ms");
out.println("Time for FGES search " + (timec - timea) + " ms");
out.println();
out.flush();
elapsed = timec - timea;
} else {
BayesPm pm = new BayesPm(dag, 3, 3);
MlBayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM);
DataSet data = im.simulateData(numCases, false, tiers);
System.out.println("Finishing simulation");
long time2 = System.currentTimeMillis();
out.println("Elapsed (simulating the data): " + (time2 - time1) + " ms");
long time3 = System.currentTimeMillis();
BDeuScore score = new BDeuScore(data);
score.setStructurePrior(1);
score.setSamplePrior(1);
System.out.println(new Date());
System.out.println("\nStarting FGES");
long timea = System.currentTimeMillis();
Fges fges = new Fges(score);
// fges.setVerbose(false);
fges.setNumPatternsToStore(0);
fges.setOut(System.out);
fges.setFaithfulnessAssumed(faithfulness);
fges.setCycleBound(-1);
long timeb = System.currentTimeMillis();
estPattern = fges.search();
long timec = System.currentTimeMillis();
out.println("Time consructing BDeu score " + (timea - time3) + " ms");
out.println("Time for FGES constructor " + (timeb - timea) + " ms");
out.println("Time for FGES search " + (timec - timea) + " ms");
out.println();
elapsed = timec - timea;
}
System.out.println("Done with FGES");
System.out.println(new Date());
// System.out.println("Replacing nodes");d
//
// estPattern = GraphUtils.replaceNodes(estPattern, dag.getNodes());
// System.out.println("Calculating degree");
//
// double degree = GraphUtils.degree(estPattern);
// degrees.add(degree);
//
// out.println("Degree out output graph = " + degree);
double[] comparison = new double[4];
// int adjFn = GraphUtils.countAdjErrors(pattern, estPattern);
// int adjFp = GraphUtils.countAdjErrors(estPattern, pattern);
// int trueAdj = pattern.getNumEdges();
//
// comparison[0] = trueAdj / (double) (trueAdj + adjFp);
// comparison[1] = trueAdj / (double) (trueAdj + adjFn);
System.out.println("Counting misclassifications.");
estPattern = GraphUtils.replaceNodes(estPattern, pattern.getNodes());
int[][] counts = GraphUtils.edgeMisclassificationCounts(pattern, estPattern, false);
allCounts.add(counts);
System.out.println(new Date());
int sumRow = counts[4][0] + counts[4][3] + counts[4][5];
int sumCol = counts[0][3] + counts[4][3] + counts[5][3] + counts[7][3];
int trueArrow = counts[4][3];
int sumTrueAdjacencies = 0;
for (int i = 0; i < 7; i++) {
for (int j = 0; j < 5; j++) {
sumTrueAdjacencies += counts[i][j];
}
}
int falsePositiveAdjacencies = 0;
for (int j = 0; j < 5; j++) {
falsePositiveAdjacencies += counts[7][j];
}
int falseNegativeAdjacencies = 0;
for (int i = 0; i < 5; i++) {
falseNegativeAdjacencies += counts[i][5];
}
comparison[0] = sumTrueAdjacencies / (double) (sumTrueAdjacencies + falsePositiveAdjacencies);
comparison[1] = sumTrueAdjacencies / (double) (sumTrueAdjacencies + falseNegativeAdjacencies);
comparison[2] = trueArrow / (double) sumCol;
comparison[3] = trueArrow / (double) sumRow;
comparisons.add(comparison);
out.println(GraphUtils.edgeMisclassifications(counts));
out.println(precisionRecall(comparison));
elapsedTimes.add(elapsed);
out.println("\nElapsed: " + elapsed + " ms");
}
printAverageConfusion("Average", allCounts);
printAveragePrecisionRecall(comparisons);
printAverageStatistics(elapsedTimes, degrees);
out.close();
}
Aggregations