Search in sources :

Example 26 with RegressionResult

use of edu.cmu.tetrad.regression.RegressionResult in project tetrad by cmu-phil.

the class SemIm method getStandardError.

public double getStandardError(Parameter parameter, int maxFreeParams) {
    TetradMatrix sampleCovar = getSampleCovar();
    if (sampleCovar == null) {
        return Double.NaN;
    }
    if (getFreeParameters().contains(parameter)) {
        if (getNumFreeParams() <= maxFreeParams) {
            if (parameter.getNodeA() != parameter.getNodeB()) {
                Node nodeA = parameter.getNodeA();
                Node nodeB = parameter.getNodeB();
                Node parent;
                Node child;
                Graph graph = getSemPm().getGraph();
                if (graph.isParentOf(nodeA, nodeB)) {
                    parent = nodeA;
                    child = nodeB;
                } else {
                    parent = nodeB;
                    child = nodeA;
                }
                if (child.getName().startsWith("E_")) {
                    return Double.NaN;
                }
                CovarianceMatrix cov = new CovarianceMatrix(measuredNodes, sampleCovar, sampleSize);
                Regression regression = new RegressionCovariance(cov);
                List<Node> parents = graph.getParents(child);
                for (Node node : new ArrayList<>(parents)) {
                    if (node.getName().startsWith("E_")) {
                        parents.remove(node);
                    }
                }
                if (!(child.getNodeType() == NodeType.LATENT) && !containsLatent(parents)) {
                    RegressionResult result = regression.regress(child, parents);
                    double[] se = result.getSe();
                    return se[parents.indexOf(parent) + 1];
                }
            }
            if (this.sampleCovarC == null) {
                this.standardErrors = null;
                return Double.NaN;
            }
            int index = getFreeParameters().indexOf(parameter);
            double[] doubles = standardErrors();
            if (doubles == null) {
                return Double.NaN;
            }
            return doubles[index];
        } else {
            return Double.NaN;
        }
    } else if (getFixedParameters().contains(parameter)) {
        return 0.0;
    }
    throw new IllegalArgumentException("That is not a parameter of this model: " + parameter);
}
Also used : Regression(edu.cmu.tetrad.regression.Regression) RegressionCovariance(edu.cmu.tetrad.regression.RegressionCovariance) RegressionResult(edu.cmu.tetrad.regression.RegressionResult)

Example 27 with RegressionResult

use of edu.cmu.tetrad.regression.RegressionResult in project tetrad by cmu-phil.

the class SampleVcpc method search.

/**
 * Runs PC starting with a fully connected graph over all of the variables in the domain of the independence test.
 * See PC for caveats. The number of possible cycles and bidirected edges is far less with CPC than with PC.
 */
