Search in sources :

Example 6 with Regression

use of edu.cmu.tetrad.regression.Regression in project tetrad by cmu-phil.

the class LingamPattern2 method getScore1.

// =============================PRIVATE METHODS=========================//
private Score getScore1(Graph dag, List<TetradMatrix> data, List<Node> variables) {
    // System.out.println("Scoring DAG: " + dag);
    List<Regression> regressions = new ArrayList<>();
    for (TetradMatrix _data : data) {
        regressions.add(new RegressionDataset(_data, variables));
    }
    int totalSampleSize = 0;
    for (TetradMatrix _data : data) {
        totalSampleSize += _data.rows();
    }
    int numCols = data.get(0).columns();
    List<Node> nodes = dag.getNodes();
    double score = 0.0;
    double[] pValues = new double[nodes.size()];
    TetradMatrix absoluteStandardizedResiduals = new TetradMatrix(totalSampleSize, numCols);
    for (int i = 0; i < nodes.size(); i++) {
        List<Double> _absoluteStandardizedResiduals = new ArrayList<>();
        for (int j = 0; j < data.size(); j++) {
            Node _target = nodes.get(i);
            List<Node> _regressors = dag.getParents(_target);
            Node target = getVariable(variables, _target.getName());
            List<Node> regressors = new ArrayList<>();
            for (Node _regressor : _regressors) {
                Node variable = getVariable(variables, _regressor.getName());
                regressors.add(variable);
            }
            RegressionResult result = regressions.get(j).regress(target, regressors);
            TetradVector residualsColumn = result.getResiduals();
            DoubleArrayList _absoluteStandardizedResidualsColumn = new DoubleArrayList(residualsColumn.toArray());
            double mean = Descriptive.mean(_absoluteStandardizedResidualsColumn);
            double std = Descriptive.standardDeviation(Descriptive.variance(_absoluteStandardizedResidualsColumn.size(), Descriptive.sum(_absoluteStandardizedResidualsColumn), Descriptive.sumOfSquares(_absoluteStandardizedResidualsColumn)));
            for (int i2 = 0; i2 < _absoluteStandardizedResidualsColumn.size(); i2++) {
                _absoluteStandardizedResidualsColumn.set(i2, (_absoluteStandardizedResidualsColumn.get(i2) - mean) / std);
                _absoluteStandardizedResidualsColumn.set(i2, Math.abs(_absoluteStandardizedResidualsColumn.get(i2)));
            }
            for (int k = 0; k < _absoluteStandardizedResidualsColumn.size(); k++) {
                _absoluteStandardizedResiduals.add(_absoluteStandardizedResidualsColumn.get(k));
            }
        }
        DoubleArrayList absoluteStandardResidualsList = new DoubleArrayList(absoluteStandardizedResiduals.getColumn(i).toArray());
        for (int k = 0; k < _absoluteStandardizedResiduals.size(); k++) {
            absoluteStandardizedResiduals.set(k, i, _absoluteStandardizedResiduals.get(k));
        }
        double _mean = Descriptive.mean(absoluteStandardResidualsList);
        double diff = _mean - Math.sqrt(2.0 / Math.PI);
        score += diff * diff;
    }
    for (int j = 0; j < absoluteStandardizedResiduals.columns(); j++) {
        double[] x = absoluteStandardizedResiduals.getColumn(j).toArray();
        double p = new AndersonDarlingTest(x).getP();
        pValues[j] = p;
    }
    return new Score(score, pValues);
}
Also used : Regression(edu.cmu.tetrad.regression.Regression) DoubleArrayList(cern.colt.list.DoubleArrayList) ArrayList(java.util.ArrayList) TetradMatrix(edu.cmu.tetrad.util.TetradMatrix) DoubleArrayList(cern.colt.list.DoubleArrayList) RegressionDataset(edu.cmu.tetrad.regression.RegressionDataset) TetradVector(edu.cmu.tetrad.util.TetradVector) AndersonDarlingTest(edu.cmu.tetrad.data.AndersonDarlingTest) RegressionResult(edu.cmu.tetrad.regression.RegressionResult)

