Search in sources :

Example 6 with DoubleMatrix2D

use of cern.colt.matrix.DoubleMatrix2D in project tetrad by cmu-phil.

the class Ricf method ricf.

// =============================PUBLIC METHODS=========================//
public RicfResult ricf(SemGraph mag, ICovarianceMatrix covMatrix, double tolerance) {
    mag.setShowErrorTerms(false);
    DoubleFactory2D factory = DoubleFactory2D.dense;
    Algebra algebra = new Algebra();
    DoubleMatrix2D S = new DenseDoubleMatrix2D(covMatrix.getMatrix().toArray());
    int p = covMatrix.getDimension();
    if (p == 1) {
        return new RicfResult(S, S, null, null, 1, Double.NaN, covMatrix);
    }
    List<Node> nodes = new ArrayList<>();
    for (String name : covMatrix.getVariableNames()) {
        nodes.add(mag.getNode(name));
    }
    DoubleMatrix2D omega = factory.diagonal(factory.diagonal(S));
    DoubleMatrix2D B = factory.identity(p);
    int[] ug = ugNodes(mag, nodes);
    int[] ugComp = complement(p, ug);
    if (ug.length > 0) {
        List<Node> _ugNodes = new LinkedList<>();
        for (int i : ug) {
            _ugNodes.add(nodes.get(i));
        }
        Graph ugGraph = mag.subgraph(_ugNodes);
        ICovarianceMatrix ugCov = covMatrix.getSubmatrix(ug);
        DoubleMatrix2D lambdaInv = fitConGraph(ugGraph, ugCov, p + 1, tolerance).shat;
        omega.viewSelection(ug, ug).assign(lambdaInv);
    }
    // Prepare lists of parents and spouses.
    int[][] pars = parentIndices(p, mag, nodes);
    int[][] spo = spouseIndices(p, mag, nodes);
    int i = 0;
    double _diff;
    while (true) {
        i++;
        DoubleMatrix2D omegaOld = omega.copy();
        DoubleMatrix2D bOld = B.copy();
        for (int _v = 0; _v < p; _v++) {
            // Exclude the UG part.
            if (Arrays.binarySearch(ug, _v) >= 0) {
                continue;
            }
            int[] v = new int[] { _v };
            int[] vcomp = complement(p, v);
            int[] all = range(0, p - 1);
            int[] parv = pars[_v];
            int[] spov = spo[_v];
            DoubleMatrix2D a6 = B.viewSelection(v, parv);
            if (spov.length == 0) {
                if (parv.length != 0) {
                    if (i == 1) {
                        DoubleMatrix2D a1 = S.viewSelection(parv, parv);
                        DoubleMatrix2D a2 = S.viewSelection(v, parv);
                        DoubleMatrix2D a3 = algebra.inverse(a1);
                        DoubleMatrix2D a4 = algebra.mult(a2, a3);
                        a4.assign(Mult.mult(-1));
                        a6.assign(a4);
                        DoubleMatrix2D a7 = S.viewSelection(parv, v);
                        DoubleMatrix2D a9 = algebra.mult(a6, a7);
                        DoubleMatrix2D a8 = S.viewSelection(v, v);
                        DoubleMatrix2D a8b = omega.viewSelection(v, v);
                        a8b.assign(a8);
                        omega.viewSelection(v, v).assign(a9, PlusMult.plusMult(1));
                    }
                }
            } else {
                if (parv.length != 0) {
                    DoubleMatrix2D oInv = new DenseDoubleMatrix2D(p, p);
                    DoubleMatrix2D a2 = omega.viewSelection(vcomp, vcomp);
                    DoubleMatrix2D a3 = algebra.inverse(a2);
                    oInv.viewSelection(vcomp, vcomp).assign(a3);
                    DoubleMatrix2D Z = algebra.mult(oInv.viewSelection(spov, vcomp), B.viewSelection(vcomp, all));
                    int lpa = parv.length;
                    int lspo = spov.length;
                    // Build XX
                    DoubleMatrix2D XX = new DenseDoubleMatrix2D(lpa + lspo, lpa + lspo);
                    int[] range1 = range(0, lpa - 1);
                    int[] range2 = range(lpa, lpa + lspo - 1);
                    // Upper left quadrant
                    XX.viewSelection(range1, range1).assign(S.viewSelection(parv, parv));
                    // Upper right quadrant
                    DoubleMatrix2D a11 = algebra.mult(S.viewSelection(parv, all), algebra.transpose(Z));
                    XX.viewSelection(range1, range2).assign(a11);
                    // Lower left quadrant
                    DoubleMatrix2D a12 = XX.viewSelection(range2, range1);
                    DoubleMatrix2D a13 = algebra.transpose(XX.viewSelection(range1, range2));
                    a12.assign(a13);
                    // Lower right quadrant
                    DoubleMatrix2D a14 = XX.viewSelection(range2, range2);
                    DoubleMatrix2D a15 = algebra.mult(Z, S);
                    DoubleMatrix2D a16 = algebra.mult(a15, algebra.transpose(Z));
                    a14.assign(a16);
                    // Build XY
                    DoubleMatrix1D YX = new DenseDoubleMatrix1D(lpa + lspo);
                    DoubleMatrix1D a17 = YX.viewSelection(range1);
                    DoubleMatrix1D a18 = S.viewSelection(v, parv).viewRow(0);
                    a17.assign(a18);
                    DoubleMatrix1D a19 = YX.viewSelection(range2);
                    DoubleMatrix2D a20 = S.viewSelection(v, all);
                    DoubleMatrix1D a21 = algebra.mult(a20, algebra.transpose(Z)).viewRow(0);
                    a19.assign(a21);
                    // Temp
                    DoubleMatrix2D a22 = algebra.inverse(XX);
                    DoubleMatrix1D temp = algebra.mult(algebra.transpose(a22), YX);
                    // Assign to b.
                    DoubleMatrix1D a23 = a6.viewRow(0);
                    DoubleMatrix1D a24 = temp.viewSelection(range1);
                    a23.assign(a24);
                    a23.assign(Mult.mult(-1));
                    // Assign to omega.
                    omega.viewSelection(v, spov).viewRow(0).assign(temp.viewSelection(range2));
                    omega.viewSelection(spov, v).viewColumn(0).assign(temp.viewSelection(range2));
                    // Variance.
                    double tempVar = S.get(_v, _v) - algebra.mult(temp, YX);
                    DoubleMatrix2D a27 = omega.viewSelection(v, spov);
                    DoubleMatrix2D a28 = oInv.viewSelection(spov, spov);
                    DoubleMatrix2D a29 = omega.viewSelection(spov, v).copy();
                    DoubleMatrix2D a30 = algebra.mult(a27, a28);
                    DoubleMatrix2D a31 = algebra.mult(a30, a29);
                    omega.viewSelection(v, v).assign(tempVar);
                    omega.viewSelection(v, v).assign(a31, PlusMult.plusMult(1));
                } else {
                    DoubleMatrix2D oInv = new DenseDoubleMatrix2D(p, p);
                    DoubleMatrix2D a2 = omega.viewSelection(vcomp, vcomp);
                    DoubleMatrix2D a3 = algebra.inverse(a2);
                    oInv.viewSelection(vcomp, vcomp).assign(a3);
                    // System.out.println("O.inv = " + oInv);
                    DoubleMatrix2D a4 = oInv.viewSelection(spov, vcomp);
                    DoubleMatrix2D a5 = B.viewSelection(vcomp, all);
                    DoubleMatrix2D Z = algebra.mult(a4, a5);
                    // System.out.println("Z = " + Z);
                    // Build XX
                    DoubleMatrix2D XX = algebra.mult(algebra.mult(Z, S), Z.viewDice());
                    // System.out.println("XX = " + XX);
                    // Build XY
                    DoubleMatrix2D a20 = S.viewSelection(v, all);
                    DoubleMatrix1D YX = algebra.mult(a20, Z.viewDice()).viewRow(0);
                    // System.out.println("YX = " + YX);
                    // Temp
                    DoubleMatrix2D a22 = algebra.inverse(XX);
                    DoubleMatrix1D a23 = algebra.mult(algebra.transpose(a22), YX);
                    // Assign to omega.
                    DoubleMatrix1D a24 = omega.viewSelection(v, spov).viewRow(0);
                    a24.assign(a23);
                    DoubleMatrix1D a25 = omega.viewSelection(spov, v).viewColumn(0);
                    a25.assign(a23);
                    // System.out.println("Omega 2 " + omega);
                    // Variance.
                    double tempVar = S.get(_v, _v) - algebra.mult(a24, YX);
                    // System.out.println("tempVar = " + tempVar);
                    DoubleMatrix2D a27 = omega.viewSelection(v, spov);
                    DoubleMatrix2D a28 = oInv.viewSelection(spov, spov);
                    DoubleMatrix2D a29 = omega.viewSelection(spov, v).copy();
                    DoubleMatrix2D a30 = algebra.mult(a27, a28);
                    DoubleMatrix2D a31 = algebra.mult(a30, a29);
                    omega.set(_v, _v, tempVar + a31.get(0, 0));
                // System.out.println("Omega final " + omega);
                }
            }
        }
        DoubleMatrix2D a32 = omega.copy();
        a32.assign(omegaOld, PlusMult.plusMult(-1));
        double diff1 = algebra.norm1(a32);
        DoubleMatrix2D a33 = B.copy();
        a33.assign(bOld, PlusMult.plusMult(-1));
        double diff2 = algebra.norm1(a32);
        double diff = diff1 + diff2;
        _diff = diff;
        if (diff < tolerance)
            break;
    }
    DoubleMatrix2D a34 = algebra.inverse(B);
    DoubleMatrix2D a35 = algebra.inverse(B.viewDice());
    DoubleMatrix2D sigmahat = algebra.mult(algebra.mult(a34, omega), a35);
    DoubleMatrix2D lambdahat = omega.copy();
    DoubleMatrix2D a36 = lambdahat.viewSelection(ugComp, ugComp);
    a36.assign(factory.make(ugComp.length, ugComp.length, 0.0));
    DoubleMatrix2D omegahat = omega.copy();
    DoubleMatrix2D a37 = omegahat.viewSelection(ug, ug);
    a37.assign(factory.make(ug.length, ug.length, 0.0));
    DoubleMatrix2D bhat = B.copy();
    return new RicfResult(sigmahat, lambdahat, bhat, omegahat, i, _diff, covMatrix);
}
Also used : ICovarianceMatrix(edu.cmu.tetrad.data.ICovarianceMatrix) Node(edu.cmu.tetrad.graph.Node) DoubleFactory2D(cern.colt.matrix.DoubleFactory2D) Endpoint(edu.cmu.tetrad.graph.Endpoint) Algebra(cern.colt.matrix.linalg.Algebra) SemGraph(edu.cmu.tetrad.graph.SemGraph) Graph(edu.cmu.tetrad.graph.Graph) DoubleMatrix2D(cern.colt.matrix.DoubleMatrix2D) DenseDoubleMatrix2D(cern.colt.matrix.impl.DenseDoubleMatrix2D) DoubleMatrix1D(cern.colt.matrix.DoubleMatrix1D) DenseDoubleMatrix1D(cern.colt.matrix.impl.DenseDoubleMatrix1D) DenseDoubleMatrix2D(cern.colt.matrix.impl.DenseDoubleMatrix2D) DenseDoubleMatrix1D(cern.colt.matrix.impl.DenseDoubleMatrix1D)

