Search in sources :

Example 16 with RegressionDataset

use of edu.cmu.tetrad.regression.RegressionDataset in project tetrad by cmu-phil.

the class SampleVcpc method search.

/**
 * Runs PC starting with a fully connected graph over all of the variables in the domain of the independence test.
 * See PC for caveats. The number of possible cycles and bidirected edges is far less with CPC than with PC.
 */
// public final Graph search() {
// return search(independenceTest.getVariable());
// }
// //    public Graph search(List<Node> nodes) {
// //
// ////        return search(new FasICov2(getIndependenceTest()), nodes);
// ////        return search(new Fas(getIndependenceTest()), nodes);
// //        return search(new Fas(getIndependenceTest()), nodes);
// }
// modified FAS into VCFAS; added in definitelyNonadjacencies set of edges.
public Graph search() {
    this.logger.log("info", "Starting VCCPC algorithm");
    this.logger.log("info", "Independence test = " + getIndependenceTest() + ".");
    this.allTriples = new HashSet<>();
    this.ambiguousTriples = new HashSet<>();
    this.colliderTriples = new HashSet<>();
    this.noncolliderTriples = new HashSet<>();
    Vcfas fas = new Vcfas(getIndependenceTest());
    definitelyNonadjacencies = new HashSet<>();
    markovInAllPatterns = new HashSet<>();
    // this.logger.log("info", "Variables " + independenceTest.getVariable());
    long startTime = System.currentTimeMillis();
    if (getIndependenceTest() == null) {
        throw new NullPointerException();
    }
    List<Node> allNodes = getIndependenceTest().getVariables();
    // if (!allNodes.containsAll(nodes)) {
    // throw new IllegalArgumentException("All of the given nodes must " +
    // "be in the domain of the independence test provided.");
    // }
    // Fas fas = new Fas(graph, getIndependenceTest());
    // FasStableConcurrent fas = new FasStableConcurrent(graph, getIndependenceTest());
    // Fas6 fas = new Fas6(graph, getIndependenceTest());
    // fas = new FasICov(graph, (IndTestFisherZ) getIndependenceTest());
    fas.setKnowledge(getKnowledge());
    fas.setDepth(getDepth());
    fas.setVerbose(verbose);
    // Note that we are ignoring the sepset map returned by this method
    // on purpose; it is not used in this search.
    graph = fas.search();
    apparentlyNonadjacencies = fas.getApparentlyNonadjacencies();
    if (isDoOrientation()) {
        if (verbose) {
            System.out.println("CPC orientation...");
        }
        SearchGraphUtils.pcOrientbk(knowledge, graph, allNodes);
        orientUnshieldedTriples(knowledge, getIndependenceTest(), getDepth());
        // orientUnshieldedTriplesConcurrent(knowledge, getIndependenceTest(), getMaxIndegree());
        MeekRules meekRules = new MeekRules();
        meekRules.setAggressivelyPreventCycles(this.aggressivelyPreventCycles);
        meekRules.setKnowledge(knowledge);
        meekRules.orientImplied(graph);
    }
    List<Triple> ambiguousTriples = new ArrayList(graph.getAmbiguousTriples());
    int[] dims = new int[ambiguousTriples.size()];
    for (int i = 0; i < ambiguousTriples.size(); i++) {
        dims[i] = 2;
    }
    List<Graph> patterns = new ArrayList<>();
    Map<Graph, List<Triple>> newColliders = new IdentityHashMap<>();
    Map<Graph, List<Triple>> newNonColliders = new IdentityHashMap<>();
    // Using combination generator to generate a list of combinations of ambiguous triples dismabiguated into colliders
    // and non-colliders. The combinations are added as graphs to the list patterns. The graphs are then subject to
    // basic rules to ensure consistent patterns.
    CombinationGenerator generator = new CombinationGenerator(dims);
    int[] combination;
    while ((combination = generator.next()) != null) {
        Graph _graph = new EdgeListGraph(graph);
        newColliders.put(_graph, new ArrayList<Triple>());
        newNonColliders.put(_graph, new ArrayList<Triple>());
        for (Graph graph : newColliders.keySet()) {
        // System.out.println("$$$ " + newColliders.get(graph));
        }
        for (int k = 0; k < combination.length; k++) {
            // System.out.println("k = " + combination[k]);
            Triple triple = ambiguousTriples.get(k);
            _graph.removeAmbiguousTriple(triple.getX(), triple.getY(), triple.getZ());
            if (combination[k] == 0) {
                newColliders.get(_graph).add(triple);
                // System.out.println(newColliders.get(_graph));
                Node x = triple.getX();
                Node y = triple.getY();
                Node z = triple.getZ();
                _graph.setEndpoint(x, y, Endpoint.ARROW);
                _graph.setEndpoint(z, y, Endpoint.ARROW);
            }
            if (combination[k] == 1) {
                newNonColliders.get(_graph).add(triple);
            }
        }
        patterns.add(_graph);
    }
    List<Graph> _patterns = new ArrayList<>(patterns);
    GRAPH: for (Graph graph : new ArrayList<>(patterns)) {
        // _graph = new EdgeListGraph(graph);
        // System.out.println("graph = " + graph + " in keyset? " + newColliders.containsKey(graph));
        // 
        List<Triple> colliders = newColliders.get(graph);
        List<Triple> nonColliders = newNonColliders.get(graph);
        for (Triple triple : colliders) {
            Node x = triple.getX();
            Node y = triple.getY();
            Node z = triple.getZ();
            if (graph.getEdge(x, y).pointsTowards(x) || (graph.getEdge(y, z).pointsTowards(z))) {
                patterns.remove(graph);
                continue GRAPH;
            }
        }
        for (Triple triple : colliders) {
            Node x = triple.getX();
            Node y = triple.getY();
            Node z = triple.getZ();
            graph.setEndpoint(x, y, Endpoint.ARROW);
            graph.setEndpoint(z, y, Endpoint.ARROW);
        }
        for (Triple triple : nonColliders) {
            Node x = triple.getX();
            Node y = triple.getY();
            Node z = triple.getZ();
            if (graph.getEdge(x, y).pointsTowards(y)) {
                graph.removeEdge(y, z);
                graph.addDirectedEdge(y, z);
            }
            if (graph.getEdge(y, z).pointsTowards(y)) {
                graph.removeEdge(x, y);
                graph.addDirectedEdge(y, x);
            }
        }
        for (Edge edge : graph.getEdges()) {
            Node x = edge.getNode1();
            Node y = edge.getNode2();
            if (Edges.isBidirectedEdge(edge)) {
                graph.removeEdge(x, y);
                graph.addUndirectedEdge(x, y);
            }
        }
        // for (Edge edge : graph.getEdges()) {
        // if (Edges.isBidirectedEdge(edge)) {
        // patterns.remove(graph);
        // continue Graph;
        // }
        // }
        MeekRules rules = new MeekRules();
        rules.orientImplied(graph);
        if (graph.existsDirectedCycle()) {
            patterns.remove(graph);
            continue GRAPH;
        }
    }
    MARKOV: for (Edge edge : apparentlyNonadjacencies.keySet()) {
        Node x = edge.getNode1();
        Node y = edge.getNode2();
        for (Graph _graph : new ArrayList<>(patterns)) {
            List<Node> boundaryX = new ArrayList<>(boundary(x, _graph));
            List<Node> boundaryY = new ArrayList<>(boundary(y, _graph));
            List<Node> futureX = new ArrayList<>(future(x, _graph));
            List<Node> futureY = new ArrayList<>(future(y, _graph));
            if (y == x) {
                continue;
            }
            if (boundaryX.contains(y) || boundaryY.contains(x)) {
                continue;
            }
            IndependenceTest test = independenceTest;
            if (!futureX.contains(y)) {
                if (!test.isIndependent(x, y, boundaryX)) {
                    continue MARKOV;
                }
            }
            if (!futureY.contains(x)) {
                if (!test.isIndependent(y, x, boundaryY)) {
                    continue MARKOV;
                }
            }
        }
        definitelyNonadjacencies.add(edge);
    // apparentlyNonadjacencies.remove(edge);
    }
    for (Edge edge : definitelyNonadjacencies) {
        if (apparentlyNonadjacencies.keySet().contains(edge)) {
            apparentlyNonadjacencies.keySet().remove(edge);
        }
    }
    setSemIm(semIm);
    // semIm.getSemPm().getGraph();
    System.out.println(semIm.getEdgeCoef());
    // graph = DataGraphUtils.replaceNodes(graph, semIm.getVariableNodes());
    // System.out.println(semIm.getEdgeCoef());
    // System.out.println(sampleRegress.entrySet());
    List<Double> squaredDifference = new ArrayList<>();
    int numNullEdges = 0;
    // //Edge Estimation Alg I
    Regression sampleRegression = new RegressionDataset(dataSet);
    System.out.println(sampleRegression.getGraph());
    graph = GraphUtils.replaceNodes(graph, dataSet.getVariables());
    Map<Edge, double[]> sampleRegress = new HashMap<>();
    Map<Edge, Double> edgeCoefs = new HashMap<>();
    ESTIMATION: for (Node z : graph.getNodes()) {
        Set<Edge> adj = getAdj(z, graph);
        for (Edge edge : apparentlyNonadjacencies.keySet()) {
            if (z == edge.getNode1() || z == edge.getNode2()) {
                for (Edge adjacency : adj) {
                    // return Unknown and go to next Z
                    sampleRegress.put(adjacency, null);
                    Node a = adjacency.getNode1();
                    Node b = adjacency.getNode2();
                    if (semIm.existsEdgeCoef(a, b)) {
                        Double c = semIm.getEdgeCoef(a, b);
                        edgeCoefs.put(adjacency, c);
                    } else {
                        edgeCoefs.put(adjacency, 0.0);
                    }
                }
                continue ESTIMATION;
            }
        }
        for (Edge nonadj : definitelyNonadjacencies) {
            if (nonadj.getNode1() == z || nonadj.getNode2() == z) {
                // return 0 for e
                double[] d = { 0, 0 };
                sampleRegress.put(nonadj, d);
                Node a = nonadj.getNode1();
                Node b = nonadj.getNode2();
                if (semIm.existsEdgeCoef(a, b)) {
                    Double c = semIm.getEdgeCoef(a, b);
                    edgeCoefs.put(nonadj, c);
                } else {
                    edgeCoefs.put(nonadj, 0.0);
                }
            }
        }
        Set<Edge> parentsOfZ = new HashSet<>();
        Set<Edge> _adj = getAdj(z, graph);
        for (Edge _adjacency : _adj) {
            if (!_adjacency.isDirected()) {
                for (Edge adjacency : adj) {
                    sampleRegress.put(adjacency, null);
                    Node a = adjacency.getNode1();
                    Node b = adjacency.getNode2();
                    if (semIm.existsEdgeCoef(a, b)) {
                        Double c = semIm.getEdgeCoef(a, b);
                        edgeCoefs.put(adjacency, c);
                    } else {
                        edgeCoefs.put(adjacency, 0.0);
                    }
                }
            }
            if (_adjacency.pointsTowards(z)) {
                parentsOfZ.add(_adjacency);
            }
        }
        for (Edge edge : parentsOfZ) {
            if (edge.pointsTowards(edge.getNode2())) {
                RegressionResult result = sampleRegression.regress(edge.getNode2(), edge.getNode1());
                System.out.println(result);
                double[] d = result.getCoef();
                sampleRegress.put(edge, d);
                Node a = edge.getNode1();
                Node b = edge.getNode2();
                if (semIm.existsEdgeCoef(a, b)) {
                    Double c = semIm.getEdgeCoef(a, b);
                    edgeCoefs.put(edge, c);
                } else {
                    edgeCoefs.put(edge, 0.0);
                }
            }
        // if (edge.pointsTowards(edge.getNode2())) {
        // RegressionResult result = sampleRegression.regress(edge.getNode2(), edge.getNode1());
        // double[] d = result.getCoef();
        // sampleRegress.put(edge, d);
        // 
        // Node a = edge.getNode1();
        // Node b = edge.getNode2();
        // if (semIm.existsEdgeCoef(a, b)) {
        // Double c = semIm.getEdgeCoef(a, b);
        // edgeCoefs.put(edge, c);
        // } else { edgeCoefs.put(edge, 0.0); }
        // }
        }
    }
    System.out.println("All IM: " + semIm + "Finish");
    System.out.println("Just IM coefs: " + semIm.getEdgeCoef());
    System.out.println("IM Coef Map: " + edgeCoefs);
    System.out.println("Regress Coef Map: " + sampleRegress);
    // 
    for (Edge edge : sampleRegress.keySet()) {
        System.out.println(" Sample Regression: " + edge + java.util.Arrays.toString(sampleRegress.get(edge)));
    }
    for (Edge edge : graph.getEdges()) {
        // if (edge.isDirected()) {
        // System.out.println("IM edge: " + semIm.getEdgeCoef(edge));
        // }
        System.out.println("Sample edge: " + java.util.Arrays.toString(sampleRegress.get(edge)));
    }
    // 
    // 
    System.out.println("Sample VCPC:");
    System.out.println("# of patterns: " + patterns.size());
    long endTime = System.currentTimeMillis();
    this.elapsedTime = endTime - startTime;
    System.out.println("Search Time (seconds):" + (elapsedTime) / 1000 + " s");
    System.out.println("Search Time (milli):" + elapsedTime + " ms");
    System.out.println("# of Apparent Nonadj: " + apparentlyNonadjacencies.size());
    System.out.println("# of Definite Nonadj: " + definitelyNonadjacencies.size());
    // System.out.println("Definitely Nonadjacencies:");
    // 
    // for (Edge edge : definitelyNonadjacencies) {
    // System.out.println(edge);
    // }
    // 
    // System.out.println("markov in all patterns:" + markovInAllPatterns);
    // System.out.println("patterns:" + patterns);
    // System.out.println("Apparently Nonadjacencies:");
    // 
    // for (Edge edge : apparentlyNonadjacencies.keySet()) {
    // System.out.println(edge);
    // }
    // System.out.println("Definitely Nonadjacencies:");
    // 
    // 
    // for (Edge edge : definitelyNonadjacencies) {
    // System.out.println(edge);
    // }
    TetradLogger.getInstance().log("apparentlyNonadjacencies", "\n Apparent Non-adjacencies" + apparentlyNonadjacencies);
    TetradLogger.getInstance().log("definitelyNonadjacencies", "\n Definite Non-adjacencies" + definitelyNonadjacencies);
    TetradLogger.getInstance().log("patterns", "Disambiguated Patterns: " + patterns);
    TetradLogger.getInstance().log("graph", "\nReturning this graph: " + graph);
    TetradLogger.getInstance().log("info", "Elapsed time = " + (elapsedTime) / 1000. + " s");
    TetradLogger.getInstance().log("info", "Finishing CPC algorithm.");
    logTriples();
    TetradLogger.getInstance().flush();
    // SearchGraphUtils.verifySepsetIntegrity(Map<Edge, List<Node>>, graph);
    return graph;
}
Also used : ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) RegressionDataset(edu.cmu.tetrad.regression.RegressionDataset) RegressionResult(edu.cmu.tetrad.regression.RegressionResult) CombinationGenerator(edu.cmu.tetrad.util.CombinationGenerator) Regression(edu.cmu.tetrad.regression.Regression)