Example 7 with Regression

use of edu.cmu.tetrad.regression.Regression in project tetrad by cmu-phil.

the class LingamPattern method getScore.

// =============================PRIVATE METHODS=========================//
private Score getScore(Graph dag, TetradMatrix data, List<Node> variables) {
    // System.out.println("Scoring DAG: " + dag);
    Regression regression = new RegressionDataset(data, variables);
    List<Node> nodes = dag.getNodes();
    double score = 0.0;
    double[] pValues = new double[nodes.size()];
    TetradMatrix residuals = new TetradMatrix(data.rows(), data.columns());
    for (int i = 0; i < nodes.size(); i++) {
        Node _target = nodes.get(i);
        List<Node> _regressors = dag.getParents(_target);
        Node target = getVariable(variables, _target.getName());
        List<Node> regressors = new ArrayList<>();
        for (Node _regressor : _regressors) {
            Node variable = getVariable(variables, _regressor.getName());
            regressors.add(variable);
        }
        RegressionResult result = regression.regress(target, regressors);
        TetradVector residualsColumn = result.getResiduals();
        // residuals.viewColumn(i).assign(residualsColumn);
        residuals.assignColumn(i, residualsColumn);
        DoubleArrayList residualsArray = new DoubleArrayList(residualsColumn.toArray());
        double mean = Descriptive.mean(residualsArray);
        double std = Descriptive.standardDeviation(Descriptive.variance(residualsArray.size(), Descriptive.sum(residualsArray), Descriptive.sumOfSquares(residualsArray)));
        for (int i2 = 0; i2 < residualsArray.size(); i2++) {
            residualsArray.set(i2, (residualsArray.get(i2) - mean) / std);
            residualsArray.set(i2, Math.abs(residualsArray.get(i2)));
        }
        double _mean = Descriptive.mean(residualsArray);
        double diff = _mean - Math.sqrt(2.0 / Math.PI);
        score += diff * diff;
    }
    for (int j = 0; j < residuals.columns(); j++) {
        double[] x = residuals.getColumn(j).toArray();
        double p = new AndersonDarlingTest(x).getP();
        pValues[j] = p;
    }
    return new Score(score, pValues);
}
Also used : Regression(edu.cmu.tetrad.regression.Regression) DoubleArrayList(cern.colt.list.DoubleArrayList) ArrayList(java.util.ArrayList) TetradMatrix(edu.cmu.tetrad.util.TetradMatrix) DoubleArrayList(cern.colt.list.DoubleArrayList) RegressionDataset(edu.cmu.tetrad.regression.RegressionDataset) TetradVector(edu.cmu.tetrad.util.TetradVector) AndersonDarlingTest(edu.cmu.tetrad.data.AndersonDarlingTest) RegressionResult(edu.cmu.tetrad.regression.RegressionResult)

Example 8 with Regression

use of edu.cmu.tetrad.regression.Regression in project tetrad by cmu-phil.

the class SampleVcpcFast method search.

/**
 * Runs PC starting with a fully connected graph over all of the variables in the domain of the independence test.
 * See PC for caveats. The number of possible cycles and bidirected edges is far less with CPC than with PC.
 */