// public final Graph search() {
// return search(independenceTest.getVariable());
// }
// //    public Graph search(List<Node> nodes) {
// //
// ////        return search(new FasICov2(getIndependenceTest()), nodes);
// ////        return search(new Fas(getIndependenceTest()), nodes);
// //        return search(new Fas(getIndependenceTest()), nodes);
// }
// modified FAS into VCFAS; added in definitelyNonadjacencies set of edges.
public Graph search() {
    this.logger.log("info", "Starting VCCPC algorithm");
    this.logger.log("info", "Independence test = " + getIndependenceTest() + ".");
    this.allTriples = new HashSet<>();
    this.ambiguousTriples = new HashSet<>();
    this.colliderTriples = new HashSet<>();
    this.noncolliderTriples = new HashSet<>();
    Vcfas fas = new Vcfas(getIndependenceTest());
    definitelyNonadjacencies = new HashSet<>();
    markovInAllPatterns = new HashSet<>();
    // this.logger.log("info", "Variables " + independenceTest.getVariable());
    long startTime = System.currentTimeMillis();
    if (getIndependenceTest() == null) {
        throw new NullPointerException();
    }
    List<Node> allNodes = getIndependenceTest().getVariables();
    // if (!allNodes.containsAll(nodes)) {
    // throw new IllegalArgumentException("All of the given nodes must " +
    // "be in the domain of the independence test provided.");
    // }
    // Fas fas = new Fas(graph, getIndependenceTest());
    // FasStableConcurrent fas = new FasStableConcurrent(graph, getIndependenceTest());
    // Fas6 fas = new Fas6(graph, getIndependenceTest());
    // fas = new FasICov(graph, (IndTestFisherZ) getIndependenceTest());
    fas.setKnowledge(getKnowledge());
    fas.setDepth(getDepth());
    fas.setVerbose(verbose);
    // Note that we are ignoring the sepset map returned by this method
    // on purpose; it is not used in this search.
    graph = fas.search();
    apparentlyNonadjacencies = fas.getApparentlyNonadjacencies();
    if (isDoOrientation()) {
        if (verbose) {
            System.out.println("CPC orientation...");
        }
        SearchGraphUtils.pcOrientbk(knowledge, graph, allNodes);
        orientUnshieldedTriples(knowledge, getIndependenceTest(), getDepth());
        // orientUnshieldedTriplesConcurrent(knowledge, getIndependenceTest(), getMaxIndegree());
        MeekRules meekRules = new MeekRules();
        meekRules.setAggressivelyPreventCycles(this.aggressivelyPreventCycles);
        meekRules.setKnowledge(knowledge);
        meekRules.orientImplied(graph);
    }
    List<Triple> ambiguousTriples = new ArrayList(graph.getAmbiguousTriples());
    int[] dims = new int[ambiguousTriples.size()];
    for (int i = 0; i < ambiguousTriples.size(); i++) {
        dims[i] = 2;
    }
    List<Graph> patterns = new ArrayList<>();
    Map<Graph, List<Triple>> newColliders = new IdentityHashMap<>();
    Map<Graph, List<Triple>> newNonColliders = new IdentityHashMap<>();
    // Using combination generator to generate a list of combinations of ambiguous triples dismabiguated into colliders
    // and non-colliders. The combinations are added as graphs to the list patterns. The graphs are then subject to
    // basic rules to ensure consistent patterns.
    CombinationGenerator generator = new CombinationGenerator(dims);
    int[] combination;
    while ((combination = generator.next()) != null) {
        Graph _graph = new EdgeListGraph(graph);
        newColliders.put(_graph, new ArrayList<Triple>());
        newNonColliders.put(_graph, new ArrayList<Triple>());
        for (Graph graph : newColliders.keySet()) {
        // System.out.println("$$$ " + newColliders.get(graph));
        }
        for (int k = 0; k < combination.length; k++) {
            // System.out.println("k = " + combination[k]);
            Triple triple = ambiguousTriples.get(k);
            _graph.removeAmbiguousTriple(triple.getX(), triple.getY(), triple.getZ());
            if (combination[k] == 0) {
                newColliders.get(_graph).add(triple);
                // System.out.println(newColliders.get(_graph));
                Node x = triple.getX();
                Node y = triple.getY();
                Node z = triple.getZ();
                _graph.setEndpoint(x, y, Endpoint.ARROW);
                _graph.setEndpoint(z, y, Endpoint.ARROW);
            }
            if (combination[k] == 1) {
                newNonColliders.get(_graph).add(triple);
            }
        }
        patterns.add(_graph);
    }
    List<Graph> _patterns = new ArrayList<>(patterns);
    GRAPH: for (Graph graph : new ArrayList<>(patterns)) {
        // _graph = new EdgeListGraph(graph);
        // System.out.println("graph = " + graph + " in keyset? " + newColliders.containsKey(graph));
        // 
        List<Triple> colliders = newColliders.get(graph);
        List<Triple> nonColliders = newNonColliders.get(graph);
        for (Triple triple : colliders) {
            Node x = triple.getX();
            Node y = triple.getY();
            Node z = triple.getZ();
            if (graph.getEdge(x, y).pointsTowards(x) || (graph.getEdge(y, z).pointsTowards(z))) {
                patterns.remove(graph);
                continue GRAPH;
            }
        }
        for (Triple triple : colliders) {
            Node x = triple.getX();
            Node y = triple.getY();
            Node z = triple.getZ();
            graph.setEndpoint(x, y, Endpoint.ARROW);
            graph.setEndpoint(z, y, Endpoint.ARROW);
        }
        for (Triple triple : nonColliders) {
            Node x = triple.getX();
            Node y = triple.getY();
            Node z = triple.getZ();
            if (graph.getEdge(x, y).pointsTowards(y)) {
                graph.removeEdge(y, z);
                graph.addDirectedEdge(y, z);
            }
            if (graph.getEdge(y, z).pointsTowards(y)) {
                graph.removeEdge(x, y);
                graph.addDirectedEdge(y, x);
            }
        }
        for (Edge edge : graph.getEdges()) {
            Node x = edge.getNode1();
            Node y = edge.getNode2();
            if (Edges.isBidirectedEdge(edge)) {
                graph.removeEdge(x, y);
                graph.addUndirectedEdge(x, y);
            }
        }
        // for (Edge edge : graph.getEdges()) {
        // if (Edges.isBidirectedEdge(edge)) {
        // patterns.remove(graph);
        // continue Graph;
        // }
        // }
        MeekRules rules = new MeekRules();
        rules.orientImplied(graph);
        if (graph.existsDirectedCycle()) {
            patterns.remove(graph);
            continue GRAPH;
        }
    }
    MARKOV: for (Edge edge : apparentlyNonadjacencies.keySet()) {
        Node x = edge.getNode1();
        Node y = edge.getNode2();
        for (Graph _graph : new ArrayList<>(patterns)) {
            List<Node> boundaryX = new ArrayList<>(boundary(x, _graph));
            List<Node> boundaryY = new ArrayList<>(boundary(y, _graph));
            List<Node> futureX = new ArrayList<>(future(x, _graph));
            List<Node> futureY = new ArrayList<>(future(y, _graph));
            if (y == x) {
                continue;
            }
            if (boundaryX.contains(y) || boundaryY.contains(x)) {
                continue;
            }
            IndependenceTest test = independenceTest;
            if (!futureX.contains(y)) {
                if (!test.isIndependent(x, y, boundaryX)) {
                    continue MARKOV;
                }
            }
            if (!futureY.contains(x)) {
                if (!test.isIndependent(y, x, boundaryY)) {
                    continue MARKOV;
                }
            }
        }
        definitelyNonadjacencies.add(edge);
    // apparentlyNonadjacencies.remove(edge);
    }
    for (Edge edge : definitelyNonadjacencies) {
        if (apparentlyNonadjacencies.keySet().contains(edge)) {
            apparentlyNonadjacencies.keySet().remove(edge);
        }
    }
    setSemIm(semIm);
    // semIm.getSemPm().getGraph();
    System.out.println(semIm.getEdgeCoef());
    // graph = DataGraphUtils.replaceNodes(graph, semIm.getVariableNodes());
    // System.out.println(semIm.getEdgeCoef());
    // System.out.println(sampleRegress.entrySet());
    List<Double> squaredDifference = new ArrayList<>();
    int numNullEdges = 0;
    // //Edge Estimation Alg I
    Regression sampleRegression = new RegressionDataset(dataSet);
    System.out.println(sampleRegression.getGraph());
    graph = GraphUtils.replaceNodes(graph, dataSet.getVariables());
    Map<Edge, double[]> sampleRegress = new HashMap<>();
    Map<Edge, Double> edgeCoefs = new HashMap<>();
    ESTIMATION: for (Node z : graph.getNodes()) {
        Set<Edge> adj = getAdj(z, graph);
        for (Edge edge : apparentlyNonadjacencies.keySet()) {
            if (z == edge.getNode1() || z == edge.getNode2()) {
                for (Edge adjacency : adj) {
                    // return Unknown and go to next Z
                    sampleRegress.put(adjacency, null);
                    Node a = adjacency.getNode1();
                    Node b = adjacency.getNode2();
                    if (semIm.existsEdgeCoef(a, b)) {
                        Double c = semIm.getEdgeCoef(a, b);
                        edgeCoefs.put(adjacency, c);
                    } else {
                        edgeCoefs.put(adjacency, 0.0);
                    }
                }
                continue ESTIMATION;
            }
        }
        for (Edge nonadj : definitelyNonadjacencies) {
            if (nonadj.getNode1() == z || nonadj.getNode2() == z) {
                // return 0 for e
                double[] d = { 0, 0 };
                sampleRegress.put(nonadj, d);
                Node a = nonadj.getNode1();
                Node b = nonadj.getNode2();
                if (semIm.existsEdgeCoef(a, b)) {
                    Double c = semIm.getEdgeCoef(a, b);
                    edgeCoefs.put(nonadj, c);
                } else {
                    edgeCoefs.put(nonadj, 0.0);
                }
            }
        }
        Set<Edge> parentsOfZ = new HashSet<>();
        Set<Edge> _adj = getAdj(z, graph);
        for (Edge _adjacency : _adj) {
            if (!_adjacency.isDirected()) {
                for (Edge adjacency : adj) {
                    sampleRegress.put(adjacency, null);
                    Node a = adjacency.getNode1();
                    Node b = adjacency.getNode2();
                    if (semIm.existsEdgeCoef(a, b)) {
                        Double c = semIm.getEdgeCoef(a, b);
                        edgeCoefs.put(adjacency, c);
                    } else {
                        edgeCoefs.put(adjacency, 0.0);
                    }
                }
            }
            if (_adjacency.pointsTowards(z)) {
                parentsOfZ.add(_adjacency);
            }
        }
        for (Edge edge : parentsOfZ) {
            if (edge.pointsTowards(edge.getNode2())) {
                RegressionResult result = sampleRegression.regress(edge.getNode2(), edge.getNode1());
                System.out.println(result);
                double[] d = result.getCoef();
                sampleRegress.put(edge, d);
                Node a = edge.getNode1();
                Node b = edge.getNode2();
                if (semIm.existsEdgeCoef(a, b)) {
                    Double c = semIm.getEdgeCoef(a, b);
                    edgeCoefs.put(edge, c);
                } else {
                    edgeCoefs.put(edge, 0.0);
                }
            }
        // if (edge.pointsTowards(edge.getNode2())) {
        // RegressionResult result = sampleRegression.regress(edge.getNode2(), edge.getNode1());
        // double[] d = result.getCoef();
        // sampleRegress.put(edge, d);
        // 
        // Node a = edge.getNode1();
        // Node b = edge.getNode2();
        // if (semIm.existsEdgeCoef(a, b)) {
        // Double c = semIm.getEdgeCoef(a, b);
        // edgeCoefs.put(edge, c);
        // } else { edgeCoefs.put(edge, 0.0); }
        // }
        }
    }
    System.out.println("All IM: " + semIm + "Finish");
    System.out.println("Just IM coefs: " + semIm.getEdgeCoef());
    System.out.println("IM Coef Map: " + edgeCoefs);
    System.out.println("Regress Coef Map: " + sampleRegress);
    // 
    for (Edge edge : sampleRegress.keySet()) {
        System.out.println(" Sample Regression: " + edge + java.util.Arrays.toString(sampleRegress.get(edge)));
    }
    for (Edge edge : graph.getEdges()) {
        // if (edge.isDirected()) {
        // System.out.println("IM edge: " + semIm.getEdgeCoef(edge));
        // }
        System.out.println("Sample edge: " + java.util.Arrays.toString(sampleRegress.get(edge)));
    }
    // 
    // 
    System.out.println("Sample VCPC:");
    System.out.println("# of patterns: " + patterns.size());
    long endTime = System.currentTimeMillis();
    this.elapsedTime = endTime - startTime;
    System.out.println("Search Time (seconds):" + (elapsedTime) / 1000 + " s");
    System.out.println("Search Time (milli):" + elapsedTime + " ms");
    System.out.println("# of Apparent Nonadj: " + apparentlyNonadjacencies.size());
    System.out.println("# of Definite Nonadj: " + definitelyNonadjacencies.size());
    // System.out.println("Definitely Nonadjacencies:");
    // 
    // for (Edge edge : definitelyNonadjacencies) {
    // System.out.println(edge);
    // }
    // 
    // System.out.println("markov in all patterns:" + markovInAllPatterns);
    // System.out.println("patterns:" + patterns);
    // System.out.println("Apparently Nonadjacencies:");
    // 
    // for (Edge edge : apparentlyNonadjacencies.keySet()) {
    // System.out.println(edge);
    // }
    // System.out.println("Definitely Nonadjacencies:");
    // 
    // 
    // for (Edge edge : definitelyNonadjacencies) {
    // System.out.println(edge);
    // }
    TetradLogger.getInstance().log("apparentlyNonadjacencies", "\n Apparent Non-adjacencies" + apparentlyNonadjacencies);
    TetradLogger.getInstance().log("definitelyNonadjacencies", "\n Definite Non-adjacencies" + definitelyNonadjacencies);
    TetradLogger.getInstance().log("patterns", "Disambiguated Patterns: " + patterns);
    TetradLogger.getInstance().log("graph", "\nReturning this graph: " + graph);
    TetradLogger.getInstance().log("info", "Elapsed time = " + (elapsedTime) / 1000. + " s");
    TetradLogger.getInstance().log("info", "Finishing CPC algorithm.");
    logTriples();
    TetradLogger.getInstance().flush();
    // SearchGraphUtils.verifySepsetIntegrity(Map<Edge, List<Node>>, graph);
    return graph;
}
Also used : ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) RegressionDataset(edu.cmu.tetrad.regression.RegressionDataset) RegressionResult(edu.cmu.tetrad.regression.RegressionResult) CombinationGenerator(edu.cmu.tetrad.util.CombinationGenerator) Regression(edu.cmu.tetrad.regression.Regression)

