Search in sources :

Example 16 with Regression

use of edu.cmu.tetrad.regression.Regression in project tetrad by cmu-phil.

the class SampleVcpc method search.

/**
 * Runs PC starting with a fully connected graph over all of the variables in the domain of the independence test.
 * See PC for caveats. The number of possible cycles and bidirected edges is far less with CPC than with PC.
 */
// public final Graph search() {
// return search(independenceTest.getVariable());
// }
// //    public Graph search(List<Node> nodes) {
// //
// ////        return search(new FasICov2(getIndependenceTest()), nodes);
// ////        return search(new Fas(getIndependenceTest()), nodes);
// //        return search(new Fas(getIndependenceTest()), nodes);
// }
// modified FAS into VCFAS; added in definitelyNonadjacencies set of edges.
public Graph search() {
    this.logger.log("info", "Starting VCCPC algorithm");
    this.logger.log("info", "Independence test = " + getIndependenceTest() + ".");
    this.allTriples = new HashSet<>();
    this.ambiguousTriples = new HashSet<>();
    this.colliderTriples = new HashSet<>();
    this.noncolliderTriples = new HashSet<>();
    Vcfas fas = new Vcfas(getIndependenceTest());
    definitelyNonadjacencies = new HashSet<>();
    markovInAllPatterns = new HashSet<>();
    // this.logger.log("info", "Variables " + independenceTest.getVariable());
    long startTime = System.currentTimeMillis();
    if (getIndependenceTest() == null) {
        throw new NullPointerException();
    }
    List<Node> allNodes = getIndependenceTest().getVariables();
    // if (!allNodes.containsAll(nodes)) {
    // throw new IllegalArgumentException("All of the given nodes must " +
    // "be in the domain of the independence test provided.");
    // }
    // Fas fas = new Fas(graph, getIndependenceTest());
    // FasStableConcurrent fas = new FasStableConcurrent(graph, getIndependenceTest());
    // Fas6 fas = new Fas6(graph, getIndependenceTest());
    // fas = new FasICov(graph, (IndTestFisherZ) getIndependenceTest());
    fas.setKnowledge(getKnowledge());
    fas.setDepth(getDepth());
    fas.setVerbose(verbose);
    // Note that we are ignoring the sepset map returned by this method
    // on purpose; it is not used in this search.
    graph = fas.search();
    apparentlyNonadjacencies = fas.getApparentlyNonadjacencies();
    if (isDoOrientation()) {
        if (verbose) {
            System.out.println("CPC orientation...");
        }
        SearchGraphUtils.pcOrientbk(knowledge, graph, allNodes);
        orientUnshieldedTriples(knowledge, getIndependenceTest(), getDepth());
        // orientUnshieldedTriplesConcurrent(knowledge, getIndependenceTest(), getMaxIndegree());
        MeekRules meekRules = new MeekRules();
        meekRules.setAggressivelyPreventCycles(this.aggressivelyPreventCycles);
        meekRules.setKnowledge(knowledge);
        meekRules.orientImplied(graph);
    }
    List<Triple> ambiguousTriples = new ArrayList(graph.getAmbiguousTriples());
    int[] dims = new int[ambiguousTriples.size()];
    for (int i = 0; i < ambiguousTriples.size(); i++) {
        dims[i] = 2;
    }
    List<Graph> patterns = new ArrayList<>();
    Map<Graph, List<Triple>> newColliders = new IdentityHashMap<>();
    Map<Graph, List<Triple>> newNonColliders = new IdentityHashMap<>();
    // Using combination generator to generate a list of combinations of ambiguous triples dismabiguated into colliders
    // and non-colliders. The combinations are added as graphs to the list patterns. The graphs are then subject to
    // basic rules to ensure consistent patterns.
    CombinationGenerator generator = new CombinationGenerator(dims);
    int[] combination;
    while ((combination = generator.next()) != null) {
        Graph _graph = new EdgeListGraph(graph);
        newColliders.put(_graph, new ArrayList<Triple>());
        newNonColliders.put(_graph, new ArrayList<Triple>());
        for (Graph graph : newColliders.keySet()) {
        // System.out.println("$$$ " + newColliders.get(graph));
        }
        for (int k = 0; k < combination.length; k++) {
            // System.out.println("k = " + combination[k]);
            Triple triple = ambiguousTriples.get(k);
            _graph.removeAmbiguousTriple(triple.getX(), triple.getY(), triple.getZ());
            if (combination[k] == 0) {
                newColliders.get(_graph).add(triple);
                // System.out.println(newColliders.get(_graph));
                Node x = triple.getX();
                Node y = triple.getY();
                Node z = triple.getZ();
                _graph.setEndpoint(x, y, Endpoint.ARROW);
                _graph.setEndpoint(z, y, Endpoint.ARROW);
            }
            if (combination[k] == 1) {
                newNonColliders.get(_graph).add(triple);
            }
        }
        patterns.add(_graph);
    }
    List<Graph> _patterns = new ArrayList<>(patterns);
    GRAPH: for (Graph graph : new ArrayList<>(patterns)) {
        // _graph = new EdgeListGraph(graph);
        // System.out.println("graph = " + graph + " in keyset? " + newColliders.containsKey(graph));
        // 
        List<Triple> colliders = newColliders.get(graph);
        List<Triple> nonColliders = newNonColliders.get(graph);
        for (Triple triple : colliders) {
            Node x = triple.getX();
            Node y = triple.getY();
            Node z = triple.getZ();
            if (graph.getEdge(x, y).pointsTowards(x) || (graph.getEdge(y, z).pointsTowards(z))) {
                patterns.remove(graph);
                continue GRAPH;
            }
        }
        for (Triple triple : colliders) {
            Node x = triple.getX();
            Node y = triple.getY();
            Node z = triple.getZ();
            graph.setEndpoint(x, y, Endpoint.ARROW);
            graph.setEndpoint(z, y, Endpoint.ARROW);
        }
        for (Triple triple : nonColliders) {
            Node x = triple.getX();
            Node y = triple.getY();
            Node z = triple.getZ();
            if (graph.getEdge(x, y).pointsTowards(y)) {
                graph.removeEdge(y, z);
                graph.addDirectedEdge(y, z);
            }
            if (graph.getEdge(y, z).pointsTowards(y)) {
                graph.removeEdge(x, y);
                graph.addDirectedEdge(y, x);
            }
        }
        for (Edge edge : graph.getEdges()) {
            Node x = edge.getNode1();
            Node y = edge.getNode2();
            if (Edges.isBidirectedEdge(edge)) {
                graph.removeEdge(x, y);
                graph.addUndirectedEdge(x, y);
            }
        }
        // for (Edge edge : graph.getEdges()) {
        // if (Edges.isBidirectedEdge(edge)) {
        // patterns.remove(graph);
        // continue Graph;
        // }
        // }
        MeekRules rules = new MeekRules();
        rules.orientImplied(graph);
        if (graph.existsDirectedCycle()) {
            patterns.remove(graph);
            continue GRAPH;
        }
    }
    MARKOV: for (Edge edge : apparentlyNonadjacencies.keySet()) {
        Node x = edge.getNode1();
        Node y = edge.getNode2();
        for (Graph _graph : new ArrayList<>(patterns)) {
            List<Node> boundaryX = new ArrayList<>(boundary(x, _graph));
            List<Node> boundaryY = new ArrayList<>(boundary(y, _graph));
            List<Node> futureX = new ArrayList<>(future(x, _graph));
            List<Node> futureY = new ArrayList<>(future(y, _graph));
            if (y == x) {
                continue;
            }
            if (boundaryX.contains(y) || boundaryY.contains(x)) {
                continue;
            }
            IndependenceTest test = independenceTest;
            if (!futureX.contains(y)) {
                if (!test.isIndependent(x, y, boundaryX)) {
                    continue MARKOV;
                }
            }
            if (!futureY.contains(x)) {
                if (!test.isIndependent(y, x, boundaryY)) {
                    continue MARKOV;
                }
            }
        }
        definitelyNonadjacencies.add(edge);
    // apparentlyNonadjacencies.remove(edge);
    }
    for (Edge edge : definitelyNonadjacencies) {
        if (apparentlyNonadjacencies.keySet().contains(edge)) {
            apparentlyNonadjacencies.keySet().remove(edge);
        }
    }
    setSemIm(semIm);
    // semIm.getSemPm().getGraph();
    System.out.println(semIm.getEdgeCoef());
    // graph = DataGraphUtils.replaceNodes(graph, semIm.getVariableNodes());
    // System.out.println(semIm.getEdgeCoef());
    // System.out.println(sampleRegress.entrySet());
    List<Double> squaredDifference = new ArrayList<>();
    int numNullEdges = 0;
    // //Edge Estimation Alg I
    Regression sampleRegression = new RegressionDataset(dataSet);
    System.out.println(sampleRegression.getGraph());
    graph = GraphUtils.replaceNodes(graph, dataSet.getVariables());
    Map<Edge, double[]> sampleRegress = new HashMap<>();
    Map<Edge, Double> edgeCoefs = new HashMap<>();
    ESTIMATION: for (Node z : graph.getNodes()) {
        Set<Edge> adj = getAdj(z, graph);
        for (Edge edge : apparentlyNonadjacencies.keySet()) {
            if (z == edge.getNode1() || z == edge.getNode2()) {
                for (Edge adjacency : adj) {
                    // return Unknown and go to next Z
                    sampleRegress.put(adjacency, null);
                    Node a = adjacency.getNode1();
                    Node b = adjacency.getNode2();
                    if (semIm.existsEdgeCoef(a, b)) {
                        Double c = semIm.getEdgeCoef(a, b);
                        edgeCoefs.put(adjacency, c);
                    } else {
                        edgeCoefs.put(adjacency, 0.0);
                    }
                }
                continue ESTIMATION;
            }
        }
        for (Edge nonadj : definitelyNonadjacencies) {
            if (nonadj.getNode1() == z || nonadj.getNode2() == z) {
                // return 0 for e
                double[] d = { 0, 0 };
                sampleRegress.put(nonadj, d);
                Node a = nonadj.getNode1();
                Node b = nonadj.getNode2();
                if (semIm.existsEdgeCoef(a, b)) {
                    Double c = semIm.getEdgeCoef(a, b);
                    edgeCoefs.put(nonadj, c);
                } else {
                    edgeCoefs.put(nonadj, 0.0);
                }
            }
        }
        Set<Edge> parentsOfZ = new HashSet<>();
        Set<Edge> _adj = getAdj(z, graph);
        for (Edge _adjacency : _adj) {
            if (!_adjacency.isDirected()) {
                for (Edge adjacency : adj) {
                    sampleRegress.put(adjacency, null);
                    Node a = adjacency.getNode1();
                    Node b = adjacency.getNode2();
                    if (semIm.existsEdgeCoef(a, b)) {
                        Double c = semIm.getEdgeCoef(a, b);
                        edgeCoefs.put(adjacency, c);
                    } else {
                        edgeCoefs.put(adjacency, 0.0);
                    }
                }
            }
            if (_adjacency.pointsTowards(z)) {
                parentsOfZ.add(_adjacency);
            }
        }
        for (Edge edge : parentsOfZ) {
            if (edge.pointsTowards(edge.getNode2())) {
                RegressionResult result = sampleRegression.regress(edge.getNode2(), edge.getNode1());
                System.out.println(result);
                double[] d = result.getCoef();
                sampleRegress.put(edge, d);
                Node a = edge.getNode1();
                Node b = edge.getNode2();
                if (semIm.existsEdgeCoef(a, b)) {
                    Double c = semIm.getEdgeCoef(a, b);
                    edgeCoefs.put(edge, c);
                } else {
                    edgeCoefs.put(edge, 0.0);
                }
            }
        // if (edge.pointsTowards(edge.getNode2())) {
        // RegressionResult result = sampleRegression.regress(edge.getNode2(), edge.getNode1());
        // double[] d = result.getCoef();
        // sampleRegress.put(edge, d);
        // 
        // Node a = edge.getNode1();
        // Node b = edge.getNode2();
        // if (semIm.existsEdgeCoef(a, b)) {
        // Double c = semIm.getEdgeCoef(a, b);
        // edgeCoefs.put(edge, c);
        // } else { edgeCoefs.put(edge, 0.0); }
        // }
        }
    }
    System.out.println("All IM: " + semIm + "Finish");
    System.out.println("Just IM coefs: " + semIm.getEdgeCoef());
    System.out.println("IM Coef Map: " + edgeCoefs);
    System.out.println("Regress Coef Map: " + sampleRegress);
    // 
    for (Edge edge : sampleRegress.keySet()) {
        System.out.println(" Sample Regression: " + edge + java.util.Arrays.toString(sampleRegress.get(edge)));
    }
    for (Edge edge : graph.getEdges()) {
        // if (edge.isDirected()) {
        // System.out.println("IM edge: " + semIm.getEdgeCoef(edge));
        // }
        System.out.println("Sample edge: " + java.util.Arrays.toString(sampleRegress.get(edge)));
    }
    // 
    // 
    System.out.println("Sample VCPC:");
    System.out.println("# of patterns: " + patterns.size());
    long endTime = System.currentTimeMillis();
    this.elapsedTime = endTime - startTime;
    System.out.println("Search Time (seconds):" + (elapsedTime) / 1000 + " s");
    System.out.println("Search Time (milli):" + elapsedTime + " ms");
    System.out.println("# of Apparent Nonadj: " + apparentlyNonadjacencies.size());
    System.out.println("# of Definite Nonadj: " + definitelyNonadjacencies.size());
    // System.out.println("Definitely Nonadjacencies:");
    // 
    // for (Edge edge : definitelyNonadjacencies) {
    // System.out.println(edge);
    // }
    // 
    // System.out.println("markov in all patterns:" + markovInAllPatterns);
    // System.out.println("patterns:" + patterns);
    // System.out.println("Apparently Nonadjacencies:");
    // 
    // for (Edge edge : apparentlyNonadjacencies.keySet()) {
    // System.out.println(edge);
    // }
    // System.out.println("Definitely Nonadjacencies:");
    // 
    // 
    // for (Edge edge : definitelyNonadjacencies) {
    // System.out.println(edge);
    // }
    TetradLogger.getInstance().log("apparentlyNonadjacencies", "\n Apparent Non-adjacencies" + apparentlyNonadjacencies);
    TetradLogger.getInstance().log("definitelyNonadjacencies", "\n Definite Non-adjacencies" + definitelyNonadjacencies);
    TetradLogger.getInstance().log("patterns", "Disambiguated Patterns: " + patterns);
    TetradLogger.getInstance().log("graph", "\nReturning this graph: " + graph);
    TetradLogger.getInstance().log("info", "Elapsed time = " + (elapsedTime) / 1000. + " s");
    TetradLogger.getInstance().log("info", "Finishing CPC algorithm.");
    logTriples();
    TetradLogger.getInstance().flush();
    // SearchGraphUtils.verifySepsetIntegrity(Map<Edge, List<Node>>, graph);
    return graph;
}
Also used : ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) RegressionDataset(edu.cmu.tetrad.regression.RegressionDataset) RegressionResult(edu.cmu.tetrad.regression.RegressionResult) CombinationGenerator(edu.cmu.tetrad.util.CombinationGenerator) Regression(edu.cmu.tetrad.regression.Regression)

