Search in sources :

Example 16 with ICovarianceMatrix

use of edu.cmu.tetrad.data.ICovarianceMatrix in project tetrad by cmu-phil.

the class MimbuildTrek method search.

// =================================== PUBLIC METHODS =========================================//
public Graph search(List<List<Node>> clustering, List<String> latentNames, ICovarianceMatrix measuresCov) {
    List<String> _latentNames = new ArrayList<>(latentNames);
    List<String> allVarNames = new ArrayList<>();
    for (List<Node> cluster : clustering) {
        for (Node node : cluster) allVarNames.add(node.getName());
    }
    measuresCov = measuresCov.getSubmatrix(allVarNames);
    List<List<Node>> _clustering = new ArrayList<>();
    for (List<Node> cluster : clustering) {
        List<Node> _cluster = new ArrayList<>();
        for (Node node : cluster) {
            _cluster.add(measuresCov.getVariable(node.getName()));
        }
        _clustering.add(_cluster);
    }
    List<Node> latents = defineLatents(_latentNames);
    this.latents = latents;
    // This removes the small clusters and their names.
    removeSmallClusters(latents, _clustering, getMinClusterSize());
    this.clustering = _clustering;
    Node[][] indicators = new Node[latents.size()][];
    for (int i = 0; i < latents.size(); i++) {
        indicators[i] = new Node[_clustering.get(i).size()];
        for (int j = 0; j < _clustering.get(i).size(); j++) {
            indicators[i][j] = _clustering.get(i).get(j);
        }
    }
    TetradMatrix cov = getCov(measuresCov, latents, indicators);
    CovarianceMatrix latentscov = new CovarianceMatrix(latents, cov, measuresCov.getSampleSize());
    this.latentsCov = latentscov;
    Graph graph;
    Cpc search = new Cpc(new IndTestTrekSep(measuresCov, alpha, clustering, latents));
    search.setKnowledge(knowledge);
    graph = search.search();
    // try {
    // Ges search = new Ges(latentscov);
    // search.setCorrErrorsAlpha(penaltyDiscount);
    // search.setKnowledge(knowledge);
    // graph = search.search();
    // } catch (Exception e) {
    // //            e.printStackTrace();
    // CPC search = new CPC(new IndTestFisherZ(latentscov, alpha));
    // search.setKnowledge(knowledge);
    // graph = search.search();
    // }
    this.structureGraph = new EdgeListGraph(graph);
    GraphUtils.fruchtermanReingoldLayout(this.structureGraph);
    return this.structureGraph;
}
Also used : ArrayList(java.util.ArrayList) TetradMatrix(edu.cmu.tetrad.util.TetradMatrix) CovarianceMatrix(edu.cmu.tetrad.data.CovarianceMatrix) ICovarianceMatrix(edu.cmu.tetrad.data.ICovarianceMatrix) ArrayList(java.util.ArrayList) List(java.util.List)

Example 17 with ICovarianceMatrix

use of edu.cmu.tetrad.data.ICovarianceMatrix in project tetrad by cmu-phil.

the class MimSearchEditor method getIndTestParamBox.

/**
 * Factory to return the correct param editor for independence test params.
 * This will go in a little box in the search editor.
 */
private JComponent getIndTestParamBox(Parameters params) {
    if (params == null) {
        throw new NullPointerException();
    }
    if (params instanceof Parameters) {
        MimRunner runner = getMimRunner();
        params.set("varNames", runner.getParams().get("varNames", null));
        DataModel dataModel = runner.getData();
        if (dataModel instanceof DataSet) {
            DataSet data = (DataSet) runner.getData();
            boolean discrete = data.isDiscrete();
            return new BuildPureClustersIndTestParamsEditor(params, discrete);
        } else if (dataModel instanceof ICovarianceMatrix) {
            return new BuildPureClustersIndTestParamsEditor(params, false);
        }
    }
    if (params instanceof Parameters) {
        MimRunner runner = getMimRunner();
        params.set("varNames", runner.getParams().get("varNames", null));
        boolean discreteData = false;
        if (runner.getData() instanceof DataSet) {
            discreteData = ((DataSet) runner.getData()).isDiscrete();
        }
        return new PurifyIndTestParamsEditor(params, discreteData);
    }
    if (params instanceof Parameters) {
        MimRunner runner = getMimRunner();
        params.set("varNames", runner.getParams().get("varNames", null));
        return new MimBuildIndTestParamsEditor(params);
    }
    throw new IllegalArgumentException("Unrecognized IndTestParams: " + params.getClass());
}
Also used : Parameters(edu.cmu.tetrad.util.Parameters) DataSet(edu.cmu.tetrad.data.DataSet) DataModel(edu.cmu.tetrad.data.DataModel) ICovarianceMatrix(edu.cmu.tetrad.data.ICovarianceMatrix) MimRunner(edu.cmu.tetradapp.model.MimRunner)