Example 28 with RegressionResult

use of edu.cmu.tetrad.regression.RegressionResult in project tetrad by cmu-phil.

the class BffBeam method removeZeroEdges.

public Graph removeZeroEdges(Graph bestGraph) {
    boolean changed = true;
    Graph graph = new EdgeListGraph(bestGraph);
    while (changed) {
        changed = false;
        Score score = scoreGraph(graph);
        SemIm estSem = score.getEstimatedSem();
        for (Parameter param : estSem.getSemPm().getParameters()) {
            if (param.getType() != ParamType.COEF) {
                continue;
            }
            Node nodeA = param.getNodeA();
            Node nodeB = param.getNodeB();
            Node parent;
            Node child;
            if (this.graph.isParentOf(nodeA, nodeB)) {
                parent = nodeA;
                child = nodeB;
            } else {
                parent = nodeB;
                child = nodeA;
            }
            Regression regression = new RegressionCovariance(cov);
            List<Node> parents = graph.getParents(child);
            RegressionResult result = regression.regress(child, parents);
            double p = result.getP()[parents.indexOf(parent) + 1];
            if (p > getHighPValueAlpha()) {
                Edge edge = graph.getEdge(param.getNodeA(), param.getNodeB());
                if (getKnowledge().isRequired(edge.getNode1().getName(), edge.getNode2().getName())) {
                    System.out.println("Not removing " + edge + " because it is required.");
                    TetradLogger.getInstance().log("details", "Not removing " + edge + " because it is required.");
                    continue;
                }
                System.out.println("Removing edge " + edge + " because it has p = " + p);
                TetradLogger.getInstance().log("details", "Removing edge " + edge + " because it has p = " + p);
                graph.removeEdge(edge);
                changed = true;
            }
        }
    }
    return graph;
}
Also used : Regression(edu.cmu.tetrad.regression.Regression) RegressionCovariance(edu.cmu.tetrad.regression.RegressionCovariance) RegressionResult(edu.cmu.tetrad.regression.RegressionResult)