Example 7 with DoubleMatrix2D

use of cern.colt.matrix.DoubleMatrix2D in project tetrad by cmu-phil.

the class MixedUtils method graphToMatrix.

public static DoubleMatrix2D graphToMatrix(Graph graph, double undirectedWeight, double directedWeight) {
    // initialize matrix
    int n = graph.getNumNodes();
    DoubleMatrix2D matrix = DoubleFactory2D.dense.make(n, n, 0.0);
    // map node names in order of appearance
    HashMap<Node, Integer> map = new HashMap<>();
    int i = 0;
    for (Node node : graph.getNodes()) {
        map.put(node, i);
        i++;
    }
    // mark edges
    for (Edge edge : graph.getEdges()) {
        // if directed find which is parent/child
        Node node1 = edge.getNode1();
        Node node2 = edge.getNode2();
        // treat bidirected as undirected...
        if (!edge.isDirected() || (edge.getEndpoint1() == Endpoint.ARROW && edge.getEndpoint2() == Endpoint.ARROW)) {
            matrix.set(map.get(node1), map.get(node2), undirectedWeight);
            matrix.set(map.get(node2), map.get(node1), undirectedWeight);
        } else {
            if (edge.pointsTowards(node1)) {
                matrix.set(map.get(node2), map.get(node1), directedWeight);
            } else {
                // if (edge.pointsTowards(node2)) {
                matrix.set(map.get(node1), map.get(node2), directedWeight);
            }
        }
    }
    return matrix;
}
Also used : DoubleMatrix2D(cern.colt.matrix.DoubleMatrix2D)