// public final Graph search() {
// return search(independenceTest.getVariable());
// }
// //    public Graph search(List<Node> nodes) {
// //
// ////        return search(new FasICov2(getIndependenceTest()), nodes);
// ////        return search(new Fas(getIndependenceTest()), nodes);
// //        return search(new Fas(getIndependenceTest()), nodes);
// }
// modified FAS into VCFAS; added in definitelyNonadjacencies set of edges.
public Graph search() {
    this.logger.log("info", "Starting VCCPC algorithm");
    this.logger.log("info", "Independence test = " + getIndependenceTest() + ".");
    this.allTriples = new HashSet<>();
    this.ambiguousTriples = new HashSet<>();
    this.colliderTriples = new HashSet<>();
    this.noncolliderTriples = new HashSet<>();
    Vcfas fas = new Vcfas(getIndependenceTest());
    definitelyNonadjacencies = new HashSet<>();
    markovInAllPatterns = new HashSet<>();
    // this.logger.log("info", "Variables " + independenceTest.getVariable());
    long startTime = System.currentTimeMillis();
    if (getIndependenceTest() == null) {
        throw new NullPointerException();
    }
    List<Node> allNodes = getIndependenceTest().getVariables();
    // if (!allNodes.containsAll(nodes)) {
    // throw new IllegalArgumentException("All of the given nodes must " +
    // "be in the domain of the independence test provided.");
    // }
    // Fas fas = new Fas(graph, getIndependenceTest());
    // FasStableConcurrent fas = new FasStableConcurrent(graph, getIndependenceTest());
    // Fas6 fas = new Fas6(graph, getIndependenceTest());
    // fas = new FasICov(graph, (IndTestFisherZ) getIndependenceTest());
    fas.setKnowledge(getKnowledge());
    fas.setDepth(getDepth());
    fas.setVerbose(verbose);
    // Note that we are ignoring the sepset map returned by this method
    // on purpose; it is not used in this search.
    graph = fas.search();
    apparentlyNonadjacencies = fas.getApparentlyNonadjacencies();
    if (isDoOrientation()) {
        if (verbose) {
            System.out.println("CPC orientation...");
        }
        SearchGraphUtils.pcOrientbk(knowledge, graph, allNodes);
        orientUnshieldedTriples(knowledge, getIndependenceTest(), getDepth());
        // orientUnshieldedTriplesConcurrent(knowledge, getIndependenceTest(), getMaxIndegree());
        MeekRules meekRules = new MeekRules();
        meekRules.setAggressivelyPreventCycles(this.aggressivelyPreventCycles);
        meekRules.setKnowledge(knowledge);
        meekRules.orientImplied(graph);
    }
    List<Triple> ambiguousTriples = new ArrayList(graph.getAmbiguousTriples());
    int[] dims = new int[ambiguousTriples.size()];
    for (int i = 0; i < ambiguousTriples.size(); i++) {
        dims[i] = 2;
    }
    List<Graph> patterns = new ArrayList<>();
    Map<Graph, List<Triple>> newColliders = new IdentityHashMap<>();
    Map<Graph, List<Triple>> newNonColliders = new IdentityHashMap<>();
    // Using combination generator to generate a list of combinations of ambiguous triples dismabiguated into colliders
    // and non-colliders. The combinations are added as graphs to the list patterns. The graphs are then subject to
    // basic rules to ensure consistent patterns.
    CombinationGenerator generator = new CombinationGenerator(dims);
    int[] combination;
    while ((combination = generator.next()) != null) {
        Graph _graph = new EdgeListGraph(graph);
        newColliders.put(_graph, new ArrayList<Triple>());
        newNonColliders.put(_graph, new ArrayList<Triple>());
        for (Graph graph : newColliders.keySet()) {
        // System.out.println("$$$ " + newColliders.get(graph));
        }
        for (int k = 0; k < combination.length; k++) {
            // System.out.println("k = " + combination[k]);
            Triple triple = ambiguousTriples.get(k);
            _graph.removeAmbiguousTriple(triple.getX(), triple.getY(), triple.getZ());
            if (combination[k] == 0) {
                newColliders.get(_graph).add(triple);
                // System.out.println(newColliders.get(_graph));
                Node x = triple.getX();
                Node y = triple.getY();
                Node z = triple.getZ();
                _graph.setEndpoint(x, y, Endpoint.ARROW);
                _graph.setEndpoint(z, y, Endpoint.ARROW);
            }
            if (combination[k] == 1) {
                newNonColliders.get(_graph).add(triple);
            }
        }
        patterns.add(_graph);
    }
    List<Graph> _patterns = new ArrayList<>(patterns);
    GRAPH: for (Graph graph : new ArrayList<>(patterns)) {
        // _graph = new EdgeListGraph(graph);
        // System.out.println("graph = " + graph + " in keyset? " + newColliders.containsKey(graph));
        // 
        List<Triple> colliders = newColliders.get(graph);
        List<Triple> nonColliders = newNonColliders.get(graph);
        for (Triple triple : colliders) {
            Node x = triple.getX();
            Node y = triple.getY();
            Node z = triple.getZ();
            if (graph.getEdge(x, y).pointsTowards(x) || (graph.getEdge(y, z).pointsTowards(z))) {
                patterns.remove(graph);
                continue GRAPH;
            }
        }
        for (Triple triple : colliders) {
            Node x = triple.getX();
            Node y = triple.getY();
            Node z = triple.getZ();
            graph.setEndpoint(x, y, Endpoint.ARROW);
            graph.setEndpoint(z, y, Endpoint.ARROW);
        }
        for (Triple triple : nonColliders) {
            Node x = triple.getX();
            Node y = triple.getY();
            Node z = triple.getZ();
            if (graph.getEdge(x, y).pointsTowards(y)) {
                graph.removeEdge(y, z);
                graph.addDirectedEdge(y, z);
            }
            if (graph.getEdge(y, z).pointsTowards(y)) {
                graph.removeEdge(x, y);
                graph.addDirectedEdge(y, x);
            }
        }
        for (Edge edge : graph.getEdges()) {
            Node x = edge.getNode1();
            Node y = edge.getNode2();
            if (Edges.isBidirectedEdge(edge)) {
                graph.removeEdge(x, y);
                graph.addUndirectedEdge(x, y);
            }
        }
        // for (Edge edge : graph.getEdges()) {
        // if (Edges.isBidirectedEdge(edge)) {
        // patterns.remove(graph);
        // continue Graph;
        // }
        // }
        MeekRules rules = new MeekRules();
        rules.orientImplied(graph);
        if (graph.existsDirectedCycle()) {
            patterns.remove(graph);
            continue GRAPH;
        }
    }
    MARKOV: for (Edge edge : apparentlyNonadjacencies.keySet()) {
        Node x = edge.getNode1();
        Node y = edge.getNode2();
        for (Graph _graph : new ArrayList<>(patterns)) {
            List<Node> boundaryX = new ArrayList<>(boundary(x, _graph));
            List<Node> boundaryY = new ArrayList<>(boundary(y, _graph));
            List<Node> futureX = new ArrayList<>(future(x, _graph));
            List<Node> futureY = new ArrayList<>(future(y, _graph));
            if (y == x) {
                continue;
            }
            if (boundaryX.contains(y) || boundaryY.contains(x)) {
                continue;
            }
            IndependenceTest test = independenceTest;
            if (!futureX.contains(y)) {
                if (!test.isIndependent(x, y, boundaryX)) {
                    continue MARKOV;
                }
            }
            if (!futureY.contains(x)) {
                if (!test.isIndependent(y, x, boundaryY)) {
                    continue MARKOV;
                }
            }
        }
        definitelyNonadjacencies.add(edge);
    // apparentlyNonadjacencies.remove(edge);
    }
    for (Edge edge : definitelyNonadjacencies) {
        if (apparentlyNonadjacencies.keySet().contains(edge)) {
            apparentlyNonadjacencies.keySet().remove(edge);
        }
    }
    setSemIm(semIm);
    // semIm.getSemPm().getGraph();
    // System.out.println(semIm.getEdgeCoef());
    // graph = DataGraphUtils.replaceNodes(graph, semIm.getVariableNodes());
    // System.out.println(semIm.getEdgeCoef());
    // System.out.println(sampleRegress.entrySet());
    List<Double> squaredDifference = new ArrayList<>();
    int numNullEdges = 0;
    // //Edge Estimation Alg I
    Regression sampleRegression = new RegressionDataset(dataSet);
    System.out.println(sampleRegression.getGraph());
    graph = GraphUtils.replaceNodes(graph, dataSet.getVariables());
    Map<Edge, double[]> sampleRegress = new HashMap<>();
    Map<Edge, Double> edgeCoefs = new HashMap<>();
    ESTIMATION: for (Node z : graph.getNodes()) {
        Set<Edge> adj = getAdj(z, graph);
        for (Edge edge : apparentlyNonadjacencies.keySet()) {
            if (z == edge.getNode1() || z == edge.getNode2()) {
                for (Edge adjacency : adj) {
                    // return Unknown and go to next Z
                    sampleRegress.put(adjacency, null);
                    Node a = adjacency.getNode1();
                    Node b = adjacency.getNode2();
                    if (semIm.existsEdgeCoef(a, b)) {
                        Double c = semIm.getEdgeCoef(a, b);
                        edgeCoefs.put(adjacency, c);
                    } else {
                        edgeCoefs.put(adjacency, 0.0);
                    }
                }
                continue ESTIMATION;
            }
        }
        for (Edge nonadj : definitelyNonadjacencies) {
            if (nonadj.getNode1() == z || nonadj.getNode2() == z) {
                // return 0 for e
                double[] d = { 0, 0 };
                sampleRegress.put(nonadj, d);
                Node a = nonadj.getNode1();
                Node b = nonadj.getNode2();
                if (semIm.existsEdgeCoef(a, b)) {
                    Double c = semIm.getEdgeCoef(a, b);
                    edgeCoefs.put(nonadj, c);
                } else {
                    edgeCoefs.put(nonadj, 0.0);
                }
            }
        }
        Set<Edge> parentsOfZ = new HashSet<>();
        Set<Edge> _adj = getAdj(z, graph);
        for (Edge _adjacency : _adj) {
            if (!_adjacency.isDirected()) {
                for (Edge adjacency : adj) {
                    sampleRegress.put(adjacency, null);
                    Node a = adjacency.getNode1();
                    Node b = adjacency.getNode2();
                    if (semIm.existsEdgeCoef(a, b)) {
                        Double c = semIm.getEdgeCoef(a, b);
                        edgeCoefs.put(adjacency, c);
                    } else {
                        edgeCoefs.put(adjacency, 0.0);
                    }
                }
            }
            if (_adjacency.pointsTowards(z)) {
                parentsOfZ.add(_adjacency);
            }
        }
        for (Edge edge : parentsOfZ) {
            if (edge.pointsTowards(edge.getNode2())) {
                RegressionResult result = sampleRegression.regress(edge.getNode2(), edge.getNode1());
                System.out.println(result);
                double[] d = result.getCoef();
                sampleRegress.put(edge, d);
                Node a = edge.getNode1();
                Node b = edge.getNode2();
                if (semIm.existsEdgeCoef(a, b)) {
                    Double c = semIm.getEdgeCoef(a, b);
                    edgeCoefs.put(edge, c);
                } else {
                    edgeCoefs.put(edge, 0.0);
                }
            }
        // if (edge.pointsTowards(edge.getNode2())) {
        // RegressionResult result = sampleRegression.regress(edge.getNode2(), edge.getNode1());
        // double[] d = result.getCoef();
        // sampleRegress.put(edge, d);
        // 
        // Node a = edge.getNode1();
        // Node b = edge.getNode2();
        // if (semIm.existsEdgeCoef(a, b)) {
        // Double c = semIm.getEdgeCoef(a, b);
        // edgeCoefs.put(edge, c);
        // } else { edgeCoefs.put(edge, 0.0); }
        // }
        }
    }
    System.out.println("All IM: " + semIm + "Finish");
    System.out.println("Just IM coefs: " + semIm.getEdgeCoef());
    System.out.println("IM Coef Map: " + edgeCoefs);
    System.out.println("Regress Coef Map: " + sampleRegress);
    // 
    for (Edge edge : sampleRegress.keySet()) {
        System.out.println(" Sample Regression: " + edge + Arrays.toString(sampleRegress.get(edge)));
    }
    for (Edge edge : graph.getEdges()) {
        // if (edge.isDirected()) {
        // System.out.println("IM edge: " + semIm.getEdgeCoef(edge));
        // }
        System.out.println("Sample edge: " + Arrays.toString(sampleRegress.get(edge)));
    }
    // 
    // 
    System.out.println("Sample VCPC:");
    System.out.println("# of patterns: " + patterns.size());
    long endTime = System.currentTimeMillis();
    this.elapsedTime = endTime - startTime;
    System.out.println("Search Time (seconds):" + (elapsedTime) / 1000 + " s");
    System.out.println("Search Time (milli):" + elapsedTime + " ms");
    System.out.println("# of Apparent Nonadj: " + apparentlyNonadjacencies.size());
    System.out.println("# of Definite Nonadj: " + definitelyNonadjacencies.size());
    // System.out.println("Definitely Nonadjacencies:");
    // 
    // for (Edge edge : definitelyNonadjacencies) {
    // System.out.println(edge);
    // }
    // 
    // System.out.println("markov in all patterns:" + markovInAllPatterns);
    // System.out.println("patterns:" + patterns);
    // System.out.println("Apparently Nonadjacencies:");
    // 
    // for (Edge edge : apparentlyNonadjacencies.keySet()) {
    // System.out.println(edge);
    // }
    // System.out.println("Definitely Nonadjacencies:");
    // 
    // 
    // for (Edge edge : definitelyNonadjacencies) {
    // System.out.println(edge);
    // }
    TetradLogger.getInstance().log("apparentlyNonadjacencies", "\n Apparent Non-adjacencies" + apparentlyNonadjacencies);
    TetradLogger.getInstance().log("definitelyNonadjacencies", "\n Definite Non-adjacencies" + definitelyNonadjacencies);
    TetradLogger.getInstance().log("patterns", "Disambiguated Patterns: " + patterns);
    TetradLogger.getInstance().log("graph", "\nReturning this graph: " + graph);
    TetradLogger.getInstance().log("info", "Elapsed time = " + (elapsedTime) / 1000. + " s");
    TetradLogger.getInstance().log("info", "Finishing CPC algorithm.");
    logTriples();
    TetradLogger.getInstance().flush();
    // SearchGraphUtils.verifySepsetIntegrity(Map<Edge, List<Node>>, graph);
    return graph;
}
Also used : ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) RegressionDataset(edu.cmu.tetrad.regression.RegressionDataset) RegressionResult(edu.cmu.tetrad.regression.RegressionResult) CombinationGenerator(edu.cmu.tetrad.util.CombinationGenerator) Regression(edu.cmu.tetrad.regression.Regression)