Example 29 with RegressionResult

use of edu.cmu.tetrad.regression.RegressionResult in project tetrad by cmu-phil.

the class TestRegression method testTabular.

/**
 * This tests whether the answer to a rather arbitrary problem changes.
 * At one point, this was the answer being returned.
 */
@Test
public void testTabular() {
    setUp();
    RandomUtil.getInstance().setSeed(3848283L);
    List<Node> nodes = data.getVariables();
    Node target = nodes.get(0);
    List<Node> regressors = new ArrayList<>();
    for (int i = 1; i < nodes.size(); i++) {
        regressors.add(nodes.get(i));
    }
    Regression regression = new RegressionDataset(data);
    RegressionResult result = regression.regress(target, regressors);
    double[] coeffs = result.getCoef();
    assertEquals(.08, coeffs[0], 0.01);
    assertEquals(-.05, coeffs[1], 0.01);
    assertEquals(.035, coeffs[2], 0.01);
    assertEquals(0.019, coeffs[3], 0.01);
    assertEquals(-.003, coeffs[4], 0.01);
}
Also used : RegressionDataset(edu.cmu.tetrad.regression.RegressionDataset) ArrayList(java.util.ArrayList) Regression(edu.cmu.tetrad.regression.Regression) RegressionResult(edu.cmu.tetrad.regression.RegressionResult) Test(org.junit.Test)