Example 17 with RegressionDataset

use of edu.cmu.tetrad.regression.RegressionDataset in project tetrad by cmu-phil.

the class RegressionRunner method execute.

// =================PUBLIC METHODS OVERRIDING ABSTRACT=================//
/**
 * Executes the algorithm, producing (at least) a result workbench. Must be
 * implemented in the extending class.
 */
public void execute() {
    if (regressorNames.size() == 0 || targetName == null) {
        outGraph = new EdgeListGraph();
        return;
    }
    if (regressorNames.contains(targetName)) {
        outGraph = new EdgeListGraph();
        return;
    }
    Regression regression;
    Node target;
    List<Node> regressors;
    if (getDataModel() instanceof DataSet) {
        DataSet _dataSet = (DataSet) getDataModel();
        regression = new RegressionDataset(_dataSet);
        target = _dataSet.getVariable(targetName);
        regressors = new LinkedList<>();
        for (String regressorName : regressorNames) {
            regressors.add(_dataSet.getVariable(regressorName));
        }
        double alpha = params.getDouble("alpha", 0.001);
        regression.setAlpha(alpha);
        result = regression.regress(target, regressors);
        outGraph = regression.getGraph();
    } else if (getDataModel() instanceof ICovarianceMatrix) {
        ICovarianceMatrix covariances = (ICovarianceMatrix) getDataModel();
        regression = new RegressionCovariance(covariances);
        target = covariances.getVariable(targetName);
        regressors = new LinkedList<>();
        for (String regressorName : regressorNames) {
            regressors.add(covariances.getVariable(regressorName));
        }
        double alpha = params.getDouble("alpha", 0.001);
        regression.setAlpha(alpha);
        result = regression.regress(target, regressors);
        outGraph = regression.getGraph();
    }
    setResultGraph(outGraph);
}
Also used : RegressionDataset(edu.cmu.tetrad.regression.RegressionDataset) Node(edu.cmu.tetrad.graph.Node) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) Regression(edu.cmu.tetrad.regression.Regression) RegressionCovariance(edu.cmu.tetrad.regression.RegressionCovariance)