Example 8 with DoubleMatrix2D

use of cern.colt.matrix.DoubleMatrix2D in project tetrad by cmu-phil.

the class MGM method nonSmoothValue.

/**
 * Calculates penalty term of objective function
 *
 * @param parIn
 * @return
 */
public double nonSmoothValue(DoubleMatrix1D parIn) {
    // DoubleMatrix1D tlam = lambda.copy().assign(Functions.mult(t));
    // Dimension checked in constructor
    // par is a copy so we can update it
    MGMParams par = new MGMParams(parIn, p, lsum);
    // penbeta = t(1).*(wv(1:p)'*wv(1:p));
    // betascale=zeros(size(beta));
    // betascale=max(0,1-penbeta./abs(beta));
    DoubleMatrix2D weightMat = alg.multOuter(weights, weights, null);
    // int p = xDat.columns();
    // weight beta
    // betaw = (wv(1:p)'*wv(1:p)).*abs(beta);
    // betanorms=sum(betaw(:));
    DoubleMatrix2D betaWeight = weightMat.viewPart(0, 0, p, p);
    DoubleMatrix2D absBeta = par.beta.copy().assign(Functions.abs);
    double betaNorms = absBeta.assign(betaWeight, Functions.mult).zSum();
    /*
        thetanorms=0;
        for s=1:p
            for j=1:q
                tempvec=theta(Lsums(j)+1:Lsums(j+1),s);
                thetanorms=thetanorms+(wv(s)*wv(p+j))*norm(tempvec);
            end
        end
        */
    double thetaNorms = 0;
    for (int i = 0; i < p; i++) {
        if (Thread.currentThread().isInterrupted()) {
            break;
        }
        for (int j = 0; j < lcumsum.length - 1; j++) {
            if (Thread.currentThread().isInterrupted()) {
                break;
            }
            DoubleMatrix1D tempVec = par.theta.viewColumn(i).viewPart(lcumsum[j], l[j]);
            thetaNorms += weightMat.get(i, p + j) * Math.sqrt(alg.norm2(tempVec));
        }
    }
    /*
        for r=1:q
            for j=1:q
                if r<j
                    tempmat=phi(Lsums(r)+1:Lsums(r+1),Lsums(j)+1:Lsums(j+1));
                    tempmat=max(0,1-t(3)*(wv(p+r)*wv(p+j))/norm(tempmat))*tempmat; % Lj by 2*Lr
                    phinorms=phinorms+(wv(p+r)*wv(p+j))*norm(tempmat,'fro');
                    phi( Lsums(r)+1:Lsums(r+1),Lsums(j)+1:Lsums(j+1) )=tempmat;
                end
            end
        end
         */
    double phiNorms = 0;
    for (int i = 0; i < lcumsum.length - 1; i++) {
        if (Thread.currentThread().isInterrupted()) {
            break;
        }
        for (int j = i + 1; j < lcumsum.length - 1; j++) {
            if (Thread.currentThread().isInterrupted()) {
                break;
            }
            DoubleMatrix2D tempMat = par.phi.viewPart(lcumsum[i], lcumsum[j], l[i], l[j]);
            phiNorms += weightMat.get(p + i, p + j) * alg.normF(tempMat);
        }
    }
    return lambda.get(0) * betaNorms + lambda.get(1) * thetaNorms + lambda.get(2) * phiNorms;
}
Also used : DoubleMatrix2D(cern.colt.matrix.DoubleMatrix2D) DoubleMatrix1D(cern.colt.matrix.DoubleMatrix1D)