Example 9 with Regression

use of edu.cmu.tetrad.regression.Regression in project tetrad by cmu-phil.

the class LingamPattern2 method getScore2.

private Score getScore2(Graph dag, List<TetradMatrix> data, List<Node> variables) {
    // System.out.println("Scoring DAG: " + dag);
    List<Regression> regressions = new ArrayList<>();
    for (TetradMatrix _data : data) {
        regressions.add(new RegressionDataset(_data, variables));
    }
    int totalSampleSize = 0;
    for (TetradMatrix _data : data) {
        totalSampleSize += _data.rows();
    }
    int numCols = data.get(0).columns();
    List<Node> nodes = dag.getNodes();
    double score = 0.0;
    double[] pValues = new double[nodes.size()];
    TetradMatrix residuals = new TetradMatrix(totalSampleSize, numCols);
    for (int j = 0; j < nodes.size(); j++) {
        List<Double> _residuals = new ArrayList<>();
        Node _target = nodes.get(j);
        List<Node> _regressors = dag.getParents(_target);
        Node target = getVariable(variables, _target.getName());
        List<Node> regressors = new ArrayList<>();
        for (Node _regressor : _regressors) {
            Node variable = getVariable(variables, _regressor.getName());
            regressors.add(variable);
        }
        for (int m = 0; m < data.size(); m++) {
            RegressionResult result = regressions.get(m).regress(target, regressors);
            TetradVector residualsSingleDataset = result.getResiduals();
            DoubleArrayList _residualsSingleDataset = new DoubleArrayList(residualsSingleDataset.toArray());
            double mean = Descriptive.mean(_residualsSingleDataset);
            double std = Descriptive.standardDeviation(Descriptive.variance(_residualsSingleDataset.size(), Descriptive.sum(_residualsSingleDataset), Descriptive.sumOfSquares(_residualsSingleDataset)));
            for (int i2 = 0; i2 < _residualsSingleDataset.size(); i2++) {
                _residualsSingleDataset.set(i2, (_residualsSingleDataset.get(i2) - mean) / std);
            }
            for (int k = 0; k < _residualsSingleDataset.size(); k++) {
                _residuals.add(_residualsSingleDataset.get(k));
            }
        }
        for (int k = 0; k < _residuals.size(); k++) {
            residuals.set(k, j, _residuals.get(k));
        }
    }
    for (int i = 0; i < nodes.size(); i++) {
        DoubleArrayList f = new DoubleArrayList(residuals.getColumn(i).toArray());
        for (int j = 0; j < f.size(); j++) {
            f.set(j, Math.abs(f.get(j)));
        }
        double _mean = Descriptive.mean(f);
        double diff = _mean - Math.sqrt(2.0 / Math.PI);
        score += diff * diff;
    }
    for (int j = 0; j < residuals.columns(); j++) {
        double[] x = residuals.getColumn(j).toArray();
        double p = new AndersonDarlingTest(x).getP();
        pValues[j] = p;
    }
    return new Score(score, pValues);
}
Also used : Regression(edu.cmu.tetrad.regression.Regression) DoubleArrayList(cern.colt.list.DoubleArrayList) ArrayList(java.util.ArrayList) TetradMatrix(edu.cmu.tetrad.util.TetradMatrix) DoubleArrayList(cern.colt.list.DoubleArrayList) RegressionDataset(edu.cmu.tetrad.regression.RegressionDataset) TetradVector(edu.cmu.tetrad.util.TetradVector) AndersonDarlingTest(edu.cmu.tetrad.data.AndersonDarlingTest) RegressionResult(edu.cmu.tetrad.regression.RegressionResult)