Example 30 with RegressionResult

use of edu.cmu.tetrad.regression.RegressionResult in project tetrad by cmu-phil.

the class IndTestMixedMultipleTTest method dependencePvalsLinear.

private double[] dependencePvalsLinear(Node x, Node y, List<Node> z) {
    if (!variablesPerNode.containsKey(x)) {
        throw new IllegalArgumentException("Unrecogized node: " + x);
    }
    if (!variablesPerNode.containsKey(y)) {
        throw new IllegalArgumentException("Unrecogized node: " + y);
    }
    for (Node node : z) {
        if (!variablesPerNode.containsKey(node)) {
            throw new IllegalArgumentException("Unrecogized node: " + node);
        }
    }
    List<Node> yzDumList = new ArrayList<>();
    List<Node> yzList = new ArrayList<>();
    yzList.add(y);
    yzList.addAll(z);
    // List<Node> zList = new ArrayList<>();
    yzDumList.addAll(variablesPerNode.get(y));
    for (Node _z : z) {
        yzDumList.addAll(variablesPerNode.get(_z));
    // zList.addAll(variablesPerNode.get(_z));
    }
    int[] _rows = getNonMissingRows(x, y, z);
    regression.setRows(_rows);
    RegressionResult result;
    try {
        result = regression.regress(x, yzDumList);
    } catch (Exception e) {
        return null;
    }
    double[] pVec = new double[yzList.size()];
    double[] pCoef = result.getP();
    // skip intercept at 0
    int coeffInd = 1;
    for (int i = 0; i < pVec.length; i++) {
        List<Node> curDummy = variablesPerNode.get(yzList.get(i));
        if (curDummy.size() == 1) {
            pVec[i] = pCoef[coeffInd];
            coeffInd++;
            continue;
        } else {
            pVec[i] = 0;
        }
        for (Node n : curDummy) {
            pVec[i] += Math.log(pCoef[coeffInd]);
            coeffInd++;
        }
        if (pVec[i] == Double.NEGATIVE_INFINITY)
            pVec[i] = 0.0;
        else
            pVec[i] = 1.0 - new ChiSquaredDistribution(2 * curDummy.size()).cumulativeProbability(-2 * pVec[i]);
    }
    return pVec;
}
Also used : ChiSquaredDistribution(org.apache.commons.math3.distribution.ChiSquaredDistribution) Node(edu.cmu.tetrad.graph.Node) RegressionResult(edu.cmu.tetrad.regression.RegressionResult)