Example 9 with DoubleMatrix2D

use of cern.colt.matrix.DoubleMatrix2D in project tetrad by cmu-phil.

the class MGM method adjMatFromMGM.

/**
 * Converts MGM to matrix of doubles. uses 2-norm to combine c-d edge parameters into single value and f-norm for
 * d-d edge parameters.
 *
 * @return
 */
public DoubleMatrix2D adjMatFromMGM() {
    // List<Node> variables = getVariable();
    DoubleMatrix2D outMat = DoubleFactory2D.dense.make(p + q, p + q);
    outMat.viewPart(0, 0, p, p).assign(params.beta.copy().assign(alg.transpose(params.beta), Functions.plus));
    for (int i = 0; i < p; i++) {
        if (Thread.currentThread().isInterrupted()) {
            break;
        }
        for (int j = 0; j < q; j++) {
            if (Thread.currentThread().isInterrupted()) {
                break;
            }
            double val = norm2(params.theta.viewColumn(i).viewPart(lcumsum[j], l[j]));
            outMat.set(i, p + j, val);
            outMat.set(p + j, i, val);
        }
    }
    for (int i = 0; i < q; i++) {
        if (Thread.currentThread().isInterrupted()) {
            break;
        }
        for (int j = i + 1; j < q; j++) {
            if (Thread.currentThread().isInterrupted()) {
                break;
            }
            double val = alg.normF(params.phi.viewPart(lcumsum[i], lcumsum[j], l[i], l[j]));
            outMat.set(p + i, p + j, val);
            outMat.set(p + j, p + i, val);
        }
    }
    // order the adjmat to be the same as the original DataSet variable ordering
    if (initVariables != null) {
        int[] varMap = new int[p + q];
        for (int i = 0; i < p + q; i++) {
            varMap[i] = variables.indexOf(initVariables.get(i));
        }
        outMat = outMat.viewSelection(varMap, varMap);
    }
    return outMat;
}
Also used : DoubleMatrix2D(cern.colt.matrix.DoubleMatrix2D)

