use of edu.cmu.tetrad.graph.Node in project tetrad by cmu-phil.
the class IndTestRegression method getVariableNames.
/**
* @return the list of variable varNames.
*/
public List<String> getVariableNames() {
List<Node> variables = getVariables();
List<String> variableNames = new ArrayList<>();
for (Node variable : variables) {
variableNames.add(variable.getName());
}
return variableNames;
}
use of edu.cmu.tetrad.graph.Node in project tetrad by cmu-phil.
the class FgesMbRunner method execute.
// =================PUBLIC METHODS OVERRIDING ABSTRACT=================//
/**
* Executes the algorithm, producing (at least) a result workbench. Must be
* implemented in the extending class.
*/
public void execute() {
IKnowledge knowledge = (IKnowledge) getParams().get("knowledge", new Knowledge2());
String targetName = getParams().getString("targetName", null);
Object model = getDataModel();
if (model == null && getSourceGraph() != null) {
model = getSourceGraph();
}
if (model == null) {
throw new RuntimeException("Data source is unspecified. You may need to double click all your data boxes, \n" + "then click Save, and then right click on them and select Propagate Downstream. \n" + "The issue is that we use a seed to simulate from IM's, so your data is not saved to \n" + "file when you save the session. It can, however, be recreated from the saved seed.");
}
Parameters params = getParams();
Node target = null;
if (model instanceof Graph) {
GraphScore gesScore = new GraphScore((Graph) model);
target = gesScore.getVariable(targetName);
fges = new FgesMb(gesScore);
fges.setKnowledge((IKnowledge) getParams().get("knowledge", new Knowledge2()));
fges.setNumPatternsToStore(params.getInt("numPatternsToSave", 1));
fges.setVerbose(true);
} else if (model instanceof DataSet) {
DataSet dataSet = (DataSet) model;
if (dataSet.isContinuous()) {
SemBicScore score = new SemBicScore(new CovarianceMatrixOnTheFly((DataSet) model));
target = score.getVariable(targetName);
score.setPenaltyDiscount(params.getDouble("penaltyDiscount", 4));
fges = new FgesMb(score);
} else if (dataSet.isDiscrete()) {
// ((Parameters) getParameters()).getSamplePrior();
double samplePrior = 1;
// ((Parameters) getParameters()).getStructurePrior();
double structurePrior = 1;
BDeuScore score = new BDeuScore(dataSet);
score.setSamplePrior(samplePrior);
score.setStructurePrior(structurePrior);
target = score.getVariable(targetName);
fges = new FgesMb(score);
} else {
throw new IllegalStateException("Data set must either be continuous or discrete.");
}
} else if (model instanceof ICovarianceMatrix) {
SemBicScore gesScore = new SemBicScore((ICovarianceMatrix) model);
gesScore.setPenaltyDiscount(params.getDouble("alpha", 0.001));
gesScore.setPenaltyDiscount(params.getDouble("penaltyDiscount", 4));
target = gesScore.getVariable(targetName);
fges = new FgesMb(gesScore);
} else if (model instanceof DataModelList) {
DataModelList list = (DataModelList) model;
for (DataModel dataModel : list) {
if (!(dataModel instanceof DataSet || dataModel instanceof ICovarianceMatrix)) {
throw new IllegalArgumentException("Need a combination of all continuous data sets or " + "covariance matrices, or else all discrete data sets, or else a single initialGraph.");
}
}
if (allContinuous(list)) {
double penalty = getParams().getDouble("penaltyDiscount", 4);
if (params.getBoolean("firstNontriangular", false)) {
SemBicScoreImages fgesScore = new SemBicScoreImages(list);
fgesScore.setPenaltyDiscount(penalty);
target = fgesScore.getVariable(targetName);
fges = new FgesMb(fgesScore);
} else {
SemBicScoreImages fgesScore = new SemBicScoreImages(list);
fgesScore.setPenaltyDiscount(penalty);
target = fgesScore.getVariable(targetName);
fges = new FgesMb(fgesScore);
}
} else if (allDiscrete(list)) {
double structurePrior = getParams().getDouble("structurePrior", 1);
double samplePrior = getParams().getDouble("samplePrior", 1);
BdeuScoreImages fgesScore = new BdeuScoreImages(list);
fgesScore.setSamplePrior(samplePrior);
fgesScore.setStructurePrior(structurePrior);
target = fgesScore.getVariable(targetName);
if (params.getBoolean("firstNontriangular", false)) {
fges = new FgesMb(fgesScore);
} else {
fges = new FgesMb(fgesScore);
}
} else {
throw new IllegalArgumentException("Data must be either all discrete or all continuous.");
}
} else {
System.out.println("No viable input.");
}
// Graph searchGraph;
//
// if (true) {
// DataModel dataModel = getDataModelList().getSelectedModel();
// ICovarianceMatrix cov;
// Node target;
// FgesMb fges;
//
// if (dataModel instanceof DataSet) {
// DataSet dataSet = (DataSet) dataModel;
// target = dataSet.getVariable(targetName);
//
// if (dataSet.isContinuous()) {
// SemBicScore gesScore = new SemBicScore(new CovarianceMatrixOnTheFly((DataSet) dataModel),
// getParameters().getDepErrorsAlpha());
// fges = new FgesMb(gesScore, target);
// } else if (dataSet.isDiscrete()) {
// double structurePrior = 1;
// double samplePrior = getParameters().getDepErrorsAlpha();
// BDeuScore score = new BDeuScore(dataSet);
// score.setSamplePrior(samplePrior);
// score.setStructurePrior(structurePrior);
// fges = new FgesMb(score, target);
// } else {
// throw new IllegalStateException("Data set must either be continuous or discrete.");
// }
// } else if (dataModel instanceof ICovarianceMatrix) {
// cov = (ICovarianceMatrix) dataModel;
// SemBicScore score = new SemBicScore(cov,
// getParameters().getDepErrorsAlpha());
// target = cov.getVariable(targetName);
// fges = new FgesMb(score, target);
// } else {
// throw new IllegalArgumentException("Expecting a data set or a covariance matrix.");
// }
//
// fges.setVerbose(true);
// fges.setHeuristicSpeedup(true);
// searchGraph = fges.search();
// } else {
// Node target = getIndependenceTest().getVariable(targetName);
// System.out.println("Target = " + target);
//
// int depth = getParameters().getMaxDegree();
//
// ScoredIndTest fgesScore = new ScoredIndTest(getIndependenceTest());
// fgesScore.setParameter1(getParameters().getDepErrorsAlpha());
// FgesMb search = new FgesMb(fgesScore, target);
// search.setKnowledge(knowledge);
// search.setMaxDegree(depth);
// search.setVerbose(true);
// search.setHeuristicSpeedup(true);
// searchGraph = search.search();
// }
// if (getSourceGraph() != null) {
// GraphUtils.arrangeBySourceGraph(searchGraph, getSourceGraph());
// } else if (knowledge.isDefaultToKnowledgeLayout()) {
// SearchGraphUtils.arrangeByKnowledgeTiers(searchGraph, knowledge);
// } else {
// GraphUtils.circleLayout(searchGraph, 200, 200, 150);
// }
// fges.setInitialGraph(initialGraph);
fges.setKnowledge((IKnowledge) getParams().get("knowledge", new Knowledge2()));
fges.setNumPatternsToStore(params.getInt("numPatternsToSave", 1));
fges.setVerbose(true);
// fges.setHeuristicSpeedup(((Parameters) params.getIndTestParams()).isFaithfulnessAssumed());
fges.setMaxIndegree(params.getInt("depth", -1));
Graph graph = fges.search(target);
if (getSourceGraph() != null) {
GraphUtils.arrangeBySourceGraph(graph, getSourceGraph());
} else if (((IKnowledge) getParams().get("knowledge", new Knowledge2())).isDefaultToKnowledgeLayout()) {
SearchGraphUtils.arrangeByKnowledgeTiers(graph, (IKnowledge) getParams().get("knowledge", new Knowledge2()));
} else {
GraphUtils.circleLayout(graph, 200, 200, 150);
}
this.topGraphs = new ArrayList<>(fges.getTopGraphs());
if (topGraphs.isEmpty()) {
topGraphs.add(new ScoredGraph(getResultGraph(), Double.NaN));
}
setIndex(topGraphs.size() - 1);
setResultGraph(graph);
}
use of edu.cmu.tetrad.graph.Node in project tetrad by cmu-phil.
the class FofcRunner method execute.
// ===================PUBLIC METHODS OVERRIDING ABSTRACT================//
/**
* Executes the algorithm, producing (at least) a result workbench. Must be
* implemented in the extending class.
*/
public void execute() {
Graph searchGraph;
FindOneFactorClusters fofc;
Object source = getData();
TestType tetradTestType = (TestType) getParams().get("tetradTestType", TestType.TETRAD_WISHART);
if (tetradTestType == null || (!(tetradTestType == TestType.TETRAD_DELTA || tetradTestType == TestType.TETRAD_WISHART))) {
tetradTestType = TestType.TETRAD_DELTA;
getParams().set("tetradTestType", tetradTestType);
}
FindOneFactorClusters.Algorithm algorithm = (FindOneFactorClusters.Algorithm) getParams().get("fofcAlgorithm", FindOneFactorClusters.Algorithm.GAP);
if (source instanceof DataSet) {
fofc = new FindOneFactorClusters((DataSet) source, tetradTestType, algorithm, getParams().getDouble("alpha", 0.001));
searchGraph = fofc.search();
} else if (source instanceof CovarianceMatrix) {
fofc = new FindOneFactorClusters((CovarianceMatrix) source, tetradTestType, algorithm, getParams().getDouble("alpha", 0.001));
searchGraph = fofc.search();
} else {
throw new IllegalArgumentException("Unrecognized data type.");
}
if (semIm != null) {
List<List<Node>> partition = MimUtils.convertToClusters2(searchGraph);
List<String> variableNames = ReidentifyVariables.reidentifyVariables2(partition, trueGraph, (DataSet) getData());
rename(searchGraph, partition, variableNames);
// searchGraph = reidentifyVariables2(searchGraph, semIm);
} else if (trueGraph != null) {
List<List<Node>> partition = MimUtils.convertToClusters2(searchGraph);
List<String> variableNames = ReidentifyVariables.reidentifyVariables1(partition, trueGraph);
rename(searchGraph, partition, variableNames);
// searchGraph = reidentifyVariables(searchGraph, trueGraph);
}
System.out.println("Search Graph " + searchGraph);
try {
Graph graph = new MarshalledObject<>(searchGraph).get();
GraphUtils.circleLayout(graph, 200, 200, 150);
GraphUtils.fruchtermanReingoldLayout(graph);
setResultGraph(graph);
setClusters(MimUtils.convertToClusters(graph, getData().getVariables()));
} catch (Exception e) {
e.printStackTrace();
throw new RuntimeException(e);
}
}
use of edu.cmu.tetrad.graph.Node in project tetrad by cmu-phil.
the class FofcRunner method getVariables.
public List<Node> getVariables() {
List<Node> latents = new ArrayList<>();
for (String name : getVariableNames()) {
Node node = new ContinuousVariable(name);
node.setNodeType(NodeType.LATENT);
latents.add(node);
}
return latents;
}
use of edu.cmu.tetrad.graph.Node in project tetrad by cmu-phil.
the class FtfcRunner method execute.
// ===================PUBLIC METHODS OVERRIDING ABSTRACT================//
/**
* Executes the algorithm, producing (at least) a result workbench. Must be
* implemented in the extending class.
*/
public void execute() {
Graph searchGraph;
FindTwoFactorClusters ftfc;
Object source = getData();
TestType tetradTestType = (TestType) getParams().get("tetradTestType", TestType.TETRAD_WISHART);
if (tetradTestType == null || (!(tetradTestType == TestType.TETRAD_DELTA || tetradTestType == TestType.TETRAD_WISHART))) {
tetradTestType = TestType.TETRAD_DELTA;
getParams().set("tetradTestType", tetradTestType);
}
FindTwoFactorClusters.Algorithm algorithm = (FindTwoFactorClusters.Algorithm) getParams().get("ftfcAlgorithm", FindTwoFactorClusters.Algorithm.GAP);
if (source instanceof DataSet) {
ftfc = new FindTwoFactorClusters((DataSet) source, algorithm, getParams().getDouble("alpha", 0.001));
ftfc.setVerbose(true);
searchGraph = ftfc.search();
} else if (source instanceof CovarianceMatrix) {
ftfc = new FindTwoFactorClusters((CovarianceMatrix) source, algorithm, getParams().getDouble("alpha", 0.001));
ftfc.setVerbose(true);
searchGraph = ftfc.search();
} else {
throw new IllegalArgumentException("Unrecognized data type.");
}
if (semIm != null) {
List<List<Node>> partition = MimUtils.convertToClusters2(searchGraph);
List<String> variableNames = ReidentifyVariables.reidentifyVariables2(partition, trueGraph, (DataSet) getData());
rename(searchGraph, partition, variableNames);
// searchGraph = reidentifyVariables2(searchGraph, semIm);
} else if (trueGraph != null) {
List<List<Node>> partition = MimUtils.convertToClusters2(searchGraph);
List<String> variableNames = ReidentifyVariables.reidentifyVariables1(partition, trueGraph);
rename(searchGraph, partition, variableNames);
// searchGraph = reidentifyVariables(searchGraph, trueGraph);
}
System.out.println("Search Graph " + searchGraph);
try {
Graph graph = new MarshalledObject<>(searchGraph).get();
GraphUtils.circleLayout(graph, 200, 200, 150);
GraphUtils.fruchtermanReingoldLayout(graph);
setResultGraph(graph);
setClusters(MimUtils.convertToClusters(graph, getData().getVariables()));
} catch (Exception e) {
e.printStackTrace();
throw new RuntimeException(e);
}
}
Aggregations