Aggregations

RegressionResult (edu.cmu.tetrad.regression.RegressionResult)33 Regression (edu.cmu.tetrad.regression.Regression)17 RegressionDataset (edu.cmu.tetrad.regression.RegressionDataset)16 ArrayList (java.util.ArrayList)15 DoubleArrayList (cern.colt.list.DoubleArrayList)11 TetradVector (edu.cmu.tetrad.util.TetradVector)9 AndersonDarlingTest (edu.cmu.tetrad.data.AndersonDarlingTest)8 Node (edu.cmu.tetrad.graph.Node)8 TetradMatrix (edu.cmu.tetrad.util.TetradMatrix)8 RegressionCovariance (edu.cmu.tetrad.regression.RegressionCovariance)3 CombinationGenerator (edu.cmu.tetrad.util.CombinationGenerator)2 ConcurrentHashMap (java.util.concurrent.ConcurrentHashMap)2 Test (org.junit.Test)2 ColtDataSet (edu.cmu.tetrad.data.ColtDataSet)1 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)1 CovarianceMatrix (edu.cmu.tetrad.data.CovarianceMatrix)1 GeneralAndersonDarlingTest (edu.cmu.tetrad.data.GeneralAndersonDarlingTest)1 ICovarianceMatrix (edu.cmu.tetrad.data.ICovarianceMatrix)1 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)1 Graph (edu.cmu.tetrad.graph.Graph)1