Example 10 with DoubleMatrix2D

use of cern.colt.matrix.DoubleMatrix2D in project tetrad by cmu-phil.

the class MGM method initParameters.

// init all parameters to zeros except for betad which is set to 1s
private void initParameters() {
    lcumsum = new int[l.length + 1];
    lcumsum[0] = 0;
    for (int i = 0; i < l.length; i++) {
        lcumsum[i + 1] = lcumsum[i] + l[i];
    }
    lsum = lcumsum[l.length];
    // LH init to zeros, maybe should be random init?
    // continuous-continuous
    DoubleMatrix2D beta = factory2D.make(xDat.columns(), xDat.columns());
    // cont squared node pot
    DoubleMatrix1D betad = factory1D.make(xDat.columns(), 1.0);
    // continuous-discrete
    DoubleMatrix2D theta = factory2D.make(lsum, xDat.columns());
    // continuous-discrete
    ;
    // discrete-discrete
    DoubleMatrix2D phi = factory2D.make(lsum, lsum);
    // cont linear node pot
    DoubleMatrix1D alpha1 = factory1D.make(xDat.columns());
    // disc node potbeta =
    DoubleMatrix1D alpha2 = factory1D.make(lsum);
    params = new MGMParams(beta, betad, theta, phi, alpha1, alpha2);
// separate lambda for each type of edge, [cc, cd, dd]
// lambda = factory1D.make(3);
}
Also used : DoubleMatrix2D(cern.colt.matrix.DoubleMatrix2D) DoubleMatrix1D(cern.colt.matrix.DoubleMatrix1D)

Aggregations

DoubleMatrix2D (cern.colt.matrix.DoubleMatrix2D)137 DenseDoubleMatrix2D (cern.colt.matrix.impl.DenseDoubleMatrix2D)39 DoubleMatrix1D (cern.colt.matrix.DoubleMatrix1D)37 Algebra (cern.colt.matrix.linalg.Algebra)16 DoubleFactory2D (cern.colt.matrix.DoubleFactory2D)13 DenseDoubleMatrix1D (cern.colt.matrix.impl.DenseDoubleMatrix1D)13 Node (edu.cmu.tetrad.graph.Node)11 Graph (edu.cmu.tetrad.graph.Graph)8 Test (org.junit.Test)6 DoubleMatrixReader (ubic.basecode.io.reader.DoubleMatrixReader)6 StringMatrixReader (ubic.basecode.io.reader.StringMatrixReader)6 DataSet (edu.cmu.tetrad.data.DataSet)5 DoubleArrayList (cern.colt.list.DoubleArrayList)4 TetradMatrix (edu.cmu.tetrad.util.TetradMatrix)4 DenseDoubleMatrix (ubic.basecode.dataStructure.matrix.DenseDoubleMatrix)4 AbstractFormatter (cern.colt.matrix.impl.AbstractFormatter)3 RobustEigenDecomposition (dr.math.matrixAlgebra.RobustEigenDecomposition)3 Endpoint (edu.cmu.tetrad.graph.Endpoint)3 ExpressionDataDoubleMatrix (ubic.gemma.core.datastructure.matrix.ExpressionDataDoubleMatrix)3 BioMaterial (ubic.gemma.model.expression.biomaterial.BioMaterial)3