use of edu.cmu.tetrad.sem.SemPm in project tetrad by cmu-phil.
the class SemEstimatorEditor method reestimate.
private void reestimate() {
SemOptimizer optimizer;
String type = wrapper.getSemOptimizerType();
switch(type) {
case "Regression":
optimizer = new SemOptimizerRegression();
break;
case "EM":
optimizer = new SemOptimizerEm();
break;
case "Powell":
optimizer = new SemOptimizerPowell();
break;
case "Random Search":
optimizer = new SemOptimizerScattershot();
break;
case "RICF":
optimizer = new SemOptimizerRicf();
break;
default:
throw new IllegalArgumentException("Unexpected optimizer type: " + type);
}
int numRestarts = wrapper.getNumRestarts();
optimizer.setNumRestarts(numRestarts);
java.util.List<SemEstimator> estimators = wrapper.getMultipleResultList();
java.util.List<SemEstimator> newEstimators = new ArrayList<>();
estimators.forEach(estimator -> {
SemPm semPm = estimator.getSemPm();
DataSet dataSet = estimator.getDataSet();
ICovarianceMatrix covMatrix = estimator.getCovMatrix();
SemEstimator newEstimator;
if (dataSet != null) {
newEstimator = new SemEstimator(dataSet, semPm, optimizer);
newEstimator.setNumRestarts(numRestarts);
newEstimator.setScoreType(wrapper.getScoreType());
} else if (covMatrix != null) {
newEstimator = new SemEstimator(covMatrix, semPm, optimizer);
newEstimator.setNumRestarts(numRestarts);
newEstimator.setScoreType(wrapper.getScoreType());
} else {
throw new IllegalStateException("Only continuous rectangular" + " data sets and covariance matrices can be processed.");
}
newEstimator.estimate();
newEstimators.add(newEstimator);
});
wrapper.setSemEstimator(newEstimators.get(0));
wrapper.setMultipleResultList(newEstimators);
resetSemImEditor();
}
use of edu.cmu.tetrad.sem.SemPm in project tetrad by cmu-phil.
the class PatternFitEditor method setup.
// ============================ Private Methods =========================//
private void setup() {
JTabbedPane pane = new JTabbedPane(JTabbedPane.LEFT);
DataModelList data = comparison.getDataModelList();
List<BayesIm> bayesIms = comparison.getBayesIms();
List<SemPm> semPms = comparison.getSemPms();
if (bayesIms != null && semPms != null) {
throw new IllegalArgumentException("That's weird; both Bayes and SEM estimations were done. Please complain.");
}
if (bayesIms != null) {
for (int i = 0; i < bayesIms.size(); i++) {
BayesEstimatorEditor editor = new BayesEstimatorEditor(bayesIms.get(i), (DataSet) data.get(i));
JPanel panel = new JPanel();
JScrollPane scroll = new JScrollPane(editor);
scroll.setPreferredSize(new Dimension(900, 600));
panel.add(Box.createVerticalStrut(10));
Box box = Box.createHorizontalBox();
panel.add(box);
panel.add(Box.createVerticalStrut(10));
Box box1 = Box.createHorizontalBox();
box1.add(new JLabel("Graph Comparison: "));
box1.add(Box.createHorizontalGlue());
add(box1);
setLayout(new BorderLayout());
pane.add("" + (i + 1), scroll);
}
}
if (semPms != null) {
for (int i = 0; i < semPms.size(); i++) {
SemEstimatorEditor editor = new SemEstimatorEditor(semPms.get(i), (DataSet) data.get(i));
pane.add("" + (i + 1), editor);
}
}
add(pane);
}
use of edu.cmu.tetrad.sem.SemPm in project tetrad by cmu-phil.
the class TestDM method test2.
@Test
public void test2() {
// setting seed for debug.
RandomUtil.getInstance().setSeed(29483818483L);
Graph graph = emptyGraph(8);
graph.addDirectedEdge(new ContinuousVariable("X0"), new ContinuousVariable("X2"));
graph.addDirectedEdge(new ContinuousVariable("X0"), new ContinuousVariable("X3"));
graph.addDirectedEdge(new ContinuousVariable("X1"), new ContinuousVariable("X2"));
graph.addDirectedEdge(new ContinuousVariable("X1"), new ContinuousVariable("X3"));
graph.addDirectedEdge(new ContinuousVariable("X0"), new ContinuousVariable("X6"));
graph.addDirectedEdge(new ContinuousVariable("X0"), new ContinuousVariable("X7"));
graph.addDirectedEdge(new ContinuousVariable("X1"), new ContinuousVariable("X6"));
graph.addDirectedEdge(new ContinuousVariable("X1"), new ContinuousVariable("X7"));
graph.addDirectedEdge(new ContinuousVariable("X4"), new ContinuousVariable("X6"));
graph.addDirectedEdge(new ContinuousVariable("X4"), new ContinuousVariable("X7"));
graph.addDirectedEdge(new ContinuousVariable("X5"), new ContinuousVariable("X6"));
graph.addDirectedEdge(new ContinuousVariable("X5"), new ContinuousVariable("X7"));
SemPm pm = new SemPm(graph);
SemIm im = new SemIm(pm);
DataSet data = im.simulateData(100000, false);
DMSearch search = new DMSearch();
search.setInputs(new int[] { 0, 1, 4, 5 });
search.setOutputs(new int[] { 2, 3, 6, 7 });
search.setData(data);
search.setTrueInputs(search.getInputs());
Graph foundGraph = search.search();
print("Test Case 2");
Graph trueGraph = new EdgeListGraph();
trueGraph.addNode(new ContinuousVariable("X0"));
trueGraph.addNode(new ContinuousVariable("X1"));
trueGraph.addNode(new ContinuousVariable("X2"));
trueGraph.addNode(new ContinuousVariable("X3"));
trueGraph.addNode(new ContinuousVariable("X4"));
trueGraph.addNode(new ContinuousVariable("X5"));
trueGraph.addNode(new ContinuousVariable("X6"));
trueGraph.addNode(new ContinuousVariable("X7"));
trueGraph.addNode(new ContinuousVariable("L0"));
trueGraph.addNode(new ContinuousVariable("L1"));
trueGraph.addDirectedEdge(new ContinuousVariable("X0"), new ContinuousVariable("L0"));
trueGraph.addDirectedEdge(new ContinuousVariable("X1"), new ContinuousVariable("L0"));
trueGraph.addDirectedEdge(new ContinuousVariable("L0"), new ContinuousVariable("X2"));
trueGraph.addDirectedEdge(new ContinuousVariable("L0"), new ContinuousVariable("X3"));
// trueGraph.addDirectedEdge(new ContinuousVariable("L0"), new ContinuousVariable("X1"));
trueGraph.addDirectedEdge(new ContinuousVariable("X0"), new ContinuousVariable("L1"));
trueGraph.addDirectedEdge(new ContinuousVariable("X1"), new ContinuousVariable("L1"));
trueGraph.addDirectedEdge(new ContinuousVariable("X4"), new ContinuousVariable("L1"));
trueGraph.addDirectedEdge(new ContinuousVariable("X5"), new ContinuousVariable("L1"));
trueGraph.addDirectedEdge(new ContinuousVariable("L1"), new ContinuousVariable("X6"));
trueGraph.addDirectedEdge(new ContinuousVariable("L1"), new ContinuousVariable("X7"));
assertTrue(foundGraph.equals(trueGraph));
}
use of edu.cmu.tetrad.sem.SemPm in project tetrad by cmu-phil.
the class TestDM method rtest13.
@Ignore
public void rtest13() {
// setting seed for debug.
RandomUtil.getInstance().setSeed(29483818483L);
Graph graph = emptyGraph(12);
Node X0 = graph.getNode("X0");
Node X1 = graph.getNode("X1");
Node X2 = graph.getNode("X2");
Node X3 = graph.getNode("X3");
Node X4 = graph.getNode("X4");
Node X5 = graph.getNode("X5");
Node X6 = graph.getNode("X6");
Node X7 = graph.getNode("X7");
Node X8 = graph.getNode("X8");
Node X9 = graph.getNode("X9");
Node X10 = graph.getNode("X10");
Node X11 = graph.getNode("X11");
graph.addDirectedEdge(X0, X6);
graph.addDirectedEdge(X1, X6);
graph.addDirectedEdge(X1, X7);
graph.addDirectedEdge(X1, X8);
graph.addDirectedEdge(X2, X8);
graph.addDirectedEdge(X3, X8);
graph.addDirectedEdge(X3, X9);
graph.addDirectedEdge(X3, X10);
graph.addDirectedEdge(X3, X9);
graph.addDirectedEdge(X4, X10);
graph.addDirectedEdge(X4, X11);
graph.addDirectedEdge(X5, X11);
//
// graph.addDirectedEdge(X1, X8);
// graph.addDirectedEdge(X2, X6);
// graph.addDirectedEdge(X2, X7);
// graph.addDirectedEdge(X2, X8);
//
//
// graph.addDirectedEdge(X3, X8);
// graph.addDirectedEdge(X3, X7);
// graph.addDirectedEdge(X4, X8);
// graph.addDirectedEdge(X4, X7);
//
//
// graph.addDirectedEdge(X5, X8);
RandomUtil.getInstance().setSeed(29483818483L);
SemPm pm = new SemPm(graph);
SemIm im = new SemIm(pm);
DataSet data = im.simulateData(100000, false);
DMSearch search = new DMSearch();
search.setInputs(new int[] { 0, 1, 2, 3, 4, 5 });
search.setOutputs(new int[] { 6, 7, 8, 9, 10, 11 });
search.setData(data);
search.setTrueInputs(search.getInputs());
search.search();
print("");
print("" + search.getDmStructure());
print("graph.existsDirectedCycle: " + search.getDmStructure().latentStructToEdgeListGraph(search.getDmStructure()).existsDirectedCycle());
print("Graph structure: " + search);
assertTrue(true);
}
use of edu.cmu.tetrad.sem.SemPm in project tetrad by cmu-phil.
the class TestDM method rtest12.
@Ignore
public void rtest12() {
// setting seed for debug.
RandomUtil.getInstance().setSeed(29483818483L);
Graph graph = emptyGraph(9);
Node X0 = graph.getNode("X0");
Node X1 = graph.getNode("X1");
Node X2 = graph.getNode("X2");
Node X3 = graph.getNode("X3");
Node X4 = graph.getNode("X4");
Node X5 = graph.getNode("X5");
Node X6 = graph.getNode("X6");
Node X7 = graph.getNode("X7");
Node X8 = graph.getNode("X8");
graph.addDirectedEdge(X0, X6);
graph.addDirectedEdge(X0, X7);
graph.addDirectedEdge(X0, X8);
graph.addDirectedEdge(X1, X6);
graph.addDirectedEdge(X1, X7);
graph.addDirectedEdge(X1, X8);
graph.addDirectedEdge(X2, X6);
graph.addDirectedEdge(X2, X7);
graph.addDirectedEdge(X2, X8);
graph.addDirectedEdge(X3, X8);
graph.addDirectedEdge(X3, X7);
graph.addDirectedEdge(X4, X8);
graph.addDirectedEdge(X4, X7);
graph.addDirectedEdge(X5, X8);
RandomUtil.getInstance().setSeed(29483818483L);
SemPm pm = new SemPm(graph);
SemIm im = new SemIm(pm);
DataSet data = im.simulateData(100000, false);
DMSearch search = new DMSearch();
search.setInputs(new int[] { 0, 1, 2, 3, 4, 5 });
search.setOutputs(new int[] { 6, 7, 8 });
search.setData(data);
search.setTrueInputs(search.getInputs());
search.search();
print("");
print("" + search.getDmStructure());
print("graph.existsDirectedCycle: " + search.getDmStructure().latentStructToEdgeListGraph(search.getDmStructure()).existsDirectedCycle());
print("Graph structure: " + search);
assertTrue(true);
}
Aggregations