use of edu.cmu.tetrad.util.Parameters in project tetrad by cmu-phil.
the class TsGFciRunner method execute.
// =================PUBLIC METHODS OVERRIDING ABSTRACT=================//
/**
* Executes the algorithm, producing (at least) a result workbench. Must be
* implemented in the extending class.
*/
// public void execute() {
// IKnowledge knowledge = getParameters().getKnowledge();
// Parameters searchParams = getParameters();
//
// Parameters params = (Parameters) searchParams;
//
// Graph graph;
//
// if (getIndependenceTest() instanceof IndTestDSep) {
// GFci gfci = new GFci(getIndependenceTest());
// graph = gfci.search();
// } else {
// GFci fci = new GFci(getIndependenceTest());
// fci.setKnowledge(knowledge);
// fci.setCompleteRuleSetUsed(params.isCompleteRuleSetUsed());
// fci.setMaxPathLength(params.getMaxReachablePathLength());
// fci.setMaxIndegree(params.getMaxIndegree());
// double penaltyDiscount = params.getPenaltyDiscount();
//
// fci.setCorrErrorsAlpha(penaltyDiscount);
// fci.setSamplePrior(params.getSamplePrior());
// fci.setStructurePrior(params.getStructurePrior());
// fci.setCompleteRuleSetUsed(false);
// fci.setFaithfulnessAssumed(params.isFaithfulnessAssumed());
// graph = fci.search();
// }
//
// if (getSourceGraph() != null) {
// GraphUtils.arrangeBySourceGraph(graph, getSourceGraph());
// } else if (knowledge.isDefaultToKnowledgeLayout()) {
// SearchGraphUtils.arrangeByKnowledgeTiers(graph, knowledge);
// } else {
// GraphUtils.circleLayout(graph, 200, 200, 150);
// }
//
// setResultGraph(graph);
// }
/**
* Executes the algorithm, producing (at least) a result workbench. Must be
* implemented in the extending class.
*/
public void execute() {
Object model = getDataModel();
if (model == null && getSourceGraph() != null) {
model = getSourceGraph();
}
if (model == null) {
throw new RuntimeException("Data source is unspecified. You may need to double click all your data boxes, \n" + "then click Save, and then right click on them and select Propagate Downstream. \n" + "The issue is that we use a seed to simulate from IM's, so your data is not saved to \n" + "file when you save the session. It can, however, be recreated from the saved seed.");
}
Parameters params = getParams();
double penaltyDiscount = params.getDouble("penaltyDiscount", 4);
if (model instanceof Graph) {
GraphScore gesScore = new GraphScore((Graph) model);
IndependenceTest test = new IndTestDSep((Graph) model);
gfci = new TsGFci(test, gesScore);
gfci.setKnowledge((IKnowledge) getParams().get("knowledge", new Knowledge2()));
gfci.setVerbose(true);
} else {
if (model instanceof DataSet) {
DataSet dataSet = (DataSet) model;
if (dataSet.isContinuous()) {
SemBicScore gesScore = new SemBicScore(new CovarianceMatrixOnTheFly((DataSet) model));
// SemBicScore2 gesScore = new SemBicScore2(new CovarianceMatrixOnTheFly((DataSet) model));
// SemGpScore gesScore = new SemGpScore(new CovarianceMatrixOnTheFly((DataSet) model));
// SvrScore gesScore = new SvrScore((DataSet) model);
gesScore.setPenaltyDiscount(penaltyDiscount);
System.out.println("Score done");
IndependenceTest test = new IndTestDSep((Graph) model);
gfci = new TsGFci(test, gesScore);
gfci.setKnowledge((IKnowledge) getParams().get("knowledge", new Knowledge2()));
} else // else if (dataSet.isDiscrete()) {
// double samplePrior = ((Parameters) getParameters()).getSamplePrior();
// double structurePrior = ((Parameters) getParameters()).getStructurePrior();
// BDeuScore score = new BDeuScore(dataSet);
// score.setSamplePrior(samplePrior);
// score.setStructurePrior(structurePrior);
// gfci = new GFci(score);
// }
{
throw new IllegalStateException("Data set must either be continuous or discrete.");
}
} else if (model instanceof ICovarianceMatrix) {
SemBicScore gesScore = new SemBicScore((ICovarianceMatrix) model);
gesScore.setPenaltyDiscount(penaltyDiscount);
gesScore.setPenaltyDiscount(penaltyDiscount);
IndependenceTest test = new IndTestDSep((Graph) model);
gfci = new TsGFci(test, gesScore);
gfci.setKnowledge((IKnowledge) getParams().get("knowledge", new Knowledge2()));
} else if (model instanceof DataModelList) {
DataModelList list = (DataModelList) model;
for (DataModel dataModel : list) {
if (!(dataModel instanceof DataSet || dataModel instanceof ICovarianceMatrix)) {
throw new IllegalArgumentException("Need a combination of all continuous data sets or " + "covariance matrices, or else all discrete data sets, or else a single initialGraph.");
}
}
if (list.size() != 1) {
throw new IllegalArgumentException("FGES takes exactly one data set, covariance matrix, or initialGraph " + "as input. For multiple data sets as input, use IMaGES.");
}
if (allContinuous(list)) {
double penalty = params.getDouble("penaltyDiscount", 4);
SemBicScoreImages fgesScore = new SemBicScoreImages(list);
fgesScore.setPenaltyDiscount(penalty);
IndependenceTest test = new IndTestDSep((Graph) model);
gfci = new TsGFci(test, fgesScore);
gfci.setKnowledge((IKnowledge) getParams().get("knowledge", new Knowledge2()));
} else // else if (allDiscrete(list)) {
// double structurePrior = ((Parameters) getParameters()).getStructurePrior();
// double samplePrior = ((Parameters) getParameters()).getSamplePrior();
//
// BdeuScoreImages fgesScore = new BdeuScoreImages(list);
// fgesScore.setSamplePrior(samplePrior);
// fgesScore.setStructurePrior(structurePrior);
//
// gfci = new GFci(fgesScore);
// }
{
throw new IllegalArgumentException("Data must be either all discrete or all continuous.");
}
} else {
System.out.println("No viable input.");
}
}
// gfci.setInitialGraph(initialGraph);
// gfci.setKnowledge(getParameters().getKnowledge());
// gfci.setNumPatternsToStore(params.getNumPatternsToSave());
gfci.setVerbose(true);
// gfci.setHeuristicSpeedup(true);
// gfci.setMaxIndegree(3);
gfci.setFaithfulnessAssumed(params.getBoolean("faithfulnessAssumed", true));
Graph graph = gfci.search();
if (getSourceGraph() != null) {
GraphUtils.arrangeBySourceGraph(graph, getSourceGraph());
} else if (((IKnowledge) getParams().get("knowledge", new Knowledge2())).isDefaultToKnowledgeLayout()) {
SearchGraphUtils.arrangeByKnowledgeTiers(graph, (IKnowledge) getParams().get("knowledge", new Knowledge2()));
} else {
GraphUtils.circleLayout(graph, 200, 200, 150);
}
setResultGraph(graph);
// this.topGraphs = new ArrayList<>(gfci.getTopGraphs());
//
// if (topGraphs.isEmpty()) {
//
// topGraphs.add(new ScoredGraph(getResultGraph(), Double.NaN));
// }
//
// setIndex(topGraphs.size() - 1);
}
use of edu.cmu.tetrad.util.Parameters in project tetrad by cmu-phil.
the class MeasurementSimulatorParams method serializableInstance.
/**
* Generates a simple exemplar of this class to test serialization.
*/
public static MeasurementSimulatorParams serializableInstance() {
MeasurementSimulatorParams params = new MeasurementSimulatorParams(new Parameters());
params.setHistory(new GeneHistory(BasalInitializer.serializableInstance(), BooleanGlassFunction.serializableInstance()));
return params;
}
use of edu.cmu.tetrad.util.Parameters in project tetrad by cmu-phil.
the class CcdRunner2 method getIndependenceTest.
public IndependenceTest getIndependenceTest() {
Object dataModel = getDataModel();
if (dataModel == null) {
dataModel = getSourceGraph();
}
Parameters params = getParams();
IndTestType testType = null;
Parameters _params = params;
testType = (IndTestType) _params.get("indTestType", IndTestType.FISHER_Z);
return new IndTestChooser().getTest(dataModel, params, testType);
}
use of edu.cmu.tetrad.util.Parameters in project tetrad by cmu-phil.
the class CfciRunner method execute.
// =================PUBLIC METHODS OVERRIDING ABSTRACT=================//
/**
* Executes the algorithm, producing (at least) a result workbench. Must be
* implemented in the extending class.
*/
public void execute() {
IKnowledge knowledge = (IKnowledge) getParams().get("knowledge", new Knowledge2());
Parameters searchParams = getParams();
Parameters params = searchParams;
Cfci cfci = new Cfci(getIndependenceTest());
cfci.setKnowledge(knowledge);
cfci.setDepth(params.getInt("depth", -1));
Graph graph = cfci.search();
if (getSourceGraph() != null) {
GraphUtils.arrangeBySourceGraph(graph, getSourceGraph());
} else if (knowledge.isDefaultToKnowledgeLayout()) {
SearchGraphUtils.arrangeByKnowledgeTiers(graph, knowledge);
} else {
GraphUtils.circleLayout(graph, 200, 200, 150);
}
setResultGraph(graph);
}
use of edu.cmu.tetrad.util.Parameters in project tetrad by cmu-phil.
the class TimeoutComparison method doRun.
private void doRun(List<AlgorithmSimulationWrapper> algorithmSimulationWrappers, List<AlgorithmWrapper> algorithmWrappers, List<SimulationWrapper> simulationWrappers, Statistics statistics, int numGraphTypes, double[][][][] allStats, Run run) {
System.out.println();
System.out.println("Run " + (run.getRunIndex() + 1));
System.out.println();
AlgorithmSimulationWrapper algorithmSimulationWrapper = algorithmSimulationWrappers.get(run.getAlgSimIndex());
AlgorithmWrapper algorithmWrapper = algorithmSimulationWrapper.getAlgorithmWrapper();
SimulationWrapper simulationWrapper = algorithmSimulationWrapper.getSimulationWrapper();
DataModel data = simulationWrapper.getDataModel(run.getRunIndex());
Graph trueGraph = simulationWrapper.getTrueGraph(run.getRunIndex());
System.out.println((run.getAlgSimIndex() + 1) + ". " + algorithmWrapper.getDescription() + " simulationWrapper: " + simulationWrapper.getDescription());
long start = System.currentTimeMillis();
Graph out;
try {
Algorithm algorithm = algorithmWrapper.getAlgorithm();
Simulation simulation = simulationWrapper.getSimulation();
if (algorithm instanceof HasKnowledge && simulation instanceof HasKnowledge) {
((HasKnowledge) algorithm).setKnowledge(((HasKnowledge) simulation).getKnowledge());
}
if (algorithmWrapper.getAlgorithm() instanceof ExternalAlgorithm) {
ExternalAlgorithm external = (ExternalAlgorithm) algorithmWrapper.getAlgorithm();
external.setSimulation(simulationWrapper.getSimulation());
external.setPath(resultsPath);
external.setSimIndex(simulationWrappers.indexOf(simulationWrapper));
}
if (algorithm instanceof MultiDataSetAlgorithm) {
List<Integer> indices = new ArrayList<>();
int numDataModels = simulationWrapper.getSimulation().getNumDataModels();
for (int i = 0; i < numDataModels; i++) {
indices.add(i);
}
Collections.shuffle(indices);
List<DataModel> dataModels = new ArrayList<>();
int randomSelectionSize = algorithmWrapper.getAlgorithmSpecificParameters().getInt("randomSelectionSize");
for (int i = 0; i < Math.min(numDataModels, randomSelectionSize); i++) {
dataModels.add(simulationWrapper.getSimulation().getDataModel(indices.get(i)));
}
Parameters _params = algorithmWrapper.getAlgorithmSpecificParameters();
out = ((MultiDataSetAlgorithm) algorithm).search(dataModels, _params);
} else {
DataModel dataModel = copyData ? data.copy() : data;
Parameters _params = algorithmWrapper.getAlgorithmSpecificParameters();
out = algorithm.search(dataModel, _params);
}
} catch (Exception e) {
System.out.println("Could not run " + algorithmWrapper.getDescription());
e.printStackTrace();
return;
}
int simIndex = simulationWrappers.indexOf(simulationWrapper) + 1;
int algIndex = algorithmWrappers.indexOf(algorithmWrapper) + 1;
long stop = System.currentTimeMillis();
long elapsed = stop - start;
saveGraph(resultsPath, out, run.getRunIndex(), simIndex, algIndex, algorithmWrapper, elapsed);
if (trueGraph != null) {
out = GraphUtils.replaceNodes(out, trueGraph.getNodes());
}
if (algorithmWrapper.getAlgorithm() instanceof ExternalAlgorithm) {
ExternalAlgorithm extAlg = (ExternalAlgorithm) algorithmWrapper.getAlgorithm();
extAlg.setSimIndex(simulationWrappers.indexOf(simulationWrapper));
extAlg.setSimulation(simulationWrapper.getSimulation());
extAlg.setPath(resultsPath);
elapsed = extAlg.getElapsedTime(data, simulationWrapper.getSimulationSpecificParameters());
}
Graph[] est = new Graph[numGraphTypes];
Graph comparisonGraph;
if (this.comparisonGraph == ComparisonGraph.true_DAG) {
comparisonGraph = new EdgeListGraph(trueGraph);
} else if (this.comparisonGraph == ComparisonGraph.Pattern_of_the_true_DAG) {
comparisonGraph = SearchGraphUtils.patternForDag(new EdgeListGraph(trueGraph));
} else if (this.comparisonGraph == ComparisonGraph.PAG_of_the_true_DAG) {
comparisonGraph = new DagToPag(new EdgeListGraph(trueGraph)).convert();
} else {
throw new IllegalArgumentException("Unrecognized graph type.");
}
// Graph comparisonGraph = trueGraph == null ? null : algorithmSimulationWrapper.getComparisonGraph(trueGraph);
est[0] = out;
graphTypeUsed[0] = true;
if (data.isMixed()) {
est[1] = getSubgraph(out, true, true, data);
est[2] = getSubgraph(out, true, false, data);
est[3] = getSubgraph(out, false, false, data);
graphTypeUsed[1] = true;
graphTypeUsed[2] = true;
graphTypeUsed[3] = true;
}
Graph[] truth = new Graph[numGraphTypes];
truth[0] = comparisonGraph;
if (data.isMixed() && comparisonGraph != null) {
truth[1] = getSubgraph(comparisonGraph, true, true, data);
truth[2] = getSubgraph(comparisonGraph, true, false, data);
truth[3] = getSubgraph(comparisonGraph, false, false, data);
}
if (comparisonGraph != null) {
for (int u = 0; u < numGraphTypes; u++) {
if (!graphTypeUsed[u]) {
continue;
}
int statIndex = -1;
for (Statistic _stat : statistics.getStatistics()) {
statIndex++;
if (_stat instanceof ParameterColumn) {
continue;
}
double stat;
if (_stat instanceof ElapsedTime) {
stat = elapsed / 1000.0;
} else {
stat = _stat.getValue(truth[u], est[u]);
}
allStats[u][run.getAlgSimIndex()][statIndex][run.getRunIndex()] = stat;
}
}
}
}
Aggregations