Example 10 with Regression

use of edu.cmu.tetrad.regression.Regression in project tetrad by cmu-phil.

the class Lofs2 method getRegressions.

// ==========================PRIVATE=======================================//
private List<Regression> getRegressions() {
    if (this.regressions == null) {
        List<Regression> regressions = new ArrayList<>();
        this.variables = dataSets.get(0).getVariables();
        for (DataSet dataSet : dataSets) {
            regressions.add(new RegressionDataset(dataSet));
        }
        this.regressions = regressions;
    }
    return this.regressions;
}
Also used : RegressionDataset(edu.cmu.tetrad.regression.RegressionDataset) OLSMultipleLinearRegression(org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression) Regression(edu.cmu.tetrad.regression.Regression)

Aggregations

Regression (edu.cmu.tetrad.regression.Regression)19 RegressionResult (edu.cmu.tetrad.regression.RegressionResult)17 RegressionDataset (edu.cmu.tetrad.regression.RegressionDataset)16 ArrayList (java.util.ArrayList)10 TetradMatrix (edu.cmu.tetrad.util.TetradMatrix)7 TetradVector (edu.cmu.tetrad.util.TetradVector)7 RegressionCovariance (edu.cmu.tetrad.regression.RegressionCovariance)4 DoubleArrayList (cern.colt.list.DoubleArrayList)3 AndersonDarlingTest (edu.cmu.tetrad.data.AndersonDarlingTest)3 Node (edu.cmu.tetrad.graph.Node)3 CombinationGenerator (edu.cmu.tetrad.util.CombinationGenerator)2 ConcurrentHashMap (java.util.concurrent.ConcurrentHashMap)2 OLSMultipleLinearRegression (org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression)2 Test (org.junit.Test)2 CovarianceMatrix (edu.cmu.tetrad.data.CovarianceMatrix)1 ICovarianceMatrix (edu.cmu.tetrad.data.ICovarianceMatrix)1 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)1