use of edu.cmu.tetrad.sem.LargeScaleSimulation in project tetrad by cmu-phil.
the class Comparison method compare.
/**
* Simulates data from model paramerizing the given DAG, and runs the algorithm on that data,
* printing out error statistics.
*/
public static ComparisonResult compare(ComparisonParameters params) {
DataSet dataSet;
Graph trueDag;
IndependenceTest test = null;
Score score = null;
ComparisonResult result = new ComparisonResult(params);
if (params.getDataFile() != null) {
dataSet = loadDataFile(params.getDataFile());
if (params.getGraphFile() == null) {
throw new IllegalArgumentException("True graph file not set.");
}
trueDag = loadGraphFile(params.getGraphFile());
} else {
if (params.getNumVars() == -1) {
throw new IllegalArgumentException("Number of variables not set.");
}
if (params.getNumEdges() == -1) {
throw new IllegalArgumentException("Number of edges not set.");
}
if (params.getDataType() == ComparisonParameters.DataType.Continuous) {
List<Node> nodes = new ArrayList<>();
for (int i = 0; i < params.getNumVars(); i++) {
nodes.add(new ContinuousVariable("X" + (i + 1)));
}
trueDag = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
if (params.getDataType() == null) {
throw new IllegalArgumentException("Data type not set or inferred.");
}
if (params.getSampleSize() == -1) {
throw new IllegalArgumentException("Sample size not set.");
}
LargeScaleSimulation sim = new LargeScaleSimulation(trueDag);
dataSet = sim.simulateDataFisher(params.getSampleSize());
} else if (params.getDataType() == ComparisonParameters.DataType.Discrete) {
List<Node> nodes = new ArrayList<>();
for (int i = 0; i < params.getNumVars(); i++) {
nodes.add(new DiscreteVariable("X" + (i + 1), 3));
}
trueDag = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
if (params.getDataType() == null) {
throw new IllegalArgumentException("Data type not set or inferred.");
}
if (params.getSampleSize() == -1) {
throw new IllegalArgumentException("Sample size not set.");
}
int[] tiers = new int[nodes.size()];
for (int i = 0; i < nodes.size(); i++) {
tiers[i] = i;
}
BayesPm pm = new BayesPm(trueDag, 3, 3);
MlBayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM);
dataSet = im.simulateData(params.getSampleSize(), false, tiers);
} else {
throw new IllegalArgumentException("Unrecognized data type.");
}
if (dataSet == null) {
throw new IllegalArgumentException("No data set.");
}
}
if (params.getIndependenceTest() == ComparisonParameters.IndependenceTestType.FisherZ) {
if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Continuous) {
throw new IllegalArgumentException("Data type previously set to something other than continuous.");
}
if (Double.isNaN(params.getAlpha())) {
throw new IllegalArgumentException("Alpha not set.");
}
test = new IndTestFisherZ(dataSet, params.getAlpha());
params.setDataType(ComparisonParameters.DataType.Continuous);
} else if (params.getIndependenceTest() == ComparisonParameters.IndependenceTestType.ChiSquare) {
if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Discrete) {
throw new IllegalArgumentException("Data type previously set to something other than discrete.");
}
if (Double.isNaN(params.getAlpha())) {
throw new IllegalArgumentException("Alpha not set.");
}
test = new IndTestChiSquare(dataSet, params.getAlpha());
params.setDataType(ComparisonParameters.DataType.Discrete);
}
if (params.getScore() == ScoreType.SemBic) {
if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Continuous) {
throw new IllegalArgumentException("Data type previously set to something other than continuous.");
}
if (Double.isNaN(params.getPenaltyDiscount())) {
throw new IllegalArgumentException("Penalty discount not set.");
}
SemBicScore semBicScore = new SemBicScore(new CovarianceMatrixOnTheFly(dataSet));
semBicScore.setPenaltyDiscount(params.getPenaltyDiscount());
score = semBicScore;
params.setDataType(ComparisonParameters.DataType.Continuous);
} else if (params.getScore() == ScoreType.BDeu) {
if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Discrete) {
throw new IllegalArgumentException("Data type previously set to something other than discrete.");
}
if (Double.isNaN(params.getSamplePrior())) {
throw new IllegalArgumentException("Sample prior not set.");
}
if (Double.isNaN(params.getStructurePrior())) {
throw new IllegalArgumentException("Structure prior not set.");
}
score = new BDeuScore(dataSet);
((BDeuScore) score).setSamplePrior(params.getSamplePrior());
((BDeuScore) score).setStructurePrior(params.getStructurePrior());
params.setDataType(ComparisonParameters.DataType.Discrete);
params.setDataType(ComparisonParameters.DataType.Discrete);
}
if (params.getAlgorithm() == null) {
throw new IllegalArgumentException("Algorithm not set.");
}
long time1 = System.currentTimeMillis();
if (params.getAlgorithm() == ComparisonParameters.Algorithm.PC) {
if (test == null)
throw new IllegalArgumentException("Test not set.");
Pc search = new Pc(test);
result.setResultGraph(search.search());
result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.CPC) {
if (test == null)
throw new IllegalArgumentException("Test not set.");
Cpc search = new Cpc(test);
result.setResultGraph(search.search());
result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCLocal) {
if (test == null)
throw new IllegalArgumentException("Test not set.");
PcLocal search = new PcLocal(test);
result.setResultGraph(search.search());
result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCStableMax) {
if (test == null)
throw new IllegalArgumentException("Test not set.");
PcStableMax search = new PcStableMax(test);
result.setResultGraph(search.search());
result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FGES) {
if (score == null)
throw new IllegalArgumentException("Score not set.");
Fges search = new Fges(score);
search.setFaithfulnessAssumed(params.isOneEdgeFaithfulnessAssumed());
result.setResultGraph(search.search());
result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FGES2) {
if (score == null)
throw new IllegalArgumentException("Score not set.");
Fges search = new Fges(score);
search.setFaithfulnessAssumed(params.isOneEdgeFaithfulnessAssumed());
result.setResultGraph(search.search());
result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FCI) {
if (test == null)
throw new IllegalArgumentException("Test not set.");
Fci search = new Fci(test);
result.setResultGraph(search.search());
result.setCorrectResult(new DagToPag(trueDag).convert());
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.GFCI) {
if (test == null)
throw new IllegalArgumentException("Test not set.");
GFci search = new GFci(test, score);
result.setResultGraph(search.search());
result.setCorrectResult(new DagToPag(trueDag).convert());
} else {
throw new IllegalArgumentException("Unrecognized algorithm.");
}
long time2 = System.currentTimeMillis();
long elapsed = time2 - time1;
result.setElapsed(elapsed);
result.setTrueDag(trueDag);
return result;
}
use of edu.cmu.tetrad.sem.LargeScaleSimulation in project tetrad by cmu-phil.
the class TestLargeSemSimulator method test1.
@Test
public void test1() {
List<Node> nodes = new ArrayList<>();
for (int i = 1; i <= 10; i++) nodes.add(new ContinuousVariable("X" + i));
Graph graph = GraphUtils.randomGraph(nodes, 0, 10, 5, 5, 5, false);
LargeScaleSimulation simulator = new LargeScaleSimulation(graph);
DataSet dataset = simulator.simulateDataFisher(1000);
assertEquals(1000, dataset.getNumRows());
}
use of edu.cmu.tetrad.sem.LargeScaleSimulation in project tetrad by cmu-phil.
the class LinearFisherModel method createData.
@Override
public void createData(Parameters parameters) {
boolean saveLatentVars = parameters.getBoolean("saveLatentVars");
dataSets = new ArrayList<>();
graphs = new ArrayList<>();
Graph graph = randomGraph.createGraph(parameters);
System.out.println("degree = " + GraphUtils.getDegree(graph));
for (int i = 0; i < parameters.getInt("numRuns"); i++) {
System.out.println("Simulating dataset #" + (i + 1));
if (shocks != null && shocks.size() > 0) {
parameters.set("numVars", shocks.get(0).getVariables().size());
}
if (parameters.getBoolean("differentGraphs") && i > 0) {
graph = randomGraph.createGraph(parameters);
}
if (shocks != null && shocks.size() > 0) {
graph.setNodes(shocks.get(0).getVariables());
}
graphs.add(graph);
int[] tiers = new int[graph.getNodes().size()];
for (int j = 0; j < tiers.length; j++) {
tiers[j] = j;
}
LargeScaleSimulation simulator = new LargeScaleSimulation(graph, graph.getNodes(), tiers);
simulator.setCoefRange(parameters.getDouble("coefLow"), parameters.getDouble("coefHigh"));
simulator.setVarRange(parameters.getDouble("varLow"), parameters.getDouble("varHigh"));
simulator.setIncludePositiveCoefs(parameters.getBoolean("includePositiveCoefs"));
simulator.setIncludeNegativeCoefs(parameters.getBoolean("includeNegativeCoefs"));
simulator.setBetaLeftValue(parameters.getDouble("betaLeftValue"));
simulator.setBetaRightValue(parameters.getDouble("betaRightValue"));
simulator.setSelfLoopCoef(parameters.getDouble("selfLoopCoef"));
simulator.setMeanRange(parameters.getDouble("meanLow"), parameters.getDouble("meanHigh"));
simulator.setErrorsNormal(parameters.getBoolean("errorsNormal"));
simulator.setVerbose(parameters.getBoolean("verbose"));
DataSet dataSet;
if (shocks == null) {
dataSet = simulator.simulateDataFisher(parameters.getInt("intervalBetweenShocks"), parameters.getInt("intervalBetweenRecordings"), parameters.getInt("sampleSize"), parameters.getDouble("fisherEpsilon"), saveLatentVars);
} else {
DataSet _shocks = (DataSet) shocks.get(i);
dataSet = simulator.simulateDataFisher(_shocks.getDoubleData().toArray(), parameters.getInt("intervalBetweenShocks"), parameters.getDouble("fisherEpsilon"));
}
double variance = parameters.getDouble("measurementVariance");
if (variance > 0) {
for (int k = 0; k < dataSet.getNumRows(); k++) {
for (int j = 0; j < dataSet.getNumColumns(); j++) {
double d = dataSet.getDouble(k, j);
double delta = RandomUtil.getInstance().nextNormal(0, Math.sqrt(variance));
dataSet.setDouble(k, j, d + delta);
}
}
}
dataSet.setName("" + (i + 1));
if (parameters.getDouble("percentDiscrete") > 0.0) {
if (this.shuffledOrder == null) {
List<Node> shuffledNodes = new ArrayList<>(dataSet.getVariables());
Collections.shuffle(shuffledNodes);
this.shuffledOrder = shuffledNodes;
}
Discretizer discretizer = new Discretizer(dataSet);
for (int k = 0; k < shuffledOrder.size() * parameters.getDouble("percentDiscrete") * 0.01; k++) {
discretizer.equalIntervals(dataSet.getVariable(shuffledOrder.get(k).getName()), parameters.getInt("numCategories"));
}
String name = dataSet.getName();
dataSet = discretizer.discretize();
dataSet.setName(name);
}
if (parameters.getBoolean("randomizeColumns")) {
dataSet = DataUtils.reorderColumns(dataSet);
}
dataSets.add(saveLatentVars ? dataSet : DataUtils.restrictToMeasured(dataSet));
}
}
use of edu.cmu.tetrad.sem.LargeScaleSimulation in project tetrad by cmu-phil.
the class Comparison2 method compare.
/**
* Simulates data from model parameterizing the given DAG, and runs the
* algorithm on that data, printing out error statistics.
*/
public static ComparisonResult compare(ComparisonParameters params) {
DataSet dataSet = null;
Graph trueDag = null;
IndependenceTest test = null;
Score score = null;
ComparisonResult result = new ComparisonResult(params);
if (params.isDataFromFile()) {
/**
* Set path to the data directory *
*/
String path = "/Users/dmalinsky/Documents/research/data/danexamples";
File dir = new File(path);
File[] files = dir.listFiles();
if (files == null) {
throw new NullPointerException("No files in " + path);
}
for (File file : files) {
if (file.getName().startsWith("graph") && file.getName().contains(String.valueOf(params.getGraphNum())) && file.getName().endsWith(".g.txt")) {
params.setGraphFile(file.getName());
trueDag = GraphUtils.loadGraphTxt(file);
break;
}
}
String trialGraph = String.valueOf(params.getGraphNum()).concat("-").concat(String.valueOf(params.getTrial())).concat(".dat.txt");
for (File file : files) {
if (file.getName().startsWith("graph") && file.getName().endsWith(trialGraph)) {
Path dataFile = Paths.get(path.concat("/").concat(file.getName()));
Delimiter delimiter = Delimiter.TAB;
if (params.getDataType() == ComparisonParameters.DataType.Continuous) {
try {
TabularDataReader dataReader = new ContinuousTabularDataFileReader(dataFile.toFile(), delimiter);
dataSet = (DataSet) DataConvertUtils.toDataModel(dataReader.readInData());
} catch (IOException e) {
e.printStackTrace();
}
params.setDataFile(file.getName());
break;
} else {
try {
TabularDataReader dataReader = new VerticalDiscreteTabularDataReader(dataFile.toFile(), delimiter);
dataSet = (DataSet) DataConvertUtils.toDataModel(dataReader.readInData());
} catch (IOException e) {
e.printStackTrace();
}
params.setDataFile(file.getName());
break;
}
}
}
System.out.println("current graph file = " + params.getGraphFile());
System.out.println("current data set file = " + params.getDataFile());
}
if (params.isNoData()) {
List<Node> nodes = new ArrayList<>();
for (int i = 0; i < params.getNumVars(); i++) {
nodes.add(new ContinuousVariable("X" + (i + 1)));
}
trueDag = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
/**
* added 5.25.16 for tsFCI *
*/
if (params.getAlgorithm() == ComparisonParameters.Algorithm.TsFCI) {
trueDag = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
trueDag = TimeSeriesUtils.graphToLagGraph(trueDag, 2);
System.out.println("Creating Time Lag Graph : " + trueDag);
}
/**
* ************************
*/
test = new IndTestDSep(trueDag);
score = new GraphScore(trueDag);
if (params.getAlgorithm() == null) {
throw new IllegalArgumentException("Algorithm not set.");
}
long time1 = System.currentTimeMillis();
if (params.getAlgorithm() == ComparisonParameters.Algorithm.PC) {
if (test == null) {
throw new IllegalArgumentException("Test not set.");
}
Pc search = new Pc(test);
result.setResultGraph(search.search());
result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.CPC) {
if (test == null) {
throw new IllegalArgumentException("Test not set.");
}
Cpc search = new Cpc(test);
result.setResultGraph(search.search());
result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCLocal) {
if (test == null) {
throw new IllegalArgumentException("Test not set.");
}
PcLocal search = new PcLocal(test);
result.setResultGraph(search.search());
result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCStableMax) {
if (test == null) {
throw new IllegalArgumentException("Test not set.");
}
PcStableMax search = new PcStableMax(test);
result.setResultGraph(search.search());
result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FGES) {
if (score == null) {
throw new IllegalArgumentException("Score not set.");
}
Fges search = new Fges(score);
// search.setFaithfulnessAssumed(params.isOneEdgeFaithfulnessAssumed());
result.setResultGraph(search.search());
result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FCI) {
if (test == null) {
throw new IllegalArgumentException("Test not set.");
}
Fci search = new Fci(test);
result.setResultGraph(search.search());
result.setCorrectResult(new DagToPag(trueDag).convert());
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.GFCI) {
if (test == null) {
throw new IllegalArgumentException("Test not set.");
}
GFci search = new GFci(test, score);
result.setResultGraph(search.search());
result.setCorrectResult(new DagToPag(trueDag).convert());
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.TsFCI) {
if (test == null) {
throw new IllegalArgumentException("Test not set.");
}
TsFci search = new TsFci(test);
IKnowledge knowledge = getKnowledge(trueDag);
search.setKnowledge(knowledge);
result.setResultGraph(search.search());
result.setCorrectResult(new TsDagToPag(trueDag).convert());
System.out.println("Correct result for trial = " + result.getCorrectResult());
System.out.println("Search result for trial = " + result.getResultGraph());
} else {
throw new IllegalArgumentException("Unrecognized algorithm.");
}
long time2 = System.currentTimeMillis();
long elapsed = time2 - time1;
result.setElapsed(elapsed);
result.setTrueDag(trueDag);
return result;
} else if (params.getDataFile() != null) {
// dataSet = loadDataFile(params.getDataFile());
System.out.println("Using data from file... ");
if (params.getGraphFile() == null) {
throw new IllegalArgumentException("True graph file not set.");
} else {
System.out.println("Using graph from file... ");
// trueDag = GraphUtils.loadGraph(File params.getGraphFile());
}
} else {
if (params.getNumVars() == -1) {
throw new IllegalArgumentException("Number of variables not set.");
}
if (params.getNumEdges() == -1) {
throw new IllegalArgumentException("Number of edges not set.");
}
if (params.getDataType() == ComparisonParameters.DataType.Continuous) {
List<Node> nodes = new ArrayList<>();
for (int i = 0; i < params.getNumVars(); i++) {
nodes.add(new ContinuousVariable("X" + (i + 1)));
}
trueDag = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
/**
* added 6.08.16 for tsFCI *
*/
if (params.getAlgorithm() == ComparisonParameters.Algorithm.TsFCI) {
trueDag = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
trueDag = TimeSeriesUtils.graphToLagGraph(trueDag, 2);
System.out.println("Creating Time Lag Graph : " + trueDag);
}
if (params.getDataType() == null) {
throw new IllegalArgumentException("Data type not set or inferred.");
}
if (params.getSampleSize() == -1) {
throw new IllegalArgumentException("Sample size not set.");
}
LargeScaleSimulation sim = new LargeScaleSimulation(trueDag);
/**
* added 6.08.16 for tsFCI *
*/
if (params.getAlgorithm() == ComparisonParameters.Algorithm.TsFCI) {
sim.setCoefRange(0.20, 0.50);
}
/**
* added 6.08.16 for tsFCI *
*/
if (params.getAlgorithm() == ComparisonParameters.Algorithm.TsFCI) {
// // System.out.println("Coefs matrix : " + sim.getCoefs());
// System.out.println(MatrixUtils.toString(sim.getCoefficientMatrix()));
// // System.out.println("dim = " + sim.getCoefs()[1][1]);
// boolean isStableTetradMatrix = allEigenvaluesAreSmallerThanOneInModulus(new TetradMatrix(sim.getCoefficientMatrix()));
// //this TetradMatrix needs to be the matrix of coefficients from the SEM!
// if (!isStableTetradMatrix) {
// System.out.println("%%%%%%%%%% WARNING %%%%%%%%% not a stable set of eigenvalues for data generation");
// System.out.println("Skipping this attempt!");
// sim.setCoefRange(0.2, 0.5);
// dataSet = sim.simulateDataAcyclic(params.getSampleSize());
// }
//
// /***************************/
boolean isStableTetradMatrix;
int attempt = 1;
int tierSize = params.getNumVars();
int[] sub = new int[tierSize];
int[] sub2 = new int[tierSize];
for (int i = 0; i < tierSize; i++) {
sub[i] = i;
sub2[i] = tierSize + i;
}
do {
dataSet = sim.simulateDataFisher(params.getSampleSize());
// System.out.println("Variable Nodes : " + sim.getVariableNodes());
// System.out.println(MatrixUtils.toString(sim.getCoefficientMatrix()));
TetradMatrix coefMat = new TetradMatrix(sim.getCoefficientMatrix());
TetradMatrix B = coefMat.getSelection(sub, sub);
TetradMatrix Gamma1 = coefMat.getSelection(sub2, sub);
TetradMatrix Gamma0 = TetradMatrix.identity(tierSize).minus(B);
TetradMatrix A1 = Gamma0.inverse().times(Gamma1);
// TetradMatrix B2 = coefMat.getSelection(sub2, sub2);
// System.out.println("B matrix : " + B);
// System.out.println("B2 matrix : " + B2);
// System.out.println("Gamma1 matrix : " + Gamma1);
// isStableTetradMatrix = allEigenvaluesAreSmallerThanOneInModulus(new TetradMatrix(sim.getCoefficientMatrix()));
isStableTetradMatrix = TimeSeriesUtils.allEigenvaluesAreSmallerThanOneInModulus(A1);
System.out.println("isStableTetradMatrix? : " + isStableTetradMatrix);
attempt++;
} while ((!isStableTetradMatrix) && attempt <= 5);
if (!isStableTetradMatrix) {
System.out.println("%%%%%%%%%% WARNING %%%%%%%% not a stable coefficient matrix, forcing coefs to [0.15,0.3]");
System.out.println("Made " + (attempt - 1) + " attempts to get stable matrix.");
sim.setCoefRange(0.15, 0.3);
dataSet = sim.simulateDataFisher(params.getSampleSize());
} else {
System.out.println("Coefficient matrix is stable.");
}
}
} else if (params.getDataType() == ComparisonParameters.DataType.Discrete) {
List<Node> nodes = new ArrayList<>();
for (int i = 0; i < params.getNumVars(); i++) {
nodes.add(new DiscreteVariable("X" + (i + 1), 3));
}
trueDag = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
if (params.getDataType() == null) {
throw new IllegalArgumentException("Data type not set or inferred.");
}
if (params.getSampleSize() == -1) {
throw new IllegalArgumentException("Sample size not set.");
}
int[] tiers = new int[nodes.size()];
for (int i = 0; i < nodes.size(); i++) {
tiers[i] = i;
}
BayesPm pm = new BayesPm(trueDag, 3, 3);
MlBayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM);
dataSet = im.simulateData(params.getSampleSize(), false, tiers);
} else {
throw new IllegalArgumentException("Unrecognized data type.");
}
if (dataSet == null) {
throw new IllegalArgumentException("No data set.");
}
}
if (params.getIndependenceTest() == ComparisonParameters.IndependenceTestType.FisherZ) {
if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Continuous) {
throw new IllegalArgumentException("Data type previously set to something other than continuous.");
}
if (Double.isNaN(params.getAlpha())) {
throw new IllegalArgumentException("Alpha not set.");
}
test = new IndTestFisherZ(dataSet, params.getAlpha());
params.setDataType(ComparisonParameters.DataType.Continuous);
} else if (params.getIndependenceTest() == ComparisonParameters.IndependenceTestType.ChiSquare) {
if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Discrete) {
throw new IllegalArgumentException("Data type previously set to something other than discrete.");
}
if (Double.isNaN(params.getAlpha())) {
throw new IllegalArgumentException("Alpha not set.");
}
test = new IndTestChiSquare(dataSet, params.getAlpha());
params.setDataType(ComparisonParameters.DataType.Discrete);
}
if (params.getScore() == ScoreType.SemBic) {
if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Continuous) {
throw new IllegalArgumentException("Data type previously set to something other than continuous.");
}
if (Double.isNaN(params.getPenaltyDiscount())) {
throw new IllegalArgumentException("Penalty discount not set.");
}
SemBicScore semBicScore = new SemBicScore(new CovarianceMatrixOnTheFly(dataSet));
semBicScore.setPenaltyDiscount(params.getPenaltyDiscount());
score = semBicScore;
params.setDataType(ComparisonParameters.DataType.Continuous);
} else if (params.getScore() == ScoreType.BDeu) {
if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Discrete) {
throw new IllegalArgumentException("Data type previously set to something other than discrete.");
}
if (Double.isNaN(params.getSamplePrior())) {
throw new IllegalArgumentException("Sample prior not set.");
}
if (Double.isNaN(params.getStructurePrior())) {
throw new IllegalArgumentException("Structure prior not set.");
}
score = new BDeuScore(dataSet);
((BDeuScore) score).setSamplePrior(params.getSamplePrior());
((BDeuScore) score).setStructurePrior(params.getStructurePrior());
params.setDataType(ComparisonParameters.DataType.Discrete);
params.setDataType(ComparisonParameters.DataType.Discrete);
}
if (params.getAlgorithm() == null) {
throw new IllegalArgumentException("Algorithm not set.");
}
long time1 = System.currentTimeMillis();
if (params.getAlgorithm() == ComparisonParameters.Algorithm.PC) {
if (test == null) {
throw new IllegalArgumentException("Test not set.");
}
Pc search = new Pc(test);
result.setResultGraph(search.search());
result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.CPC) {
if (test == null) {
throw new IllegalArgumentException("Test not set.");
}
Cpc search = new Cpc(test);
result.setResultGraph(search.search());
result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCLocal) {
if (test == null) {
throw new IllegalArgumentException("Test not set.");
}
PcLocal search = new PcLocal(test);
result.setResultGraph(search.search());
result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCStableMax) {
if (test == null) {
throw new IllegalArgumentException("Test not set.");
}
PcStableMax search = new PcStableMax(test);
result.setResultGraph(search.search());
result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FGES) {
if (score == null) {
throw new IllegalArgumentException("Score not set.");
}
Fges search = new Fges(score);
// search.setFaithfulnessAssumed(params.isOneEdgeFaithfulnessAssumed());
result.setResultGraph(search.search());
result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FCI) {
if (test == null) {
throw new IllegalArgumentException("Test not set.");
}
Fci search = new Fci(test);
result.setResultGraph(search.search());
result.setCorrectResult(new DagToPag(trueDag).convert());
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.GFCI) {
if (test == null) {
throw new IllegalArgumentException("Test not set.");
}
GFci search = new GFci(test, score);
result.setResultGraph(search.search());
result.setCorrectResult(new DagToPag(trueDag).convert());
} else if (params.getAlgorithm() == ComparisonParameters.Algorithm.TsFCI) {
if (test == null) {
throw new IllegalArgumentException("Test not set.");
}
TsFci search = new TsFci(test);
IKnowledge knowledge = getKnowledge(trueDag);
search.setKnowledge(knowledge);
result.setResultGraph(search.search());
result.setCorrectResult(new TsDagToPag(trueDag).convert());
} else {
throw new IllegalArgumentException("Unrecognized algorithm.");
}
long time2 = System.currentTimeMillis();
long elapsed = time2 - time1;
result.setElapsed(elapsed);
result.setTrueDag(trueDag);
return result;
}
use of edu.cmu.tetrad.sem.LargeScaleSimulation in project tetrad by cmu-phil.
the class PerformanceTests method testPcStable.
public void testPcStable(int numVars, double edgeFactor, int numCases, double alpha) {
int depth = -1;
init(new File("long.pcstable." + numVars + "." + edgeFactor + "." + alpha + ".txt"), "Tests performance of the PC Stable algorithm");
long time1 = System.currentTimeMillis();
Graph dag = makeDag(numVars, edgeFactor);
System.out.println("Graph done");
out.println("Graph done");
System.out.println("Starting simulation");
LargeScaleSimulation simulator = new LargeScaleSimulation(dag);
simulator.setOut(out);
DataSet data = simulator.simulateDataFisher(numCases);
System.out.println("Finishing simulation");
long time2 = System.currentTimeMillis();
out.println("Elapsed (simulating the data): " + (time2 - time1) + " ms");
System.out.println("Making covariance matrix");
// ICovarianceMatrix cov = new CovarianceMatrix(data);
ICovarianceMatrix cov = new CovarianceMatrixOnTheFly(data);
// ICovarianceMatrix cov = new CorrelationMatrix(new CovarianceMatrix(data));
// ICovarianceMatrix cov = DataUtils.covarianceParanormalDrton(data);
// ICovarianceMatrix cov = new CovarianceMatrix(DataUtils.covarianceParanormalWasserman(data));
// System.out.println(cov);
System.out.println("Covariance matrix done");
long time3 = System.currentTimeMillis();
out.println("Elapsed (calculating cov): " + (time3 - time2) + " ms");
// out.println(cov);
IndTestFisherZ test = new IndTestFisherZ(cov, alpha);
PcStable pcStable = new PcStable(test);
// pcStable.setVerbose(false);
// pcStable.setDepth(depth);
// pcStable.setOut(out);
Graph estPattern = pcStable.search();
// out.println(estPattern);
long time4 = System.currentTimeMillis();
// out.println("# Vars = " + numVars);
// out.println("# Edges = " + (int) (numVars * edgeFactor));
out.println("# Cases = " + numCases);
out.println("alpha = " + alpha);
out.println("depth = " + depth);
out.println("Elapsed (simulating the data): " + (time2 - time1) + " ms");
out.println("Elapsed (calculating cov): " + (time3 - time2) + " ms");
out.println("Elapsed (running PC-Stable) " + (time4 - time3) + " ms");
out.println("Total elapsed (cov + PC-Stable) " + (time4 - time2) + " ms");
final Graph truePattern = SearchGraphUtils.patternForDag(dag);
System.out.println("# edges in true pattern = " + truePattern.getNumEdges());
System.out.println("# edges in est pattern = " + estPattern.getNumEdges());
SearchGraphUtils.graphComparison(estPattern, truePattern, out);
out.println("seed = " + RandomUtil.getInstance().getSeed() + "L");
out.close();
}
Aggregations