use of edu.cmu.tetrad.graph.Node in project tetrad by cmu-phil.
the class FtfcRunner method getVariables.
public List<Node> getVariables() {
List<Node> latents = new ArrayList<>();
for (String name : getVariableNames()) {
Node node = new ContinuousVariable(name);
node.setNodeType(NodeType.LATENT);
latents.add(node);
}
return latents;
}
use of edu.cmu.tetrad.graph.Node in project tetrad by cmu-phil.
the class GlassoRunner method execute.
// ===================PUBLIC METHODS OVERRIDING ABSTRACT================//
public void execute() {
Object dataModel = getDataModel();
Parameters params = getParams();
if (dataModel instanceof DataSet) {
DataSet dataSet = (DataSet) dataModel;
DoubleMatrix2D cov = new DenseDoubleMatrix2D(dataSet.getCovarianceMatrix().toArray());
Glasso glasso = new Glasso(cov);
glasso.setMaxit((int) params.get("maxit", 10000));
glasso.setIa(params.getBoolean("ia", false));
glasso.setIs(params.getBoolean("is", false));
glasso.setItr(params.getBoolean("itr", false));
glasso.setIpen(params.getBoolean("ipen", false));
glasso.setThr(params.getDouble("thr", 1e-4));
glasso.setRhoAllEqual(1.0);
Glasso.Result result = glasso.search();
TetradMatrix wwi = new TetradMatrix(result.getWwi().toArray());
List<Node> variables = dataSet.getVariables();
Graph resultGraph = new EdgeListGraph(variables);
for (int i = 0; i < variables.size(); i++) {
for (int j = i + 1; j < variables.size(); j++) {
if (wwi.get(i, j) != 0.0 && wwi.get(i, j) != 0.0) {
resultGraph.addUndirectedEdge(variables.get(i), variables.get(j));
}
}
}
setResultGraph(resultGraph);
}
}
use of edu.cmu.tetrad.graph.Node in project tetrad by cmu-phil.
the class SemEstimatorWrapper method serializableInstance.
// public SemEstimatorWrapper(DataWrapper dataWrapper,
// SemPmWrapper semPmWrapper,
// SemImWrapper semImWrapper,
// Parameters params) {
// if (dataWrapper == null) {
// throw new NullPointerException();
// }
//
// if (semPmWrapper == null) {
// throw new NullPointerException();
// }
//
// if (semImWrapper == null) {
// throw new NullPointerException();
// }
//
// DataSet dataSet =
// (DataSet) dataWrapper.getSelectedDataModel();
// SemPm semPm = semPmWrapper.getSemPm();
// SemIm semIm = semImWrapper.getSemIm();
//
// this.semEstimator = new SemEstimator(dataSet, semPm, getOptimizer());
// if (!degreesOfFreedomCheck(semPm)) return;
// this.semEstimator.setTrueSemIm(semIm);
// this.semEstimator.setNumRestarts(getParams().getInt("numRestarts", 1));
// this.semEstimator.estimate();
//
// this.params = params;
//
// log();
// }
/**
* Generates a simple exemplar of this class to test serialization.
*
* @see TetradSerializableUtils
*/
public static SemEstimatorWrapper serializableInstance() {
List<Node> variables = new LinkedList<>();
ContinuousVariable x = new ContinuousVariable("X");
variables.add(x);
DataSet dataSet = new ColtDataSet(10, variables);
for (int i = 0; i < dataSet.getNumRows(); i++) {
for (int j = 0; j < dataSet.getNumColumns(); j++) {
dataSet.setDouble(i, j, RandomUtil.getInstance().nextDouble());
}
}
Dag dag = new Dag();
dag.addNode(x);
SemPm pm = new SemPm(dag);
Parameters params1 = new Parameters();
return new SemEstimatorWrapper(dataSet, pm, params1);
}
use of edu.cmu.tetrad.graph.Node in project tetrad by cmu-phil.
the class BayesEstimatorWrapper method estimate.
private void estimate(DataSet dataSet, BayesPm bayesPm) {
Graph graph = bayesPm.getDag();
for (Object o : graph.getNodes()) {
Node node = (Node) o;
if (node.getNodeType() == NodeType.LATENT) {
throw new IllegalArgumentException("Estimation of Bayes IM's " + "with latents is not supported.");
}
}
if (DataUtils.containsMissingValue(dataSet)) {
throw new IllegalArgumentException("Please remove or impute missing values.");
}
try {
MlBayesEstimator estimator = new MlBayesEstimator();
this.bayesIm = estimator.estimate(bayesPm, dataSet);
} catch (ArrayIndexOutOfBoundsException e) {
e.printStackTrace();
throw new RuntimeException("Value assignments between Bayes PM " + "and discrete data set do not match.");
}
}
use of edu.cmu.tetrad.graph.Node in project tetrad by cmu-phil.
the class IdentifiabilityWrapper method setup.
// ===============================PRIVATE METHODS======================//
private void setup(BayesIm bayesIm, Parameters params) {
TetradLogger.getInstance().setConfigForClass(this.getClass());
this.params = params;
if (params.get("evidence", null) == null || ((Evidence) params.get("evidence", null)).isIncompatibleWith(bayesIm)) {
bayesUpdater = new Identifiability(bayesIm);
} else {
bayesUpdater = new Identifiability(bayesIm, (Evidence) params.get("evidence", null));
}
Node node = (Node) getParams().get("variable", null);
if (node != null) {
NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat();
TetradLogger.getInstance().log("info", "\nIdentifiability");
String nodeName = node.getName();
int nodeIndex = bayesIm.getNodeIndex(bayesIm.getNode(nodeName));
double[] priors = getBayesUpdater().calculatePriorMarginals(nodeIndex);
double[] marginals = getBayesUpdater().calculateUpdatedMarginals(nodeIndex);
TetradLogger.getInstance().log("details", "\nVariable = " + nodeName);
TetradLogger.getInstance().log("details", "\nEvidence:");
Evidence evidence = (Evidence) getParams().get("evidence", null);
Proposition proposition = evidence.getProposition();
for (int i = 0; i < proposition.getNumVariables(); i++) {
Node variable = proposition.getVariableSource().getVariables().get(i);
int category = proposition.getSingleCategory(i);
if (category != -1) {
TetradLogger.getInstance().log("details", "\t" + variable + " = " + category);
}
}
TetradLogger.getInstance().log("details", "\nCat.\tPrior\tMarginal");
for (int i = 0; i < priors.length; i++) {
TetradLogger.getInstance().log("details", category(evidence, nodeName, i) + "\t" + nf.format(priors[i]) + "\t" + nf.format(marginals[i]));
}
}
TetradLogger.getInstance().reset();
}
Aggregations