Example 18 with ICovarianceMatrix

use of edu.cmu.tetrad.data.ICovarianceMatrix in project tetrad by cmu-phil.

the class PerformanceTestsDan method testIdaOutputForDan.

private void testIdaOutputForDan() {
    int numRuns = 100;
    for (int run = 0; run < numRuns; run++) {
        double alphaGFci = 0.01;
        double alphaPc = 0.01;
        int penaltyDiscount = 1;
        int depth = 3;
        int maxPathLength = 3;
        final int numVars = 15;
        final double edgesPerNode = 1.0;
        final int numCases = 1000;
        // final int numLatents = RandomUtil.getInstance().nextInt(3) + 1;
        final int numLatents = 6;
        // writeToFile = false;
        PrintStream out1;
        PrintStream out2;
        PrintStream out3;
        PrintStream out4;
        PrintStream out5;
        PrintStream out6;
        PrintStream out7;
        PrintStream out8;
        PrintStream out9;
        PrintStream out10;
        PrintStream out11;
        PrintStream out12;
        File dir0 = new File("gfci.output");
        dir0.mkdirs();
        File dir = new File(dir0, "" + (run + 1));
        dir.mkdir();
        try {
            out1 = new PrintStream(new File(dir, "hyperparameters.txt"));
            out2 = new PrintStream(new File(dir, "variables.txt"));
            out3 = new PrintStream(new File(dir, "dag.long.txt"));
            out4 = new PrintStream(new File(dir, "dag.matrix.txt"));
            out5 = new PrintStream(new File(dir, "coef.matrix.txt"));
            out6 = new PrintStream(new File(dir, "pag.long.txt"));
            out7 = new PrintStream(new File(dir, "pag.matrix.txt"));
            out8 = new PrintStream(new File(dir, "pattern.long.txt"));
            out9 = new PrintStream(new File(dir, "pattern.matrix.txt"));
            out10 = new PrintStream(new File(dir, "data.txt"));
            out11 = new PrintStream(new File(dir, "true.pag.long.txt"));
            out12 = new PrintStream(new File(dir, "true.pag.matrix.txt"));
        } catch (FileNotFoundException e) {
            e.printStackTrace();
            throw new RuntimeException(e);
        }
        out1.println("Num _vars = " + numVars);
        out1.println("Num edges = " + (int) (numVars * edgesPerNode));
        out1.println("Num cases = " + numCases);
        out1.println("Alpha for PC = " + alphaPc);
        out1.println("Alpha for FFCI = " + alphaGFci);
        out1.println("Penalty discount = " + penaltyDiscount);
        out1.println("Depth = " + depth);
        out1.println("Maximum reachable path length for dsep search and discriminating undirectedPaths = " + maxPathLength);
        List<Node> vars = new ArrayList<>();
        for (int i = 0; i < numVars; i++) vars.add(new GraphNode("X" + (i + 1)));
        // Graph dag = DataGraphUtils.randomDagQuick2(varsWithLatents, 0, (int) (varsWithLatents.size() * edgesPerNode));
        Graph dag = GraphUtils.randomGraph(vars, 0, (int) (vars.size() * edgesPerNode), 5, 5, 5, false);
        GraphUtils.fixLatents1(numLatents, dag);
        // List<Node> varsWithLatents = new ArrayList<Node>();
        // 
        // Graph dag = getLatentGraph(_vars, varsWithLatents, edgesPerNode, numLatents);
        out3.println(dag);
        printDanMatrix(vars, dag, out4);
        SemPm pm = new SemPm(dag);
        SemIm im = new SemIm(pm);
        NumberFormat nf = new DecimalFormat("0.0000");
        for (int i = 0; i < vars.size(); i++) {
            for (Node var : vars) {
                if (im.existsEdgeCoef(var, vars.get(i))) {
                    double coef = im.getEdgeCoef(var, vars.get(i));
                    out5.print(nf.format(coef) + "\t");
                } else {
                    out5.print(nf.format(0) + "\t");
                }
            }
            out5.println();
        }
        out5.println();
        String vars_temp = vars.toString();
        vars_temp = vars_temp.replace("[", "");
        vars_temp = vars_temp.replace("]", "");
        vars_temp = vars_temp.replace("X", "");
        out2.println(vars_temp);
        List<Node> _vars = new ArrayList<>();
        for (Node node : vars) {
            if (node.getNodeType() == NodeType.MEASURED) {
                _vars.add(node);
            }
        }
        String _vars_temp = _vars.toString();
        _vars_temp = _vars_temp.replace("[", "");
        _vars_temp = _vars_temp.replace("]", "");
        _vars_temp = _vars_temp.replace("X", "");
        out2.println(_vars_temp);
        DataSet fullData = im.simulateData(numCases, false);
        DataSet data = DataUtils.restrictToMeasured(fullData);
        ICovarianceMatrix cov = new CovarianceMatrix(data);
        final IndTestFisherZ independenceTestGFci = new IndTestFisherZ(cov, alphaGFci);
        final edu.cmu.tetrad.search.SemBicScore scoreGfci = new edu.cmu.tetrad.search.SemBicScore(cov);
        out6.println("GFCI.PAG_of_the_true_DAG");
        GFci gFci = new GFci(independenceTestGFci, scoreGfci);
        gFci.setVerbose(false);
        gFci.setMaxDegree(depth);
        gFci.setMaxPathLength(maxPathLength);
        // gFci.setPossibleDsepSearchDone(true);
        gFci.setCompleteRuleSetUsed(true);
        Graph pag = gFci.search();
        out6.println(pag);
        printDanMatrix(_vars, pag, out7);
        out8.println("Pattern_of_the_true_DAG OVER MEASURED VARIABLES");
        final IndTestFisherZ independencePc = new IndTestFisherZ(cov, alphaPc);
        Pc pc = new Pc(independencePc);
        pc.setVerbose(false);
        pc.setDepth(depth);
        Graph pattern = pc.search();
        out8.println(pattern);
        printDanMatrix(_vars, pattern, out9);
        out10.println(data);
        out11.println("True PAG_of_the_true_DAG");
        final Graph truePag = new DagToPag(dag).convert();
        out11.println(truePag);
        printDanMatrix(_vars, truePag, out12);
        out1.close();
        out2.close();
        out3.close();
        out4.close();
        out5.close();
        out6.close();
        out7.close();
        out8.close();
        out9.close();
        out10.close();
        out11.close();
        out12.close();
    }
}
Also used : DataSet(edu.cmu.tetrad.data.DataSet) ICovarianceMatrix(edu.cmu.tetrad.data.ICovarianceMatrix) DecimalFormat(java.text.DecimalFormat) FileNotFoundException(java.io.FileNotFoundException) ArrayList(java.util.ArrayList) Pc(edu.cmu.tetrad.search.Pc) SemPm(edu.cmu.tetrad.sem.SemPm) PrintStream(java.io.PrintStream) IndTestFisherZ(edu.cmu.tetrad.search.IndTestFisherZ) GFci(edu.cmu.tetrad.search.GFci) CovarianceMatrix(edu.cmu.tetrad.data.CovarianceMatrix) ICovarianceMatrix(edu.cmu.tetrad.data.ICovarianceMatrix) DagToPag(edu.cmu.tetrad.search.DagToPag) File(java.io.File) SemIm(edu.cmu.tetrad.sem.SemIm) NumberFormat(java.text.NumberFormat)

