Search in sources :

Example 36 with SemIm

use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.

the class TestSemXml method testRoundtrip2.

public void testRoundtrip2() {
    SemIm semIm = sampleSemIm1();
    Element element = SemXmlRenderer.getElement(semIm);
    SemXmlParser parser = new SemXmlParser();
    SemIm semIm2 = parser.getSemIm(element);
}
Also used : SemXmlParser(edu.cmu.tetrad.sem.SemXmlParser) Element(nu.xom.Element) SemIm(edu.cmu.tetrad.sem.SemIm)

Example 37 with SemIm

use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.

the class TestMatrixUtils method testImpiedCovar.

@Test
public void testImpiedCovar() {
    List<Node> nodes = new ArrayList<>();
    for (int i = 0; i < 10; i++) {
        nodes.add(new ContinuousVariable("X" + (i + 1)));
    }
    Graph graph = new Dag(GraphUtils.randomGraph(nodes, 0, 10, 30, 15, 15, false));
    SemPm pm = new SemPm(graph);
    SemIm im = new SemIm(pm);
    TetradMatrix err = im.getErrCovar();
    TetradMatrix coef = im.getEdgeCoef();
    TetradMatrix implied = MatrixUtils.impliedCovar(coef, err);
    assertTrue(MatrixUtils.isPositiveDefinite(implied));
    TetradMatrix corr = MatrixUtils.convertCovToCorr(new TetradMatrix(implied));
    assertTrue(MatrixUtils.isPositiveDefinite(corr));
}
Also used : ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) Graph(edu.cmu.tetrad.graph.Graph) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) SemPm(edu.cmu.tetrad.sem.SemPm) TetradMatrix(edu.cmu.tetrad.util.TetradMatrix) Dag(edu.cmu.tetrad.graph.Dag) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test)

Example 38 with SemIm

use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.

the class TestMimbuild2 method test1.

@Test
public void test1() {
    RandomUtil.getInstance().setSeed(49283494L);
    for (int r = 0; r < 1; r++) {
        Graph mim = DataGraphUtils.randomSingleFactorModel(5, 5, 6, 0, 0, 0);
        Graph mimStructure = structure(mim);
        Parameters params = new Parameters();
        params.set("coefLow", .5);
        params.set("coefHigh", 1.5);
        SemPm pm = new SemPm(mim);
        SemIm im = new SemIm(pm, params);
        DataSet data = im.simulateData(300, false);
        String algorithm = "FOFC";
        Graph searchGraph;
        List<List<Node>> partition;
        if (algorithm.equals("FOFC")) {
            FindOneFactorClusters fofc = new FindOneFactorClusters(data, TestType.TETRAD_WISHART, FindOneFactorClusters.Algorithm.GAP, 0.001);
            searchGraph = fofc.search();
            partition = fofc.getClusters();
        } else if (algorithm.equals("BPC")) {
            TestType testType = TestType.TETRAD_WISHART;
            TestType purifyType = TestType.TETRAD_BASED;
            BuildPureClusters bpc = new BuildPureClusters(data, 0.001, testType, purifyType);
            searchGraph = bpc.search();
            partition = MimUtils.convertToClusters2(searchGraph);
        } else {
            throw new IllegalStateException();
        }
        List<String> latentVarList = reidentifyVariables(mim, data, partition, 2);
        // System.out.println(partition);
        // System.out.println(latentVarList);
        // 
        // System.out.println("True\n" + mimStructure);
        Graph mimbuildStructure;
        for (int mimbuildMethod : new int[] { 2 }) {
            if (mimbuildMethod == 2) {
                Mimbuild2 mimbuild = new Mimbuild2();
                mimbuild.setAlpha(0.001);
                mimbuild.setMinClusterSize(3);
                mimbuildStructure = mimbuild.search(partition, latentVarList, new CovarianceMatrix(data));
                int shd = SearchGraphUtils.structuralHammingDistance(mimStructure, mimbuildStructure);
                assertEquals(7, shd);
            } else if (mimbuildMethod == 3) {
                // System.out.println("Mimbuild Trek\n");
                MimbuildTrek mimbuild = new MimbuildTrek();
                mimbuild.setAlpha(0.1);
                mimbuild.setMinClusterSize(3);
                mimbuildStructure = mimbuild.search(partition, latentVarList, new CovarianceMatrix(data));
                // ICovarianceMatrix latentcov = mimbuild.getLatentsCov();
                // System.out.println("\nCovariance over the latents");
                // System.out.println(latentcov);
                // System.out.println("Estimated\n" + mimbuildStructure);
                int shd = SearchGraphUtils.structuralHammingDistance(mimStructure, mimbuildStructure);
                // System.out.println("SHD = " + shd);
                // System.out.println();
                assertEquals(3, shd);
            } else {
                throw new IllegalStateException();
            }
        }
    }
}
Also used : Mimbuild2(edu.cmu.tetrad.search.Mimbuild2) Parameters(edu.cmu.tetrad.util.Parameters) SemPm(edu.cmu.tetrad.sem.SemPm) SemIm(edu.cmu.tetrad.sem.SemIm) Test(org.junit.Test)