Example 18 with RegressionDataset

use of edu.cmu.tetrad.regression.RegressionDataset in project tetrad by cmu-phil.

the class TestRegression method testTabular.

/**
 * This tests whether the answer to a rather arbitrary problem changes.
 * At one point, this was the answer being returned.
 */
@Test
public void testTabular() {
    setUp();
    RandomUtil.getInstance().setSeed(3848283L);
    List<Node> nodes = data.getVariables();
    Node target = nodes.get(0);
    List<Node> regressors = new ArrayList<>();
    for (int i = 1; i < nodes.size(); i++) {
        regressors.add(nodes.get(i));
    }
    Regression regression = new RegressionDataset(data);
    RegressionResult result = regression.regress(target, regressors);
    double[] coeffs = result.getCoef();
    assertEquals(.08, coeffs[0], 0.01);
    assertEquals(-.05, coeffs[1], 0.01);
    assertEquals(.035, coeffs[2], 0.01);
    assertEquals(0.019, coeffs[3], 0.01);
    assertEquals(-.003, coeffs[4], 0.01);
}
Also used : RegressionDataset(edu.cmu.tetrad.regression.RegressionDataset) ArrayList(java.util.ArrayList) Regression(edu.cmu.tetrad.regression.Regression) RegressionResult(edu.cmu.tetrad.regression.RegressionResult) Test(org.junit.Test)