Example 19 with ICovarianceMatrix

use of edu.cmu.tetrad.data.ICovarianceMatrix in project tetrad by cmu-phil.

the class GeneralAlgorithmRunner method execute.

// ============================PUBLIC METHODS==========================//
@Override
public void execute() {
    List<Graph> graphList = new ArrayList<>();
    int i = 0;
    if (getDataModelList().isEmpty()) {
        if (getSourceGraph() != null) {
            Algorithm algo = getAlgorithm();
            if (algo instanceof HasKnowledge) {
                ((HasKnowledge) algo).setKnowledge(getKnowledge());
            }
            graphList.add(algo.search(null, parameters));
        } else {
            throw new IllegalArgumentException("The parent boxes did not include any datasets or graphs. Try opening\n" + "the editors for those boxes and loading or simulating them.");
        }
    } else {
        if (getAlgorithm() instanceof MultiDataSetAlgorithm) {
            for (int k = 0; k < parameters.getInt("numRuns"); k++) {
                List<DataSet> dataSets = getDataModelList().stream().map(e -> (DataSet) e).collect(Collectors.toCollection(ArrayList::new));
                if (dataSets.size() < parameters.getInt("randomSelectionSize")) {
                    throw new IllegalArgumentException("Sorry, the 'random selection size' is greater than " + "the number of data sets.");
                }
                Collections.shuffle(dataSets);
                List<DataModel> sub = new ArrayList<>();
                for (int j = 0; j < parameters.getInt("randomSelectionSize"); j++) {
                    sub.add(dataSets.get(j));
                }
                Algorithm algo = getAlgorithm();
                if (algo instanceof HasKnowledge) {
                    ((HasKnowledge) algo).setKnowledge(getKnowledge());
                }
                graphList.add(((MultiDataSetAlgorithm) algo).search(sub, parameters));
            }
        } else if (getAlgorithm() instanceof ClusterAlgorithm) {
            for (int k = 0; k < parameters.getInt("numRuns"); k++) {
                getDataModelList().forEach(dataModel -> {
                    if (dataModel instanceof ICovarianceMatrix) {
                        ICovarianceMatrix dataSet = (ICovarianceMatrix) dataModel;
                        graphList.add(algorithm.search(dataSet, parameters));
                    } else if (dataModel instanceof DataSet) {
                        DataSet dataSet = (DataSet) dataModel;
                        if (!dataSet.isContinuous()) {
                            throw new IllegalArgumentException("Sorry, you need a continuous dataset for a cluster algorithm.");
                        }
                        graphList.add(algorithm.search(dataSet, parameters));
                    }
                });
            }
        } else {
            getDataModelList().forEach(data -> {
                IKnowledge knowledgeFromData = data.getKnowledge();
                if (!(knowledgeFromData == null || knowledgeFromData.getVariables().isEmpty())) {
                    this.knowledge = knowledgeFromData;
                }
                Algorithm algo = getAlgorithm();
                if (algo instanceof HasKnowledge) {
                    ((HasKnowledge) algo).setKnowledge(getKnowledge());
                }
                DataType algDataType = algo.getDataType();
                if (data.isContinuous() && (algDataType == DataType.Continuous || algDataType == DataType.Mixed)) {
                    graphList.add(algo.search(data, parameters));
                } else if (data.isDiscrete() && (algDataType == DataType.Discrete || algDataType == DataType.Mixed)) {
                    graphList.add(algo.search(data, parameters));
                } else if (data.isMixed() && algDataType == DataType.Mixed) {
                    graphList.add(algo.search(data, parameters));
                } else {
                    throw new IllegalArgumentException("The type of data changed; try opening up the search editor and " + "running the algorithm there.");
                }
            });
        }
    }
    if (getKnowledge().getVariablesNotInTiers().size() < getKnowledge().getVariables().size()) {
        for (Graph graph : graphList) {
            SearchGraphUtils.arrangeByKnowledgeTiers(graph, getKnowledge());
        }
    } else {
        for (Graph graph : graphList) {
            GraphUtils.circleLayout(graph, 225, 200, 150);
        }
    }
    this.graphList = graphList;
}
Also used : GraphUtils(edu.cmu.tetrad.graph.GraphUtils) ObjectInputStream(java.io.ObjectInputStream) Parameters(edu.cmu.tetrad.util.Parameters) HashMap(java.util.HashMap) Triple(edu.cmu.tetrad.graph.Triple) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) HasKnowledge(edu.cmu.tetrad.algcomparison.utils.HasKnowledge) DataType(edu.cmu.tetrad.data.DataType) KnowledgeBoxInput(edu.cmu.tetrad.data.KnowledgeBoxInput) Map(java.util.Map) ClusterAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.cluster.ClusterAlgorithm) ICovarianceMatrix(edu.cmu.tetrad.data.ICovarianceMatrix) Fges(edu.cmu.tetrad.algcomparison.algorithm.oracle.pattern.Fges) BdeuScore(edu.cmu.tetrad.algcomparison.score.BdeuScore) Algorithm(edu.cmu.tetrad.algcomparison.algorithm.Algorithm) IKnowledge(edu.cmu.tetrad.data.IKnowledge) Graph(edu.cmu.tetrad.graph.Graph) IOException(java.io.IOException) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) Collectors(java.util.stream.Collectors) DataModel(edu.cmu.tetrad.data.DataModel) DataModelList(edu.cmu.tetrad.data.DataModelList) List(java.util.List) ParamsResettable(edu.cmu.tetrad.session.ParamsResettable) DataSet(edu.cmu.tetrad.data.DataSet) ImpliedOrientation(edu.cmu.tetrad.search.ImpliedOrientation) MultiDataSetAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm) SessionModel(edu.cmu.tetrad.session.SessionModel) IndependenceTest(edu.cmu.tetrad.search.IndependenceTest) SearchGraphUtils(edu.cmu.tetrad.search.SearchGraphUtils) Knowledge2(edu.cmu.tetrad.data.Knowledge2) Unmarshallable(edu.cmu.tetrad.util.Unmarshallable) Collections(java.util.Collections) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) DataSet(edu.cmu.tetrad.data.DataSet) ICovarianceMatrix(edu.cmu.tetrad.data.ICovarianceMatrix) ArrayList(java.util.ArrayList) HasKnowledge(edu.cmu.tetrad.algcomparison.utils.HasKnowledge) ClusterAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.cluster.ClusterAlgorithm) Algorithm(edu.cmu.tetrad.algcomparison.algorithm.Algorithm) MultiDataSetAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm) IKnowledge(edu.cmu.tetrad.data.IKnowledge) ClusterAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.cluster.ClusterAlgorithm) Graph(edu.cmu.tetrad.graph.Graph) MultiDataSetAlgorithm(edu.cmu.tetrad.algcomparison.algorithm.MultiDataSetAlgorithm) DataModel(edu.cmu.tetrad.data.DataModel) DataType(edu.cmu.tetrad.data.DataType)