Example 39 with SemIm

use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.

the class SemUpdaterEditor method getUpdatePanel.

private Box getUpdatePanel() {
    final SemEvidence evidence = semUpdater.getEvidence();
    focusTraversalOrder.clear();
    Box b = Box.createVerticalBox();
    Box b0 = Box.createHorizontalBox();
    b0.add(new JLabel("<html>" + "In the list below, specify values for variables you have evidence " + "<br>for. Click the 'Do Update Now' button to view updated model."));
    b0.add(Box.createHorizontalGlue());
    b.add(b0);
    b.add(Box.createVerticalStrut(10));
    Box d = Box.createHorizontalBox();
    d.add(new JLabel("Variable = value"));
    d.add(Box.createHorizontalGlue());
    d.add(new JLabel("Manipulated"));
    b.add(d);
    for (int i = 0; i < evidence.getNumNodes(); i++) {
        Box c = Box.createHorizontalBox();
        SemIm semIm = evidence.getSemIm();
        Node node = semIm.getVariableNodes().get(i);
        String name = node.getName();
        JLabel label = new JLabel(name + " =  ") {

            private static final long serialVersionUID = 820570350956700782L;

            @Override
            public Dimension getMaximumSize() {
                return getPreferredSize();
            }
        };
        c.add(label);
        double mean = evidence.getProposition().getValue(i);
        final DoubleTextField field = new DoubleTextField(mean, 5, NumberFormatUtil.getInstance().getNumberFormat());
        field.setFilter((value, oldValue) -> {
            try {
                final int nodeIndex = labels.get(field);
                if (Double.isNaN(value) && evidence.isManipulated(nodeIndex)) {
                    throw new IllegalArgumentException();
                }
                evidence.getProposition().setValue(nodeIndex, value);
                SemIm updatedSem = semUpdater.getUpdatedSemIm();
                semImEditor.displaySemIm(updatedSem, semImEditor.getTabSelectionIndex(), semImEditor.getMatrixSelection());
                return value;
            } catch (IllegalArgumentException e) {
                return oldValue;
            }
        });
        labels.put(field, i);
        variablesToTextFields.put(i, field);
        focusTraversalOrder.add(field);
        c.add(field);
        c.add(Box.createHorizontalStrut(2));
        c.add(Box.createHorizontalGlue());
        JCheckBox checkbox = new JCheckBox() {

            private static final long serialVersionUID = -3808843047563493212L;

            @Override
            public Dimension getMaximumSize() {
                return getPreferredSize();
            }
        };
        checkbox.setSelected(evidence.isManipulated(i));
        checkBoxesToVariables.put(checkbox, i);
        variablesToCheckboxes.put(i, checkbox);
        checkbox.addActionListener((e) -> {
            JCheckBox chkbox = (JCheckBox) e.getSource();
            boolean selected = chkbox.isSelected();
            Integer o = checkBoxesToVariables.get(chkbox);
            // If no value has been set for this variable, set it to
            // the mean.
            double value = evidence.getProposition().getValue(o);
            if (Double.isNaN(value)) {
                DoubleTextField dblTxtField = variablesToTextFields.get(o);
                SemIm semIM = semUpdater.getSemIm();
                Node varNode = semIM.getVariableNodes().get(o);
                double semIMMean = semIM.getMean(varNode);
                dblTxtField.setValue(semIMMean);
            }
            semUpdater.getEvidence().setManipulated(o, selected);
            SemIm updatedSem = semUpdater.getUpdatedSemIm();
            semImEditor.displaySemIm(updatedSem, semImEditor.getTabSelectionIndex(), semImEditor.getMatrixSelection());
        });
        checkbox.setBackground(Color.WHITE);
        checkbox.setBorder(null);
        c.add(checkbox);
        c.setMaximumSize(new Dimension(1000, 30));
        b.add(c);
    }
    b.add(Box.createVerticalGlue());
    Box b2 = Box.createHorizontalBox();
    b2.add(Box.createHorizontalGlue());
    JButton button = new JButton("Do Update Now");
    button.addActionListener((e) -> {
        SemIm updatedSem = semUpdater.getUpdatedSemIm();
        semImEditor.displaySemIm(updatedSem, semImEditor.getTabSelectionIndex(), semImEditor.getMatrixSelection());
        semUpdater.setEvidence(new SemEvidence(updatedSem));
    });
    b2.add(button);
    b.add(b2);
    b.setBorder(new EmptyBorder(5, 5, 5, 5));
    setFocusTraversalPolicy(new FocusTraversalPolicy() {

        @Override
        public Component getComponentAfter(Container focusCycleRoot, Component aComponent) {
            int index = focusTraversalOrder.indexOf(aComponent);
            int size = focusTraversalOrder.size();
            if (index != -1) {
                return focusTraversalOrder.get((index + 1) % size);
            } else {
                return getFirstComponent(focusCycleRoot);
            }
        }

        @Override
        public Component getComponentBefore(Container focusCycleRoot, Component aComponent) {
            int index = focusTraversalOrder.indexOf(aComponent);
            int size = focusTraversalOrder.size();
            if (index != -1) {
                return focusTraversalOrder.get((index - 1) % size);
            } else {
                return getFirstComponent(focusCycleRoot);
            }
        }

        @Override
        public Component getFirstComponent(Container focusCycleRoot) {
            return focusTraversalOrder.getFirst();
        }

        @Override
        public Component getLastComponent(Container focusCycleRoot) {
            return focusTraversalOrder.getLast();
        }

        @Override
        public Component getDefaultComponent(Container focusCycleRoot) {
            return getFirstComponent(focusCycleRoot);
        }
    });
    setFocusCycleRoot(true);
    return b;
}
Also used : SemEvidence(edu.cmu.tetrad.sem.SemEvidence) DoubleTextField(edu.cmu.tetradapp.util.DoubleTextField) Node(edu.cmu.tetrad.graph.Node) JButton(javax.swing.JButton) JLabel(javax.swing.JLabel) FocusTraversalPolicy(java.awt.FocusTraversalPolicy) Box(javax.swing.Box) JCheckBox(javax.swing.JCheckBox) Dimension(java.awt.Dimension) JCheckBox(javax.swing.JCheckBox) Container(java.awt.Container) EmptyBorder(javax.swing.border.EmptyBorder) Component(java.awt.Component) SemIm(edu.cmu.tetrad.sem.SemIm)