Example 17 with Regression

use of edu.cmu.tetrad.regression.Regression in project tetrad by cmu-phil.

the class RegressionRunner method execute.

// =================PUBLIC METHODS OVERRIDING ABSTRACT=================//
/**
 * Executes the algorithm, producing (at least) a result workbench. Must be
 * implemented in the extending class.
 */
public void execute() {
    if (regressorNames.size() == 0 || targetName == null) {
        outGraph = new EdgeListGraph();
        return;
    }
    if (regressorNames.contains(targetName)) {
        outGraph = new EdgeListGraph();
        return;
    }
    Regression regression;
    Node target;
    List<Node> regressors;
    if (getDataModel() instanceof DataSet) {
        DataSet _dataSet = (DataSet) getDataModel();
        regression = new RegressionDataset(_dataSet);
        target = _dataSet.getVariable(targetName);
        regressors = new LinkedList<>();
        for (String regressorName : regressorNames) {
            regressors.add(_dataSet.getVariable(regressorName));
        }
        double alpha = params.getDouble("alpha", 0.001);
        regression.setAlpha(alpha);
        result = regression.regress(target, regressors);
        outGraph = regression.getGraph();
    } else if (getDataModel() instanceof ICovarianceMatrix) {
        ICovarianceMatrix covariances = (ICovarianceMatrix) getDataModel();
        regression = new RegressionCovariance(covariances);
        target = covariances.getVariable(targetName);
        regressors = new LinkedList<>();
        for (String regressorName : regressorNames) {
            regressors.add(covariances.getVariable(regressorName));
        }
        double alpha = params.getDouble("alpha", 0.001);
        regression.setAlpha(alpha);
        result = regression.regress(target, regressors);
        outGraph = regression.getGraph();
    }
    setResultGraph(outGraph);
}
Also used : RegressionDataset(edu.cmu.tetrad.regression.RegressionDataset) Node(edu.cmu.tetrad.graph.Node) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) Regression(edu.cmu.tetrad.regression.Regression) RegressionCovariance(edu.cmu.tetrad.regression.RegressionCovariance)

