Search in sources :

Example 1 with SemGraph

use of edu.cmu.tetrad.graph.SemGraph in project tetrad by cmu-phil.

the class Ricf method ricf.

// =============================PUBLIC METHODS=========================//
public RicfResult ricf(SemGraph mag, ICovarianceMatrix covMatrix, double tolerance) {
    DoubleFactory2D factory = DoubleFactory2D.dense;
    Algebra algebra = new Algebra();
    DoubleMatrix2D S = new DenseDoubleMatrix2D(covMatrix.getMatrix().toArray());
    int p = covMatrix.getDimension();
    if (p == 1) {
        return new RicfResult(S, S, null, null, 1, Double.NaN, covMatrix);
    List<Node> nodes = new ArrayList<>();
    for (String name : covMatrix.getVariableNames()) {
    DoubleMatrix2D omega = factory.diagonal(factory.diagonal(S));
    DoubleMatrix2D B = factory.identity(p);
    int[] ug = ugNodes(mag, nodes);
    int[] ugComp = complement(p, ug);
    if (ug.length > 0) {
        List<Node> _ugNodes = new LinkedList<>();
        for (int i : ug) {
        Graph ugGraph = mag.subgraph(_ugNodes);
        ICovarianceMatrix ugCov = covMatrix.getSubmatrix(ug);
        DoubleMatrix2D lambdaInv = fitConGraph(ugGraph, ugCov, p + 1, tolerance).shat;
        omega.viewSelection(ug, ug).assign(lambdaInv);
    // Prepare lists of parents and spouses.
    int[][] pars = parentIndices(p, mag, nodes);
    int[][] spo = spouseIndices(p, mag, nodes);
    int i = 0;
    double _diff;
    while (true) {
        DoubleMatrix2D omegaOld = omega.copy();
        DoubleMatrix2D bOld = B.copy();
        for (int _v = 0; _v < p; _v++) {
            // Exclude the UG part.
            if (Arrays.binarySearch(ug, _v) >= 0) {
            int[] v = new int[] { _v };
            int[] vcomp = complement(p, v);
            int[] all = range(0, p - 1);
            int[] parv = pars[_v];
            int[] spov = spo[_v];
            DoubleMatrix2D a6 = B.viewSelection(v, parv);
            if (spov.length == 0) {
                if (parv.length != 0) {
                    if (i == 1) {
                        DoubleMatrix2D a1 = S.viewSelection(parv, parv);
                        DoubleMatrix2D a2 = S.viewSelection(v, parv);
                        DoubleMatrix2D a3 = algebra.inverse(a1);
                        DoubleMatrix2D a4 = algebra.mult(a2, a3);
                        DoubleMatrix2D a7 = S.viewSelection(parv, v);
                        DoubleMatrix2D a9 = algebra.mult(a6, a7);
                        DoubleMatrix2D a8 = S.viewSelection(v, v);
                        DoubleMatrix2D a8b = omega.viewSelection(v, v);
                        omega.viewSelection(v, v).assign(a9, PlusMult.plusMult(1));
            } else {
                if (parv.length != 0) {
                    DoubleMatrix2D oInv = new DenseDoubleMatrix2D(p, p);
                    DoubleMatrix2D a2 = omega.viewSelection(vcomp, vcomp);
                    DoubleMatrix2D a3 = algebra.inverse(a2);
                    oInv.viewSelection(vcomp, vcomp).assign(a3);
                    DoubleMatrix2D Z = algebra.mult(oInv.viewSelection(spov, vcomp), B.viewSelection(vcomp, all));
                    int lpa = parv.length;
                    int lspo = spov.length;
                    // Build XX
                    DoubleMatrix2D XX = new DenseDoubleMatrix2D(lpa + lspo, lpa + lspo);
                    int[] range1 = range(0, lpa - 1);
                    int[] range2 = range(lpa, lpa + lspo - 1);
                    // Upper left quadrant
                    XX.viewSelection(range1, range1).assign(S.viewSelection(parv, parv));
                    // Upper right quadrant
                    DoubleMatrix2D a11 = algebra.mult(S.viewSelection(parv, all), algebra.transpose(Z));
                    XX.viewSelection(range1, range2).assign(a11);
                    // Lower left quadrant
                    DoubleMatrix2D a12 = XX.viewSelection(range2, range1);
                    DoubleMatrix2D a13 = algebra.transpose(XX.viewSelection(range1, range2));
                    // Lower right quadrant
                    DoubleMatrix2D a14 = XX.viewSelection(range2, range2);
                    DoubleMatrix2D a15 = algebra.mult(Z, S);
                    DoubleMatrix2D a16 = algebra.mult(a15, algebra.transpose(Z));
                    // Build XY
                    DoubleMatrix1D YX = new DenseDoubleMatrix1D(lpa + lspo);
                    DoubleMatrix1D a17 = YX.viewSelection(range1);
                    DoubleMatrix1D a18 = S.viewSelection(v, parv).viewRow(0);
                    DoubleMatrix1D a19 = YX.viewSelection(range2);
                    DoubleMatrix2D a20 = S.viewSelection(v, all);
                    DoubleMatrix1D a21 = algebra.mult(a20, algebra.transpose(Z)).viewRow(0);
                    // Temp
                    DoubleMatrix2D a22 = algebra.inverse(XX);
                    DoubleMatrix1D temp = algebra.mult(algebra.transpose(a22), YX);
                    // Assign to b.
                    DoubleMatrix1D a23 = a6.viewRow(0);
                    DoubleMatrix1D a24 = temp.viewSelection(range1);
                    // Assign to omega.
                    omega.viewSelection(v, spov).viewRow(0).assign(temp.viewSelection(range2));
                    omega.viewSelection(spov, v).viewColumn(0).assign(temp.viewSelection(range2));
                    // Variance.
                    double tempVar = S.get(_v, _v) - algebra.mult(temp, YX);
                    DoubleMatrix2D a27 = omega.viewSelection(v, spov);
                    DoubleMatrix2D a28 = oInv.viewSelection(spov, spov);
                    DoubleMatrix2D a29 = omega.viewSelection(spov, v).copy();
                    DoubleMatrix2D a30 = algebra.mult(a27, a28);
                    DoubleMatrix2D a31 = algebra.mult(a30, a29);
                    omega.viewSelection(v, v).assign(tempVar);
                    omega.viewSelection(v, v).assign(a31, PlusMult.plusMult(1));
                } else {
                    DoubleMatrix2D oInv = new DenseDoubleMatrix2D(p, p);
                    DoubleMatrix2D a2 = omega.viewSelection(vcomp, vcomp);
                    DoubleMatrix2D a3 = algebra.inverse(a2);
                    oInv.viewSelection(vcomp, vcomp).assign(a3);
                    // System.out.println("O.inv = " + oInv);
                    DoubleMatrix2D a4 = oInv.viewSelection(spov, vcomp);
                    DoubleMatrix2D a5 = B.viewSelection(vcomp, all);
                    DoubleMatrix2D Z = algebra.mult(a4, a5);
                    // System.out.println("Z = " + Z);
                    // Build XX
                    DoubleMatrix2D XX = algebra.mult(algebra.mult(Z, S), Z.viewDice());
                    // System.out.println("XX = " + XX);
                    // Build XY
                    DoubleMatrix2D a20 = S.viewSelection(v, all);
                    DoubleMatrix1D YX = algebra.mult(a20, Z.viewDice()).viewRow(0);
                    // System.out.println("YX = " + YX);
                    // Temp
                    DoubleMatrix2D a22 = algebra.inverse(XX);
                    DoubleMatrix1D a23 = algebra.mult(algebra.transpose(a22), YX);
                    // Assign to omega.
                    DoubleMatrix1D a24 = omega.viewSelection(v, spov).viewRow(0);
                    DoubleMatrix1D a25 = omega.viewSelection(spov, v).viewColumn(0);
                    // System.out.println("Omega 2 " + omega);
                    // Variance.
                    double tempVar = S.get(_v, _v) - algebra.mult(a24, YX);
                    // System.out.println("tempVar = " + tempVar);
                    DoubleMatrix2D a27 = omega.viewSelection(v, spov);
                    DoubleMatrix2D a28 = oInv.viewSelection(spov, spov);
                    DoubleMatrix2D a29 = omega.viewSelection(spov, v).copy();
                    DoubleMatrix2D a30 = algebra.mult(a27, a28);
                    DoubleMatrix2D a31 = algebra.mult(a30, a29);
                    omega.set(_v, _v, tempVar + a31.get(0, 0));
                // System.out.println("Omega final " + omega);
        DoubleMatrix2D a32 = omega.copy();
        a32.assign(omegaOld, PlusMult.plusMult(-1));
        double diff1 = algebra.norm1(a32);
        DoubleMatrix2D a33 = B.copy();
        a33.assign(bOld, PlusMult.plusMult(-1));
        double diff2 = algebra.norm1(a32);
        double diff = diff1 + diff2;
        _diff = diff;
        if (diff < tolerance)
    DoubleMatrix2D a34 = algebra.inverse(B);
    DoubleMatrix2D a35 = algebra.inverse(B.viewDice());
    DoubleMatrix2D sigmahat = algebra.mult(algebra.mult(a34, omega), a35);
    DoubleMatrix2D lambdahat = omega.copy();
    DoubleMatrix2D a36 = lambdahat.viewSelection(ugComp, ugComp);
    a36.assign(factory.make(ugComp.length, ugComp.length, 0.0));
    DoubleMatrix2D omegahat = omega.copy();
    DoubleMatrix2D a37 = omegahat.viewSelection(ug, ug);
    a37.assign(factory.make(ug.length, ug.length, 0.0));
    DoubleMatrix2D bhat = B.copy();
    return new RicfResult(sigmahat, lambdahat, bhat, omegahat, i, _diff, covMatrix);
Also used : ICovarianceMatrix( Node(edu.cmu.tetrad.graph.Node) DoubleFactory2D(cern.colt.matrix.DoubleFactory2D) Endpoint(edu.cmu.tetrad.graph.Endpoint) Algebra(cern.colt.matrix.linalg.Algebra) SemGraph(edu.cmu.tetrad.graph.SemGraph) Graph(edu.cmu.tetrad.graph.Graph) DoubleMatrix2D(cern.colt.matrix.DoubleMatrix2D) DenseDoubleMatrix2D(cern.colt.matrix.impl.DenseDoubleMatrix2D) DoubleMatrix1D(cern.colt.matrix.DoubleMatrix1D) DenseDoubleMatrix1D(cern.colt.matrix.impl.DenseDoubleMatrix1D) DenseDoubleMatrix2D(cern.colt.matrix.impl.DenseDoubleMatrix2D) DenseDoubleMatrix1D(cern.colt.matrix.impl.DenseDoubleMatrix1D)

Example 2 with SemGraph

use of edu.cmu.tetrad.graph.SemGraph in project tetrad by cmu-phil.

the class GeneralizedSemEstimatorEditor method layoutByGraph.

public void layoutByGraph(Graph graph) {
    SemGraph _graph = (SemGraph) graphicalEditor().getWorkbench().getGraph();
    // graphicalEditor().getWorkbench().setGraph(_graph);
    errorTerms.setText("Show Error Terms");
Also used : SemGraph(edu.cmu.tetrad.graph.SemGraph)

Example 3 with SemGraph

use of edu.cmu.tetrad.graph.SemGraph in project tetrad by cmu-phil.

the class GeneralizedSemImEditor method layoutByGraph.

public void layoutByGraph(Graph graph) {
    SemGraph _graph = (SemGraph) graphicalEditor().getWorkbench().getGraph();
    // graphicalEditor().getWorkbench().setGraph(_graph);
    errorTerms.setText("Show Error Terms");
Also used : SemGraph(edu.cmu.tetrad.graph.SemGraph)

Example 4 with SemGraph

use of edu.cmu.tetrad.graph.SemGraph in project tetrad by cmu-phil.

the class SemOptimizerRicf method optimize.

// ==============================PUBLIC METHODS========================//
 * Optimizes the fitting function of the given Sem using the Powell method
 * from Numerical Recipes by adjusting the freeParameters of the Sem.
public void optimize(SemIm semIm) {
    if (numRestarts < 1)
        numRestarts = 1;
    if (numRestarts != 1) {
        throw new IllegalArgumentException("Number of restarts must be 1 for this method.");
    TetradMatrix sampleCovar = semIm.getSampleCovar();
    if (sampleCovar == null) {
        throw new NullPointerException("Sample covar has not been set.");
    if (DataUtils.containsMissingValue(sampleCovar)) {
        throw new IllegalArgumentException("Please remove or impute missing values.");
    if (DataUtils.containsMissingValue(sampleCovar)) {
        throw new IllegalArgumentException("Please remove or impute missing values.");
    TetradLogger.getInstance().log("info", "Trying EM...");
    // new SemOptimizerEm().optimize(semIm);
    CovarianceMatrix cov = new CovarianceMatrix(semIm.getMeasuredNodes(), sampleCovar, semIm.getSampleSize());
    SemGraph graph = semIm.getSemPm().getGraph();
    Ricf.RicfResult result = new Ricf().ricf(graph, cov, 0.001);
    // Ricf.RicfResult result = null;
    // for (int t = 0; t < 10; t++) {
    // Graph graph = semIm.getSemPm().getGraph();
    // result = new Ricf().ricf(graph, cov, 0.001);
    // TetradMatrix bHat = result.getBhat();
    // TetradMatrix lHat = result.getLhat();
    // TetradMatrix oHat = result.getOhat();
    // TetradMatrix sHat = result.getShat();
    // for (Parameter param : semIm.getFreeParameters()) {
    // if (param.getType() == ParamType.COEF) {
    // int i = semIm.getSemPm().getVariableNodes().indexOf(param.getNodeA());
    // int j = semIm.getSemPm().getVariableNodes().indexOf(param.getNodeB());
    // semIm.setEdgeCoef(param.getNodeA(), param.getNodeB(), -bHat.get(j, i));
    // }
    // if (param.getType() == ParamType.VAR) {
    // int i = semIm.getSemPm().getVariableNodes().indexOf(param.getNodeA());
    // if (lHat.get(i, i) != 0) {
    // semIm.setErrVar(param.getNodeA(), lHat.get(i, i));
    // } else if (oHat.get(i, i) != 0) {
    // semIm.setErrVar(param.getNodeA(), oHat.get(i, i));
    // }
    // }
    // }
    // if (t < 9) {
    // for (Parameter param : semIm.getFreeParameters()) {
    // double value = semIm.getParamValue(param);
    // double max = Double.NEGATIVE_INFINITY;
    // double d;
    // for (d = value - .5; d <= value + 0.5; d += 0.001) {
    // semIm.setParamValue(param, d);
    // double fml = semIm.getFml();
    // if (fml > max) max = fml;
    // }
    // semIm.setParamValue(param, d);
    // }
    // }
    // }
    TetradMatrix bHat = new TetradMatrix(result.getBhat().toArray());
    TetradMatrix lHat = new TetradMatrix(result.getLhat().toArray());
    TetradMatrix oHat = new TetradMatrix(result.getOhat().toArray());
    for (Parameter param : semIm.getFreeParameters()) {
        if (param.getType() == ParamType.COEF) {
            int i = semIm.getSemPm().getVariableNodes().indexOf(param.getNodeA());
            int j = semIm.getSemPm().getVariableNodes().indexOf(param.getNodeB());
            semIm.setEdgeCoef(param.getNodeA(), param.getNodeB(), -bHat.get(j, i));
        if (param.getType() == ParamType.VAR) {
            int i = semIm.getSemPm().getVariableNodes().indexOf(param.getNodeA());
            if (lHat.get(i, i) != 0) {
                semIm.setErrVar(param.getNodeA(), lHat.get(i, i));
            } else if (oHat.get(i, i) != 0) {
                semIm.setErrVar(param.getNodeA(), oHat.get(i, i));
        if (param.getType() == ParamType.COVAR) {
            int i = semIm.getSemPm().getVariableNodes().indexOf(param.getNodeA());
            int j = semIm.getSemPm().getVariableNodes().indexOf(param.getNodeB());
            if (lHat.get(i, i) != 0) {
                semIm.setErrCovar(param.getNodeA(), param.getNodeB(), lHat.get(j, i));
            } else if (oHat.get(i, i) != 0) {
                semIm.setErrCovar(param.getNodeA(), param.getNodeB(), oHat.get(j, i));
Also used : SemGraph(edu.cmu.tetrad.graph.SemGraph) TetradMatrix(edu.cmu.tetrad.util.TetradMatrix) CovarianceMatrix(

Example 5 with SemGraph

use of edu.cmu.tetrad.graph.SemGraph in project tetrad by cmu-phil.

the class SemXmlRenderer method makeMarginalErrorDistribution.

private static Element makeMarginalErrorDistribution(SemIm semIm) {
    Element marginalErrorElement = new Element(SemXmlConstants.MARGINAL_ERROR_DISTRIBUTION);
    Element normal;
    SemGraph semGraph = semIm.getSemPm().getGraph();
    for (Node node : getExogenousNodes(semGraph)) {
        normal = new Element(SemXmlConstants.NORMAL);
        normal.addAttribute(new Attribute(SemXmlConstants.VARIABLE, node.getName()));
        normal.addAttribute(new Attribute(SemXmlConstants.MEAN, "0.0"));
        normal.addAttribute(new Attribute(SemXmlConstants.VARIANCE, Double.toString(semIm.getParamValue(node, node))));
    return marginalErrorElement;
Also used : Attribute(nu.xom.Attribute) Element(nu.xom.Element) SemGraph(edu.cmu.tetrad.graph.SemGraph) Node(edu.cmu.tetrad.graph.Node)


SemGraph (edu.cmu.tetrad.graph.SemGraph)19 Node (edu.cmu.tetrad.graph.Node)8 DoubleFactory2D (cern.colt.matrix.DoubleFactory2D)2 DoubleMatrix1D (cern.colt.matrix.DoubleMatrix1D)2 DoubleMatrix2D (cern.colt.matrix.DoubleMatrix2D)2 DenseDoubleMatrix1D (cern.colt.matrix.impl.DenseDoubleMatrix1D)2 DenseDoubleMatrix2D (cern.colt.matrix.impl.DenseDoubleMatrix2D)2 Algebra (cern.colt.matrix.linalg.Algebra)2 ICovarianceMatrix ( Endpoint (edu.cmu.tetrad.graph.Endpoint)2 Graph (edu.cmu.tetrad.graph.Graph)2 TetradMatrix (edu.cmu.tetrad.util.TetradMatrix)2 CovarianceMatrix ( GraphNode (edu.cmu.tetrad.graph.GraphNode)1 TetradVector (edu.cmu.tetrad.util.TetradVector)1 LinkedList (java.util.LinkedList)1 Attribute (nu.xom.Attribute)1 Element (nu.xom.Element)1