Example 20 with ICovarianceMatrix

use of edu.cmu.tetrad.data.ICovarianceMatrix in project tetrad by cmu-phil.

the class Ricf method ricf2.

/**
 * same as above but takes a Graph instead of a SemGraph *
 */
public RicfResult ricf2(Graph mag, ICovarianceMatrix covMatrix, double tolerance) {
    // mag.setShowErrorTerms(false);
    DoubleFactory2D factory = DoubleFactory2D.dense;
    Algebra algebra = new Algebra();
    DoubleMatrix2D S = new DenseDoubleMatrix2D(covMatrix.getMatrix().toArray());
    int p = covMatrix.getDimension();
    if (p == 1) {
        return new RicfResult(S, S, null, null, 1, Double.NaN, covMatrix);
    }
    List<Node> nodes = new ArrayList<>();
    for (String name : covMatrix.getVariableNames()) {
        nodes.add(mag.getNode(name));
    }
    DoubleMatrix2D omega = factory.diagonal(factory.diagonal(S));
    DoubleMatrix2D B = factory.identity(p);
    int[] ug = ugNodes(mag, nodes);
    int[] ugComp = complement(p, ug);
    if (ug.length > 0) {
        List<Node> _ugNodes = new LinkedList<>();
        for (int i : ug) {
            _ugNodes.add(nodes.get(i));
        }
        Graph ugGraph = mag.subgraph(_ugNodes);
        ICovarianceMatrix ugCov = covMatrix.getSubmatrix(ug);
        DoubleMatrix2D lambdaInv = fitConGraph(ugGraph, ugCov, p + 1, tolerance).shat;
        omega.viewSelection(ug, ug).assign(lambdaInv);
    }
    // Prepare lists of parents and spouses.
    int[][] pars = parentIndices(p, mag, nodes);
    int[][] spo = spouseIndices(p, mag, nodes);
    int i = 0;
    double _diff;
    while (true) {
        i++;
        DoubleMatrix2D omegaOld = omega.copy();
        DoubleMatrix2D bOld = B.copy();
        for (int _v = 0; _v < p; _v++) {
            // Exclude the UG part.
            if (Arrays.binarySearch(ug, _v) >= 0) {
                continue;
            }
            int[] v = new int[] { _v };
            int[] vcomp = complement(p, v);
            int[] all = range(0, p - 1);
            int[] parv = pars[_v];
            int[] spov = spo[_v];
            DoubleMatrix2D a6 = B.viewSelection(v, parv);
            if (spov.length == 0) {
                if (parv.length != 0) {
                    if (i == 1) {
                        DoubleMatrix2D a1 = S.viewSelection(parv, parv);
                        DoubleMatrix2D a2 = S.viewSelection(v, parv);
                        DoubleMatrix2D a3 = algebra.inverse(a1);
                        DoubleMatrix2D a4 = algebra.mult(a2, a3);
                        a4.assign(Mult.mult(-1));
                        a6.assign(a4);
                        DoubleMatrix2D a7 = S.viewSelection(parv, v);
                        DoubleMatrix2D a9 = algebra.mult(a6, a7);
                        DoubleMatrix2D a8 = S.viewSelection(v, v);
                        DoubleMatrix2D a8b = omega.viewSelection(v, v);
                        a8b.assign(a8);
                        omega.viewSelection(v, v).assign(a9, PlusMult.plusMult(1));
                    }
                }
            } else {
                if (parv.length != 0) {
                    DoubleMatrix2D oInv = new DenseDoubleMatrix2D(p, p);
                    DoubleMatrix2D a2 = omega.viewSelection(vcomp, vcomp);
                    DoubleMatrix2D a3 = algebra.inverse(a2);
                    oInv.viewSelection(vcomp, vcomp).assign(a3);
                    DoubleMatrix2D Z = algebra.mult(oInv.viewSelection(spov, vcomp), B.viewSelection(vcomp, all));
                    int lpa = parv.length;
                    int lspo = spov.length;
                    // Build XX
                    DoubleMatrix2D XX = new DenseDoubleMatrix2D(lpa + lspo, lpa + lspo);
                    int[] range1 = range(0, lpa - 1);
                    int[] range2 = range(lpa, lpa + lspo - 1);
                    // Upper left quadrant
                    XX.viewSelection(range1, range1).assign(S.viewSelection(parv, parv));
                    // Upper right quadrant
                    DoubleMatrix2D a11 = algebra.mult(S.viewSelection(parv, all), algebra.transpose(Z));
                    XX.viewSelection(range1, range2).assign(a11);
                    // Lower left quadrant
                    DoubleMatrix2D a12 = XX.viewSelection(range2, range1);
                    DoubleMatrix2D a13 = algebra.transpose(XX.viewSelection(range1, range2));
                    a12.assign(a13);
                    // Lower right quadrant
                    DoubleMatrix2D a14 = XX.viewSelection(range2, range2);
                    DoubleMatrix2D a15 = algebra.mult(Z, S);
                    DoubleMatrix2D a16 = algebra.mult(a15, algebra.transpose(Z));
                    a14.assign(a16);
                    // Build XY
                    DoubleMatrix1D YX = new DenseDoubleMatrix1D(lpa + lspo);
                    DoubleMatrix1D a17 = YX.viewSelection(range1);
                    DoubleMatrix1D a18 = S.viewSelection(v, parv).viewRow(0);
                    a17.assign(a18);
                    DoubleMatrix1D a19 = YX.viewSelection(range2);
                    DoubleMatrix2D a20 = S.viewSelection(v, all);
                    DoubleMatrix1D a21 = algebra.mult(a20, algebra.transpose(Z)).viewRow(0);
                    a19.assign(a21);
                    // Temp
                    DoubleMatrix2D a22 = algebra.inverse(XX);
                    DoubleMatrix1D temp = algebra.mult(algebra.transpose(a22), YX);
                    // Assign to b.
                    DoubleMatrix1D a23 = a6.viewRow(0);
                    DoubleMatrix1D a24 = temp.viewSelection(range1);
                    a23.assign(a24);
                    a23.assign(Mult.mult(-1));
                    // Assign to omega.
                    omega.viewSelection(v, spov).viewRow(0).assign(temp.viewSelection(range2));
                    omega.viewSelection(spov, v).viewColumn(0).assign(temp.viewSelection(range2));
                    // Variance.
                    double tempVar = S.get(_v, _v) - algebra.mult(temp, YX);
                    DoubleMatrix2D a27 = omega.viewSelection(v, spov);
                    DoubleMatrix2D a28 = oInv.viewSelection(spov, spov);
                    DoubleMatrix2D a29 = omega.viewSelection(spov, v).copy();
                    DoubleMatrix2D a30 = algebra.mult(a27, a28);
                    DoubleMatrix2D a31 = algebra.mult(a30, a29);
                    omega.viewSelection(v, v).assign(tempVar);
                    omega.viewSelection(v, v).assign(a31, PlusMult.plusMult(1));
                } else {
                    DoubleMatrix2D oInv = new DenseDoubleMatrix2D(p, p);
                    DoubleMatrix2D a2 = omega.viewSelection(vcomp, vcomp);
                    DoubleMatrix2D a3 = algebra.inverse(a2);
                    oInv.viewSelection(vcomp, vcomp).assign(a3);
                    // System.out.println("O.inv = " + oInv);
                    DoubleMatrix2D a4 = oInv.viewSelection(spov, vcomp);
                    DoubleMatrix2D a5 = B.viewSelection(vcomp, all);
                    DoubleMatrix2D Z = algebra.mult(a4, a5);
                    // System.out.println("Z = " + Z);
                    // Build XX
                    DoubleMatrix2D XX = algebra.mult(algebra.mult(Z, S), Z.viewDice());
                    // System.out.println("XX = " + XX);
                    // Build XY
                    DoubleMatrix2D a20 = S.viewSelection(v, all);
                    DoubleMatrix1D YX = algebra.mult(a20, Z.viewDice()).viewRow(0);
                    // System.out.println("YX = " + YX);
                    // Temp
                    DoubleMatrix2D a22 = algebra.inverse(XX);
                    DoubleMatrix1D a23 = algebra.mult(algebra.transpose(a22), YX);
                    // Assign to omega.
                    DoubleMatrix1D a24 = omega.viewSelection(v, spov).viewRow(0);
                    a24.assign(a23);
                    DoubleMatrix1D a25 = omega.viewSelection(spov, v).viewColumn(0);
                    a25.assign(a23);
                    // System.out.println("Omega 2 " + omega);
                    // Variance.
                    double tempVar = S.get(_v, _v) - algebra.mult(a24, YX);
                    // System.out.println("tempVar = " + tempVar);
                    DoubleMatrix2D a27 = omega.viewSelection(v, spov);
                    DoubleMatrix2D a28 = oInv.viewSelection(spov, spov);
                    DoubleMatrix2D a29 = omega.viewSelection(spov, v).copy();
                    DoubleMatrix2D a30 = algebra.mult(a27, a28);
                    DoubleMatrix2D a31 = algebra.mult(a30, a29);
                    omega.set(_v, _v, tempVar + a31.get(0, 0));
                // System.out.println("Omega final " + omega);
                }
            }
        }
        DoubleMatrix2D a32 = omega.copy();
        a32.assign(omegaOld, PlusMult.plusMult(-1));
        double diff1 = algebra.norm1(a32);
        DoubleMatrix2D a33 = B.copy();
        a33.assign(bOld, PlusMult.plusMult(-1));
        double diff2 = algebra.norm1(a32);
        double diff = diff1 + diff2;
        _diff = diff;
        if (diff < tolerance)
            break;
    }
    DoubleMatrix2D a34 = algebra.inverse(B);
    DoubleMatrix2D a35 = algebra.inverse(B.viewDice());
    DoubleMatrix2D sigmahat = algebra.mult(algebra.mult(a34, omega), a35);
    DoubleMatrix2D lambdahat = omega.copy();
    DoubleMatrix2D a36 = lambdahat.viewSelection(ugComp, ugComp);
    a36.assign(factory.make(ugComp.length, ugComp.length, 0.0));
    DoubleMatrix2D omegahat = omega.copy();
    DoubleMatrix2D a37 = omegahat.viewSelection(ug, ug);
    a37.assign(factory.make(ug.length, ug.length, 0.0));
    DoubleMatrix2D bhat = B.copy();
    return new RicfResult(sigmahat, lambdahat, bhat, omegahat, i, _diff, covMatrix);
}
Also used : ICovarianceMatrix(edu.cmu.tetrad.data.ICovarianceMatrix) Node(edu.cmu.tetrad.graph.Node) DoubleFactory2D(cern.colt.matrix.DoubleFactory2D) Endpoint(edu.cmu.tetrad.graph.Endpoint) Algebra(cern.colt.matrix.linalg.Algebra) SemGraph(edu.cmu.tetrad.graph.SemGraph) Graph(edu.cmu.tetrad.graph.Graph) DoubleMatrix2D(cern.colt.matrix.DoubleMatrix2D) DenseDoubleMatrix2D(cern.colt.matrix.impl.DenseDoubleMatrix2D) DoubleMatrix1D(cern.colt.matrix.DoubleMatrix1D) DenseDoubleMatrix1D(cern.colt.matrix.impl.DenseDoubleMatrix1D) DenseDoubleMatrix2D(cern.colt.matrix.impl.DenseDoubleMatrix2D) DenseDoubleMatrix1D(cern.colt.matrix.impl.DenseDoubleMatrix1D)