Example 18 with Regression

use of edu.cmu.tetrad.regression.Regression in project tetrad by cmu-phil.

the class BffBeam method removeZeroEdges.

public Graph removeZeroEdges(Graph bestGraph) {
    boolean changed = true;
    Graph graph = new EdgeListGraph(bestGraph);
    while (changed) {
        changed = false;
        Score score = scoreGraph(graph);
        SemIm estSem = score.getEstimatedSem();
        for (Parameter param : estSem.getSemPm().getParameters()) {
            if (param.getType() != ParamType.COEF) {
                continue;
            }
            Node nodeA = param.getNodeA();
            Node nodeB = param.getNodeB();
            Node parent;
            Node child;
            if (this.graph.isParentOf(nodeA, nodeB)) {
                parent = nodeA;
                child = nodeB;
            } else {
                parent = nodeB;
                child = nodeA;
            }
            Regression regression = new RegressionCovariance(cov);
            List<Node> parents = graph.getParents(child);
            RegressionResult result = regression.regress(child, parents);
            double p = result.getP()[parents.indexOf(parent) + 1];
            if (p > getHighPValueAlpha()) {
                Edge edge = graph.getEdge(param.getNodeA(), param.getNodeB());
                if (getKnowledge().isRequired(edge.getNode1().getName(), edge.getNode2().getName())) {
                    System.out.println("Not removing " + edge + " because it is required.");
                    TetradLogger.getInstance().log("details", "Not removing " + edge + " because it is required.");
                    continue;
                }
                System.out.println("Removing edge " + edge + " because it has p = " + p);
                TetradLogger.getInstance().log("details", "Removing edge " + edge + " because it has p = " + p);
                graph.removeEdge(edge);
                changed = true;
            }
        }
    }
    return graph;
}
Also used : Regression(edu.cmu.tetrad.regression.Regression) RegressionCovariance(edu.cmu.tetrad.regression.RegressionCovariance) RegressionResult(edu.cmu.tetrad.regression.RegressionResult)

