use of edu.cmu.tetrad.util.Parameters in project tetrad by cmu-phil.
the class FgesRunner method getParamSettings.
@Override
public Map<String, String> getParamSettings() {
super.getParamSettings();
Parameters params = getParams();
paramSettings.put("Penalty Discount", new DecimalFormat("0.0").format(params.getDouble("penaltyDiscount", 4)));
return paramSettings;
}
use of edu.cmu.tetrad.util.Parameters in project tetrad by cmu-phil.
the class ForbiddenGraphModel method createKnowledge.
private void createKnowledge(Parameters params) {
IKnowledge knwl = getKnowledge();
if (knwl == null) {
return;
}
knwl.clear();
List<String> varNames = getVarNames();
getKnowledgeBoxInput().getVariableNames().stream().filter(e -> !e.startsWith("E_")).forEach(e -> {
varNames.add(e);
knwl.addVariable(e);
});
if (resultGraph == null) {
throw new NullPointerException("I couldn't find a parent graph.");
}
List<Node> nodes = resultGraph.getNodes();
int numOfNodes = nodes.size();
for (int i = 0; i < numOfNodes; i++) {
for (int j = i + 1; j < numOfNodes; j++) {
Node n1 = nodes.get(i);
Node n2 = nodes.get(j);
if (n1.getName().startsWith("E_") || n2.getName().startsWith("E_")) {
continue;
}
Edge edge = resultGraph.getEdge(n1, n2);
if (edge != null && edge.isDirected()) {
knwl.setForbidden(edge.getNode2().getName(), edge.getNode1().getName());
}
}
}
}
use of edu.cmu.tetrad.util.Parameters in project tetrad by cmu-phil.
the class GFciRunner method getIndependenceTest.
public IndependenceTest getIndependenceTest() {
Object dataModel = getDataModel();
if (dataModel == null) {
dataModel = getSourceGraph();
}
Parameters params = getParams();
IndTestType testType;
if (getParams() instanceof Parameters) {
Parameters _params = params;
testType = (IndTestType) _params.get("indTestType", IndTestType.FISHER_Z);
} else {
Parameters _params = params;
testType = (IndTestType) _params.get("indTestType", IndTestType.FISHER_Z);
}
return new IndTestChooser().getTest(dataModel, params, testType);
}
use of edu.cmu.tetrad.util.Parameters in project tetrad by cmu-phil.
the class GFciRunner method execute.
// =================PUBLIC METHODS OVERRIDING ABSTRACT=================//
/**
* Executes the algorithm, producing (at least) a result workbench. Must be
* implemented in the extending class.
*/
// public void execute() {
// IKnowledge knowledge = getParameters().getKnowledge();
// Parameters searchParams = getParameters();
//
// Parameters params = (Parameters) searchParams;
//
// Graph graph;
//
// if (getIndependenceTest() instanceof IndTestDSep) {
// GFci gfci = new GFci(getIndependenceTest());
// graph = gfci.search();
// } else {
// GFci fci = new GFci(getIndependenceTest());
// fci.setKnowledge(knowledge);
// fci.setCompleteRuleSetUsed(params.isCompleteRuleSetUsed());
// fci.setMaxPathLength(params.getMaxReachablePathLength());
// fci.setMaxIndegree(params.getMaxIndegree());
// double penaltyDiscount = params.getPenaltyDiscount();
//
// fci.setCorrErrorsAlpha(penaltyDiscount);
// fci.setSamplePrior(params.getSamplePrior());
// fci.setStructurePrior(params.getStructurePrior());
// fci.setCompleteRuleSetUsed(false);
// fci.setHeuristicSpeedup(params.isFaithfulnessAssumed());
// graph = fci.search();
// }
//
// if (getSourceGraph() != null) {
// GraphUtils.arrangeBySourceGraph(graph, getSourceGraph());
// } else if (knowledge.isDefaultToKnowledgeLayout()) {
// SearchGraphUtils.arrangeByKnowledgeTiers(graph, knowledge);
// } else {
// GraphUtils.circleLayout(graph, 200, 200, 150);
// }
//
// setResultGraph(graph);
// }
/**
* Executes the algorithm, producing (at least) a result workbench. Must be
* implemented in the extending class.
*/
public void execute() {
Object model = getDataModel();
if (model == null && getSourceGraph() != null) {
model = getSourceGraph();
}
if (model == null) {
throw new RuntimeException("Data source is unspecified. You may need to double click all your data boxes, \n" + "then click Save, and then right click on them and select Propagate Downstream. \n" + "The issue is that we use a seed to simulate from IM's, so your data is not saved to \n" + "file when you save the session. It can, however, be recreated from the saved seed.");
}
Parameters params = getParams();
double penaltyDiscount = params.getDouble("penaltyDiscount", 4);
if (model instanceof Graph) {
IndependenceTest test = new IndTestDSep((Graph) model);
GraphScore gesScore = new GraphScore((Graph) model);
gfci = new GFci(test, gesScore);
gfci.setKnowledge((IKnowledge) getParams().get("knowledge", new Knowledge2()));
gfci.setVerbose(true);
} else {
if (model instanceof DataSet) {
DataSet dataSet = (DataSet) model;
if (dataSet.isContinuous()) {
IndependenceTest test = new IndTestFisherZ(new CovarianceMatrixOnTheFly((DataSet) model), 0.001);
SemBicScore gesScore = new SemBicScore(new CovarianceMatrixOnTheFly((DataSet) model));
gesScore.setPenaltyDiscount(penaltyDiscount);
// SemBicScore2 gesScore = new SemBicScore2(new CovarianceMatrixOnTheFly((DataSet) model));
// SemGpScore gesScore = new SemGpScore(new CovarianceMatrixOnTheFly((DataSet) model));
// SvrScore gesScore = new SvrScore((DataSet) model);
gesScore.setPenaltyDiscount(penaltyDiscount);
System.out.println("Score done");
gfci = new GFci(test, gesScore);
} else // else if (dataSet.isDiscrete()) {
// double samplePrior = ((Parameters) getParameters()).getSamplePrior();
// double structurePrior = ((Parameters) getParameters()).getStructurePrior();
// BDeuScore score = new BDeuScore(dataSet);
// score.setSamplePrior(samplePrior);
// score.setStructurePrior(structurePrior);
// gfci = new GFci(score);
// }
{
throw new IllegalStateException("Data set must either be continuous or discrete.");
}
} else if (model instanceof ICovarianceMatrix) {
IndependenceTest test = new IndTestFisherZ((ICovarianceMatrix) model, 0.001);
SemBicScore gesScore = new SemBicScore((ICovarianceMatrix) model);
gesScore.setPenaltyDiscount(penaltyDiscount);
gesScore.setPenaltyDiscount(penaltyDiscount);
gfci = new GFci(test, gesScore);
} else if (model instanceof DataModelList) {
DataModelList list = (DataModelList) model;
for (DataModel dataModel : list) {
if (!(dataModel instanceof DataSet || dataModel instanceof ICovarianceMatrix)) {
throw new IllegalArgumentException("Need a combination of all continuous data sets or " + "covariance matrices, or else all discrete data sets, or else a single initialGraph.");
}
}
if (list.size() != 1) {
throw new IllegalArgumentException("FGES takes exactly one data set, covariance matrix, or initialGraph " + "as input. For multiple data sets as input, use IMaGES.");
}
if (allContinuous(list)) {
double penalty = params.getDouble("penaltyDiscount", 4);
SemBicScoreImages fgesScore = new SemBicScoreImages(list);
IndTestScore test = new IndTestScore(fgesScore);
fgesScore.setPenaltyDiscount(penalty);
gfci = new GFci(test, fgesScore);
} else // else if (allDiscrete(list)) {
// double structurePrior = ((Parameters) getParameters()).getStructurePrior();
// double samplePrior = ((Parameters) getParameters()).getSamplePrior();
//
// BdeuScoreImages fgesScore = new BdeuScoreImages(list);
// fgesScore.setSamplePrior(samplePrior);
// fgesScore.setStructurePrior(structurePrior);
//
// gfci = new GFci(fgesScore);
// }
{
throw new IllegalArgumentException("Data must be either all discrete or all continuous.");
}
} else {
System.out.println("No viable input.");
}
}
// gfci.setInitialGraph(initialGraph);
// gfci.setKnowledge(getParameters().getKnowledge());
// gfci.setNumPatternsToStore(params.getNumPatternsToSave());
gfci.setVerbose(true);
// gfci.setHeuristicSpeedup(true);
// gfci.setMaxIndegree(3);
gfci.setFaithfulnessAssumed(params.getBoolean("faithfulnessAssumed", true));
Graph graph = gfci.search();
if (getSourceGraph() != null) {
GraphUtils.arrangeBySourceGraph(graph, getSourceGraph());
} else if (((IKnowledge) getParams().get("knowledge", new Knowledge2())).isDefaultToKnowledgeLayout()) {
SearchGraphUtils.arrangeByKnowledgeTiers(graph, (IKnowledge) getParams().get("knowledge", new Knowledge2()));
} else {
GraphUtils.circleLayout(graph, 200, 200, 150);
}
setResultGraph(graph);
// this.topGraphs = new ArrayList<>(gfci.getTopGraphs());
//
// if (topGraphs.isEmpty()) {
//
// topGraphs.add(new ScoredGraph(getResultGraph(), Double.NaN));
// }
//
// setIndex(topGraphs.size() - 1);
}
use of edu.cmu.tetrad.util.Parameters in project tetrad by cmu-phil.
the class TsGFciRunner method getIndependenceTest.
public IndependenceTest getIndependenceTest() {
Object dataModel = getDataModel();
if (dataModel == null) {
dataModel = getSourceGraph();
}
Parameters params = getParams();
IndTestType testType;
if (getParams() instanceof Parameters) {
Parameters _params = params;
testType = (IndTestType) _params.get("indTestType", IndTestType.FISHER_Z);
} else {
Parameters _params = params;
testType = (IndTestType) _params.get("indTestType", IndTestType.FISHER_Z);
}
return new IndTestChooser().getTest(dataModel, params, testType);
}
Aggregations