Aggregations

ICovarianceMatrix (edu.cmu.tetrad.data.ICovarianceMatrix)30 DataSet (edu.cmu.tetrad.data.DataSet)12 CovarianceMatrix (edu.cmu.tetrad.data.CovarianceMatrix)10 ArrayList (java.util.ArrayList)9 DataModel (edu.cmu.tetrad.data.DataModel)8 Node (edu.cmu.tetrad.graph.Node)8 Graph (edu.cmu.tetrad.graph.Graph)7 SemPm (edu.cmu.tetrad.sem.SemPm)5 TetradMatrix (edu.cmu.tetrad.util.TetradMatrix)5 Test (org.junit.Test)5 Parameters (edu.cmu.tetrad.util.Parameters)4 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)3 SemIm (edu.cmu.tetrad.sem.SemIm)3 DoubleFactory2D (cern.colt.matrix.DoubleFactory2D)2 DoubleMatrix1D (cern.colt.matrix.DoubleMatrix1D)2 DoubleMatrix2D (cern.colt.matrix.DoubleMatrix2D)2 DenseDoubleMatrix1D (cern.colt.matrix.impl.DenseDoubleMatrix1D)2 DenseDoubleMatrix2D (cern.colt.matrix.impl.DenseDoubleMatrix2D)2 Algebra (cern.colt.matrix.linalg.Algebra)2 ColtDataSet (edu.cmu.tetrad.data.ColtDataSet)2