Example 19 with Regression

use of edu.cmu.tetrad.regression.Regression in project tetrad by cmu-phil.

the class TestRegression method testTabular.

/**
 * This tests whether the answer to a rather arbitrary problem changes.
 * At one point, this was the answer being returned.
 */
@Test
public void testTabular() {
    setUp();
    RandomUtil.getInstance().setSeed(3848283L);
    List<Node> nodes = data.getVariables();
    Node target = nodes.get(0);
    List<Node> regressors = new ArrayList<>();
    for (int i = 1; i < nodes.size(); i++) {
        regressors.add(nodes.get(i));
    }
    Regression regression = new RegressionDataset(data);
    RegressionResult result = regression.regress(target, regressors);
    double[] coeffs = result.getCoef();
    assertEquals(.08, coeffs[0], 0.01);
    assertEquals(-.05, coeffs[1], 0.01);
    assertEquals(.035, coeffs[2], 0.01);
    assertEquals(0.019, coeffs[3], 0.01);
    assertEquals(-.003, coeffs[4], 0.01);
}
Also used : RegressionDataset(edu.cmu.tetrad.regression.RegressionDataset) ArrayList(java.util.ArrayList) Regression(edu.cmu.tetrad.regression.Regression) RegressionResult(edu.cmu.tetrad.regression.RegressionResult) Test(org.junit.Test)

Aggregations

Regression (edu.cmu.tetrad.regression.Regression)19 RegressionResult (edu.cmu.tetrad.regression.RegressionResult)17 RegressionDataset (edu.cmu.tetrad.regression.RegressionDataset)16 ArrayList (java.util.ArrayList)10 TetradMatrix (edu.cmu.tetrad.util.TetradMatrix)7 TetradVector (edu.cmu.tetrad.util.TetradVector)7 RegressionCovariance (edu.cmu.tetrad.regression.RegressionCovariance)4 DoubleArrayList (cern.colt.list.DoubleArrayList)3 AndersonDarlingTest (edu.cmu.tetrad.data.AndersonDarlingTest)3 Node (edu.cmu.tetrad.graph.Node)3 CombinationGenerator (edu.cmu.tetrad.util.CombinationGenerator)2 ConcurrentHashMap (java.util.concurrent.ConcurrentHashMap)2 OLSMultipleLinearRegression (org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression)2 Test (org.junit.Test)2 CovarianceMatrix (edu.cmu.tetrad.data.CovarianceMatrix)1 ICovarianceMatrix (edu.cmu.tetrad.data.ICovarianceMatrix)1 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)1