use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.
the class TestSemXml method testRoundtrip2.
public void testRoundtrip2() {
SemIm semIm = sampleSemIm1();
Element element = SemXmlRenderer.getElement(semIm);
SemXmlParser parser = new SemXmlParser();
SemIm semIm2 = parser.getSemIm(element);
}
use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.
the class TestMatrixUtils method testImpiedCovar.
@Test
public void testImpiedCovar() {
List<Node> nodes = new ArrayList<>();
for (int i = 0; i < 10; i++) {
nodes.add(new ContinuousVariable("X" + (i + 1)));
}
Graph graph = new Dag(GraphUtils.randomGraph(nodes, 0, 10, 30, 15, 15, false));
SemPm pm = new SemPm(graph);
SemIm im = new SemIm(pm);
TetradMatrix err = im.getErrCovar();
TetradMatrix coef = im.getEdgeCoef();
TetradMatrix implied = MatrixUtils.impliedCovar(coef, err);
assertTrue(MatrixUtils.isPositiveDefinite(implied));
TetradMatrix corr = MatrixUtils.convertCovToCorr(new TetradMatrix(implied));
assertTrue(MatrixUtils.isPositiveDefinite(corr));
}
use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.
the class TestMimbuild2 method test1.
@Test
public void test1() {
RandomUtil.getInstance().setSeed(49283494L);
for (int r = 0; r < 1; r++) {
Graph mim = DataGraphUtils.randomSingleFactorModel(5, 5, 6, 0, 0, 0);
Graph mimStructure = structure(mim);
Parameters params = new Parameters();
params.set("coefLow", .5);
params.set("coefHigh", 1.5);
SemPm pm = new SemPm(mim);
SemIm im = new SemIm(pm, params);
DataSet data = im.simulateData(300, false);
String algorithm = "FOFC";
Graph searchGraph;
List<List<Node>> partition;
if (algorithm.equals("FOFC")) {
FindOneFactorClusters fofc = new FindOneFactorClusters(data, TestType.TETRAD_WISHART, FindOneFactorClusters.Algorithm.GAP, 0.001);
searchGraph = fofc.search();
partition = fofc.getClusters();
} else if (algorithm.equals("BPC")) {
TestType testType = TestType.TETRAD_WISHART;
TestType purifyType = TestType.TETRAD_BASED;
BuildPureClusters bpc = new BuildPureClusters(data, 0.001, testType, purifyType);
searchGraph = bpc.search();
partition = MimUtils.convertToClusters2(searchGraph);
} else {
throw new IllegalStateException();
}
List<String> latentVarList = reidentifyVariables(mim, data, partition, 2);
// System.out.println(partition);
// System.out.println(latentVarList);
//
// System.out.println("True\n" + mimStructure);
Graph mimbuildStructure;
for (int mimbuildMethod : new int[] { 2 }) {
if (mimbuildMethod == 2) {
Mimbuild2 mimbuild = new Mimbuild2();
mimbuild.setAlpha(0.001);
mimbuild.setMinClusterSize(3);
mimbuildStructure = mimbuild.search(partition, latentVarList, new CovarianceMatrix(data));
int shd = SearchGraphUtils.structuralHammingDistance(mimStructure, mimbuildStructure);
assertEquals(7, shd);
} else if (mimbuildMethod == 3) {
// System.out.println("Mimbuild Trek\n");
MimbuildTrek mimbuild = new MimbuildTrek();
mimbuild.setAlpha(0.1);
mimbuild.setMinClusterSize(3);
mimbuildStructure = mimbuild.search(partition, latentVarList, new CovarianceMatrix(data));
// ICovarianceMatrix latentcov = mimbuild.getLatentsCov();
// System.out.println("\nCovariance over the latents");
// System.out.println(latentcov);
// System.out.println("Estimated\n" + mimbuildStructure);
int shd = SearchGraphUtils.structuralHammingDistance(mimStructure, mimbuildStructure);
// System.out.println("SHD = " + shd);
// System.out.println();
assertEquals(3, shd);
} else {
throw new IllegalStateException();
}
}
}
}
use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.
the class SemUpdaterEditor method getUpdatePanel.
private Box getUpdatePanel() {
final SemEvidence evidence = semUpdater.getEvidence();
focusTraversalOrder.clear();
Box b = Box.createVerticalBox();
Box b0 = Box.createHorizontalBox();
b0.add(new JLabel("<html>" + "In the list below, specify values for variables you have evidence " + "<br>for. Click the 'Do Update Now' button to view updated model."));
b0.add(Box.createHorizontalGlue());
b.add(b0);
b.add(Box.createVerticalStrut(10));
Box d = Box.createHorizontalBox();
d.add(new JLabel("Variable = value"));
d.add(Box.createHorizontalGlue());
d.add(new JLabel("Manipulated"));
b.add(d);
for (int i = 0; i < evidence.getNumNodes(); i++) {
Box c = Box.createHorizontalBox();
SemIm semIm = evidence.getSemIm();
Node node = semIm.getVariableNodes().get(i);
String name = node.getName();
JLabel label = new JLabel(name + " = ") {
private static final long serialVersionUID = 820570350956700782L;
@Override
public Dimension getMaximumSize() {
return getPreferredSize();
}
};
c.add(label);
double mean = evidence.getProposition().getValue(i);
final DoubleTextField field = new DoubleTextField(mean, 5, NumberFormatUtil.getInstance().getNumberFormat());
field.setFilter((value, oldValue) -> {
try {
final int nodeIndex = labels.get(field);
if (Double.isNaN(value) && evidence.isManipulated(nodeIndex)) {
throw new IllegalArgumentException();
}
evidence.getProposition().setValue(nodeIndex, value);
SemIm updatedSem = semUpdater.getUpdatedSemIm();
semImEditor.displaySemIm(updatedSem, semImEditor.getTabSelectionIndex(), semImEditor.getMatrixSelection());
return value;
} catch (IllegalArgumentException e) {
return oldValue;
}
});
labels.put(field, i);
variablesToTextFields.put(i, field);
focusTraversalOrder.add(field);
c.add(field);
c.add(Box.createHorizontalStrut(2));
c.add(Box.createHorizontalGlue());
JCheckBox checkbox = new JCheckBox() {
private static final long serialVersionUID = -3808843047563493212L;
@Override
public Dimension getMaximumSize() {
return getPreferredSize();
}
};
checkbox.setSelected(evidence.isManipulated(i));
checkBoxesToVariables.put(checkbox, i);
variablesToCheckboxes.put(i, checkbox);
checkbox.addActionListener((e) -> {
JCheckBox chkbox = (JCheckBox) e.getSource();
boolean selected = chkbox.isSelected();
Integer o = checkBoxesToVariables.get(chkbox);
// If no value has been set for this variable, set it to
// the mean.
double value = evidence.getProposition().getValue(o);
if (Double.isNaN(value)) {
DoubleTextField dblTxtField = variablesToTextFields.get(o);
SemIm semIM = semUpdater.getSemIm();
Node varNode = semIM.getVariableNodes().get(o);
double semIMMean = semIM.getMean(varNode);
dblTxtField.setValue(semIMMean);
}
semUpdater.getEvidence().setManipulated(o, selected);
SemIm updatedSem = semUpdater.getUpdatedSemIm();
semImEditor.displaySemIm(updatedSem, semImEditor.getTabSelectionIndex(), semImEditor.getMatrixSelection());
});
checkbox.setBackground(Color.WHITE);
checkbox.setBorder(null);
c.add(checkbox);
c.setMaximumSize(new Dimension(1000, 30));
b.add(c);
}
b.add(Box.createVerticalGlue());
Box b2 = Box.createHorizontalBox();
b2.add(Box.createHorizontalGlue());
JButton button = new JButton("Do Update Now");
button.addActionListener((e) -> {
SemIm updatedSem = semUpdater.getUpdatedSemIm();
semImEditor.displaySemIm(updatedSem, semImEditor.getTabSelectionIndex(), semImEditor.getMatrixSelection());
semUpdater.setEvidence(new SemEvidence(updatedSem));
});
b2.add(button);
b.add(b2);
b.setBorder(new EmptyBorder(5, 5, 5, 5));
setFocusTraversalPolicy(new FocusTraversalPolicy() {
@Override
public Component getComponentAfter(Container focusCycleRoot, Component aComponent) {
int index = focusTraversalOrder.indexOf(aComponent);
int size = focusTraversalOrder.size();
if (index != -1) {
return focusTraversalOrder.get((index + 1) % size);
} else {
return getFirstComponent(focusCycleRoot);
}
}
@Override
public Component getComponentBefore(Container focusCycleRoot, Component aComponent) {
int index = focusTraversalOrder.indexOf(aComponent);
int size = focusTraversalOrder.size();
if (index != -1) {
return focusTraversalOrder.get((index - 1) % size);
} else {
return getFirstComponent(focusCycleRoot);
}
}
@Override
public Component getFirstComponent(Container focusCycleRoot) {
return focusTraversalOrder.getFirst();
}
@Override
public Component getLastComponent(Container focusCycleRoot) {
return focusTraversalOrder.getLast();
}
@Override
public Component getDefaultComponent(Container focusCycleRoot) {
return getFirstComponent(focusCycleRoot);
}
});
setFocusCycleRoot(true);
return b;
}
use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.
the class PurifyScoreBased method gaussianEM.
private double gaussianEM(SemGraph semdag, SemIm initialSemIm) {
double score, newScore = -Double.MAX_VALUE, bestScore = -Double.MAX_VALUE;
SemPm semPm = new SemPm(semdag);
semdag.setShowErrorTerms(true);
for (int p = 0; p < numObserved; p++) {
for (int q = 0; q < numObserved; q++) {
this.bestCyy[p][q] = this.Cyy[p][q];
}
if (this.Cyz != null) {
for (int q = 0; q < numLatent; q++) {
this.bestCyz[p][q] = this.Cyz[p][q];
}
}
}
if (this.Czz != null) {
for (int p = 0; p < numLatent; p++) {
for (int q = 0; q < numLatent; q++) {
this.bestCzz[p][q] = this.Czz[p][q];
}
}
}
initializeGaussianEM(semdag);
for (int i = 0; i < 3; i++) {
System.out.println("--Trial " + i);
SemIm semIm;
if (i == 0 && initialSemIm != null) {
semIm = initialSemIm;
} else {
semIm = new SemIm(semPm);
semIm.setCovMatrix(this.covarianceMatrix);
}
do {
score = newScore;
gaussianExpectation(semIm);
newScore = gaussianMaximization(semIm);
if (newScore == -Double.MAX_VALUE) {
break;
}
} while (Math.abs(score - newScore) > 1.E-3);
System.out.println(newScore);
if (newScore > bestScore && !Double.isInfinite(newScore)) {
bestScore = newScore;
for (int p = 0; p < numObserved; p++) {
for (int q = 0; q < numObserved; q++) {
this.bestCyy[p][q] = this.Cyy[p][q];
}
for (int q = 0; q < numLatent; q++) {
this.bestCyz[p][q] = this.Cyz[p][q];
}
}
for (int p = 0; p < numLatent; p++) {
for (int q = 0; q < numLatent; q++) {
this.bestCzz[p][q] = this.Czz[p][q];
}
}
}
}
for (int p = 0; p < numObserved; p++) {
for (int q = 0; q < numObserved; q++) {
this.Cyy[p][q] = this.bestCyy[p][q];
}
for (int q = 0; q < numLatent; q++) {
this.Cyz[p][q] = this.bestCyz[p][q];
}
}
for (int p = 0; p < numLatent; p++) {
for (int q = 0; q < numLatent; q++) {
this.Czz[p][q] = this.bestCzz[p][q];
}
}
if (Double.isInfinite(bestScore)) {
System.out.println("* * Warning: Heywood case in this step");
return -Double.MAX_VALUE;
}
// System.exit(0);
return bestScore;
}
Aggregations