Example 19 with RegressionDataset

use of edu.cmu.tetrad.regression.RegressionDataset in project tetrad by cmu-phil.

the class IndTestRegressionAD method isIndependent.

/**
 * Determines whether variable x is independent of variable y given a list of conditioning variables z.
 *
 * @param xVar  the one variable being compared.
 * @param yVar  the second variable being compared.
 * @param zList the list of conditioning variables.
 * @return true iff x _||_ y | z.
 * @throws RuntimeException if a matrix singularity is encountered.
 */
public boolean isIndependent(Node xVar, Node yVar, List<Node> zList) {
    if (zList == null) {
        throw new NullPointerException();
    }
    for (Node node : zList) {
        if (node == null) {
            throw new NullPointerException();
        }
    }
    TetradVector v1, v2;
    try {
        List<Node> regressors = new ArrayList<>();
        regressors.add(dataSet.getVariable(yVar.getName()));
        for (Node zVar : zList) {
            regressors.add(dataSet.getVariable(zVar.getName()));
        }
        RegressionDataset regression = new RegressionDataset(dataSet);
        RegressionResult result = regression.regress(xVar, regressors);
        v1 = result.getResiduals();
        v2 = regression.getResidualsWithoutFirstRegressor();
    // regressors.remove(dataSet.getVariable(yVar.getNode()));
    // regression = new RegressionDataset(dataSet);
    // result = regression.regress(xVar, regressors);
    // v2 = result.getResiduals();
    } catch (Exception e) {
        throw e;
    }
    List<Double> d1 = new ArrayList<>();
    for (int i = 0; i < v1.size(); i++) d1.add(v1.get(i));
    List<Double> d2 = new ArrayList<>();
    double[] f2 = new double[v2.size()];
    for (int i = 0; i < v2.size(); i++) {
        d2.add(v2.get(i));
        f2[i] = v2.get(i);
    }
    double sd = StatUtils.sd(f2);
    // RealDistribution c2 = new EmpiricalCdf(d2);
    RealDistribution c2 = new NormalDistribution(0, sd);
    GeneralAndersonDarlingTest test = new GeneralAndersonDarlingTest(d1, c2);
    double aSquaredStar = test.getASquaredStar();
    System.out.println("A squared star = " + aSquaredStar + " p = " + test.getP());
    double p = test.getP();
    double aa2 = 1 - tanh(aSquaredStar);
    boolean independent = p > alpha;
    this.pvalue = aa2;
    if (verbose) {
        if (independent) {
            TetradLogger.getInstance().log("independencies", SearchLogUtils.independenceFactMsg(xVar, yVar, zList, 0.));
        } else {
            TetradLogger.getInstance().log("dependencies", SearchLogUtils.dependenceFactMsg(xVar, yVar, zList, 0.));
        }
    }
    return independent;
}
Also used : Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) RealDistribution(org.apache.commons.math3.distribution.RealDistribution) RegressionDataset(edu.cmu.tetrad.regression.RegressionDataset) NormalDistribution(org.apache.commons.math3.distribution.NormalDistribution) GeneralAndersonDarlingTest(edu.cmu.tetrad.data.GeneralAndersonDarlingTest) RegressionResult(edu.cmu.tetrad.regression.RegressionResult)