Example 40 with SemIm

use of edu.cmu.tetrad.sem.SemIm in project tetrad by cmu-phil.

the class PurifyScoreBased method gaussianEM.

private double gaussianEM(SemGraph semdag, SemIm initialSemIm) {
    double score, newScore = -Double.MAX_VALUE, bestScore = -Double.MAX_VALUE;
    SemPm semPm = new SemPm(semdag);
    semdag.setShowErrorTerms(true);
    for (int p = 0; p < numObserved; p++) {
        for (int q = 0; q < numObserved; q++) {
            this.bestCyy[p][q] = this.Cyy[p][q];
        }
        if (this.Cyz != null) {
            for (int q = 0; q < numLatent; q++) {
                this.bestCyz[p][q] = this.Cyz[p][q];
            }
        }
    }
    if (this.Czz != null) {
        for (int p = 0; p < numLatent; p++) {
            for (int q = 0; q < numLatent; q++) {
                this.bestCzz[p][q] = this.Czz[p][q];
            }
        }
    }
    initializeGaussianEM(semdag);
    for (int i = 0; i < 3; i++) {
        System.out.println("--Trial " + i);
        SemIm semIm;
        if (i == 0 && initialSemIm != null) {
            semIm = initialSemIm;
        } else {
            semIm = new SemIm(semPm);
            semIm.setCovMatrix(this.covarianceMatrix);
        }
        do {
            score = newScore;
            gaussianExpectation(semIm);
            newScore = gaussianMaximization(semIm);
            if (newScore == -Double.MAX_VALUE) {
                break;
            }
        } while (Math.abs(score - newScore) > 1.E-3);
        System.out.println(newScore);
        if (newScore > bestScore && !Double.isInfinite(newScore)) {
            bestScore = newScore;
            for (int p = 0; p < numObserved; p++) {
                for (int q = 0; q < numObserved; q++) {
                    this.bestCyy[p][q] = this.Cyy[p][q];
                }
                for (int q = 0; q < numLatent; q++) {
                    this.bestCyz[p][q] = this.Cyz[p][q];
                }
            }
            for (int p = 0; p < numLatent; p++) {
                for (int q = 0; q < numLatent; q++) {
                    this.bestCzz[p][q] = this.Czz[p][q];
                }
            }
        }
    }
    for (int p = 0; p < numObserved; p++) {
        for (int q = 0; q < numObserved; q++) {
            this.Cyy[p][q] = this.bestCyy[p][q];
        }
        for (int q = 0; q < numLatent; q++) {
            this.Cyz[p][q] = this.bestCyz[p][q];
        }
    }
    for (int p = 0; p < numLatent; p++) {
        for (int q = 0; q < numLatent; q++) {
            this.Czz[p][q] = this.bestCzz[p][q];
        }
    }
    if (Double.isInfinite(bestScore)) {
        System.out.println("* * Warning: Heywood case in this step");
        return -Double.MAX_VALUE;
    }
    // System.exit(0);
    return bestScore;
}
Also used : SemPm(edu.cmu.tetrad.sem.SemPm) SemIm(edu.cmu.tetrad.sem.SemIm)

Aggregations

SemIm (edu.cmu.tetrad.sem.SemIm)81 SemPm (edu.cmu.tetrad.sem.SemPm)71 Test (org.junit.Test)46 DataSet (edu.cmu.tetrad.data.DataSet)28 ArrayList (java.util.ArrayList)28 Graph (edu.cmu.tetrad.graph.Graph)26 Node (edu.cmu.tetrad.graph.Node)19 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)16 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)16 SemEstimator (edu.cmu.tetrad.sem.SemEstimator)15 Dag (edu.cmu.tetrad.graph.Dag)10 DMSearch (edu.cmu.tetrad.search.DMSearch)9 StandardizedSemIm (edu.cmu.tetrad.sem.StandardizedSemIm)9 TetradMatrix (edu.cmu.tetrad.util.TetradMatrix)7 NumberFormat (java.text.NumberFormat)7 GraphNode (edu.cmu.tetrad.graph.GraphNode)5 IndependenceTest (edu.cmu.tetrad.search.IndependenceTest)4 DecimalFormat (java.text.DecimalFormat)4 List (java.util.List)4 CovarianceMatrix (edu.cmu.tetrad.data.CovarianceMatrix)3