Aggregations

RegressionDataset (edu.cmu.tetrad.regression.RegressionDataset)19 Regression (edu.cmu.tetrad.regression.Regression)16 RegressionResult (edu.cmu.tetrad.regression.RegressionResult)16 ArrayList (java.util.ArrayList)10 TetradMatrix (edu.cmu.tetrad.util.TetradMatrix)7 TetradVector (edu.cmu.tetrad.util.TetradVector)7 Node (edu.cmu.tetrad.graph.Node)4 DoubleArrayList (cern.colt.list.DoubleArrayList)3 AndersonDarlingTest (edu.cmu.tetrad.data.AndersonDarlingTest)3 CombinationGenerator (edu.cmu.tetrad.util.CombinationGenerator)2 ConcurrentHashMap (java.util.concurrent.ConcurrentHashMap)2 OLSMultipleLinearRegression (org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression)2 ColtDataSet (edu.cmu.tetrad.data.ColtDataSet)1 GeneralAndersonDarlingTest (edu.cmu.tetrad.data.GeneralAndersonDarlingTest)1 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)1 RegressionCovariance (edu.cmu.tetrad.regression.RegressionCovariance)1 Parameters (edu.cmu.tetrad.util.Parameters)1 Ellipse2D (java.awt.geom.Ellipse2D)1 Vector (java.util.Vector)1 NormalDistribution (org.apache.commons.math3.distribution.NormalDistribution)1