Search in sources :

Example 76 with TetradMatrix

use of edu.cmu.tetrad.util.TetradMatrix in project tetrad by cmu-phil.

the class Peter1Score method printMinimalLinearlyDependentSet.

// Prints a smallest subset of parents that causes a singular matrix exception.
private boolean printMinimalLinearlyDependentSet(int[] parents, ICovarianceMatrix cov) {
    List<Node> _parents = new ArrayList<>();
    for (int p : parents) _parents.add(variables.get(p));
    DepthChoiceGenerator gen = new DepthChoiceGenerator(_parents.size(), _parents.size());
    int[] choice;
    while ((choice = gen.next()) != null) {
        int[] sel = new int[choice.length];
        List<Node> _sel = new ArrayList<>();
        for (int m = 0; m < choice.length; m++) {
            sel[m] = parents[m];
            _sel.add(variables.get(sel[m]));
        }
        TetradMatrix m = cov.getSelection(sel, sel);
        try {
            m.inverse();
        } catch (Exception e2) {
            forbidden.add(sel[0]);
            out.println("### Linear dependence among variables: " + _sel);
            out.println("### Removing " + _sel.get(0));
            return true;
        }
    }
    return false;
}
Also used : DepthChoiceGenerator(edu.cmu.tetrad.util.DepthChoiceGenerator) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) TetradMatrix(edu.cmu.tetrad.util.TetradMatrix)

Example 77 with TetradMatrix

use of edu.cmu.tetrad.util.TetradMatrix in project tetrad by cmu-phil.

the class RegressionCovariance method regress.

/**
 * Regresses the given target on the given regressors, yielding a regression
 * plane, in which coefficients are given for each regressor plus the
 * constant (if means have been specified, that is, for the last), and se,
 * t, and p values are given for each regressor.
 *
 * @param target     The variable being regressed.
 * @param regressors The list of regressors.
 * @return the regression plane.
 */
public RegressionResult regress(Node target, List<Node> regressors) {
    TetradMatrix allCorrelations = correlations.getMatrix();
    List<Node> variables = correlations.getVariables();
    int yIndex = variables.indexOf(target);
    int[] xIndices = new int[regressors.size()];
    for (int i = 0; i < regressors.size(); i++) {
        xIndices[i] = variables.indexOf(regressors.get(i));
        if (xIndices[i] == -1) {
            throw new NullPointerException("Can't find variable " + regressors.get(i) + " in this list: " + variables);
        }
    }
    TetradMatrix rX = allCorrelations.getSelection(xIndices, xIndices);
    TetradMatrix rY = allCorrelations.getSelection(xIndices, new int[] { yIndex });
    TetradMatrix bStar = rX.inverse().times(rY);
    TetradVector b = new TetradVector(bStar.rows() + 1);
    for (int k = 1; k < b.size(); k++) {
        double sdY = sd.get(yIndex);
        double sdK = sd.get(xIndices[k - 1]);
        b.set(k, bStar.get(k - 1, 0) * (sdY / sdK));
    }
    b.set(0, Double.NaN);
    if (means != null) {
        double b0 = means.get(yIndex);
        for (int i = 0; i < xIndices.length; i++) {
            b0 -= b.get(i + 1) * means.get(xIndices[i]);
        }
        b.set(0, b0);
    }
    int[] allIndices = new int[1 + regressors.size()];
    allIndices[0] = yIndex;
    for (int i = 1; i < allIndices.length; i++) {
        allIndices[i] = variables.indexOf(regressors.get(i - 1));
    }
    TetradMatrix r = allCorrelations.getSelection(allIndices, allIndices);
    TetradMatrix rInv = r.inverse();
    int n = correlations.getSampleSize();
    int k = regressors.size() + 1;
    double vY = rInv.get(0, 0);
    double r2 = 1.0 - (1.0 / vY);
    // Book says n - 1.
    double tss = n * sd.get(yIndex) * sd.get(yIndex);
    double rss = tss * (1.0 - r2);
    double seY = Math.sqrt(rss / (double) (n - k));
    TetradVector sqErr = new TetradVector(allIndices.length);
    TetradVector t = new TetradVector(allIndices.length);
    TetradVector p = new TetradVector(allIndices.length);
    sqErr.set(0, Double.NaN);
    t.set(0, Double.NaN);
    p.set(0, Double.NaN);
    TetradMatrix rxInv = rX.inverse();
    for (int i = 0; i < regressors.size(); i++) {
        double _r2 = 1.0 - (1.0 / rxInv.get(i, i));
        double _tss = n * sd.get(xIndices[i]) * sd.get(xIndices[i]);
        double _se = seY / Math.sqrt(_tss * (1.0 - _r2));
        double _t = b.get(i + 1) / _se;
        double _p = 2 * (1.0 - ProbUtils.tCdf(Math.abs(_t), n - k));
        sqErr.set(i + 1, _se);
        t.set(i + 1, _t);
        p.set(i + 1, _p);
    }
    // Graph
    this.graph = createGraph(target, allIndices, regressors, p);
    String[] vNames = createVarNamesArray(regressors);
    double[] bArray = b.toArray();
    double[] tArray = t.toArray();
    double[] pArray = p.toArray();
    double[] seArray = sqErr.toArray();
    return new RegressionResult(false, vNames, n, bArray, tArray, pArray, seArray, r2, rss, alpha, null, null);
}
Also used : TetradVector(edu.cmu.tetrad.util.TetradVector) TetradMatrix(edu.cmu.tetrad.util.TetradMatrix)

Example 78 with TetradMatrix

use of edu.cmu.tetrad.util.TetradMatrix in project tetrad by cmu-phil.

the class RegressionDatasetGeneralized method regress.

/**
 * Regresses the target on the given regressors.
 *
 * @param target     The target variable.
 * @param regressors The regressor variables.
 * @return The regression plane, specifying for each regressors its
 * coefficeint, se, t, and p values, and specifying the same for the
 * constant.
 */
public RegressionResult regress(Node target, List<Node> regressors) {
    int n = data.rows();
    int k = regressors.size() + 1;
    int _target = variables.indexOf(target);
    int[] _regressors = new int[regressors.size()];
    for (int i = 0; i < regressors.size(); i++) {
        _regressors[i] = variables.indexOf(regressors.get(i));
    }
    int[] rows = new int[data.rows()];
    for (int i = 0; i < rows.length; i++) rows[i] = i;
    // TetradMatrix y = data.viewSelection(rows, new int[]{_target}).copy();
    TetradMatrix xSub = data.getSelection(rows, _regressors);
    // TetradMatrix y = data.subsetColumns(Arrays.asList(target)).getDoubleData();
    // RectangularDataSet rectangularDataSet = data.subsetColumns(regressors);
    // TetradMatrix xSub = rectangularDataSet.getDoubleData();
    TetradMatrix X = new TetradMatrix(xSub.rows(), xSub.columns() + 1);
    for (int i = 0; i < X.rows(); i++) {
        for (int j = 0; j < X.columns(); j++) {
            if (j == 0) {
                X.set(i, j, 1);
            } else {
                X.set(i, j, xSub.get(i, j - 1));
            }
        }
    }
    // for (int i = 0; i < zList.size(); i++) {
    // zCols[i] = getVariable().indexOf(zList.get(i));
    // }
    // int[] zRows = new int[data.rows()];
    // for (int i = 0; i < data.rows(); i++) {
    // zRows[i] = i;
    // }
    TetradVector y = data.getColumn(_target);
    TetradMatrix Xt = X.transpose();
    TetradMatrix XtX = Xt.times(X);
    TetradMatrix G = XtX.inverse();
    TetradMatrix GXt = G.times(Xt);
    TetradVector b = GXt.times(y);
    TetradVector yPred = X.times(b);
    // TetradVector xRes = yPred.copy().assign(y, Functions.minus);
    TetradVector xRes = yPred.minus(y);
    double rss = rss(X, y, b);
    double se = Math.sqrt(rss / (n - k));
    double tss = tss(y);
    double r2 = 1.0 - (rss / tss);
    // TetradVector sqErr = TetradVector.instance(y.columns());
    // TetradVector t = TetradVector.instance(y.columns());
    // TetradVector p = TetradVector.instance(y.columns());
    // 
    // for (int i = 0; i < 1; i++) {
    // double _s = se * se * xTxInv.get(i, i);
    // double _se = Math.sqrt(_s);
    // double _t = b.get(i) / _se;
    // double _p = 2 * (1.0 - ProbUtils.tCdf(Math.abs(_t), n - k));
    // 
    // sqErr.set(i, _se);
    // t.set(i, _t);
    // p.set(i, _p);
    // }
    // 
    // this.graph = createOutputGraph(target.getNode(), y, regressors, p);
    // 
    String[] vNames = new String[regressors.size()];
    for (int i = 0; i < regressors.size(); i++) {
        vNames[i] = regressors.get(i).getName();
    }
    return new RegressionResult(false, vNames, n, b.toArray(), new double[0], new double[0], new double[0], r2, rss, alpha, yPred, xRes);
}
Also used : TetradVector(edu.cmu.tetrad.util.TetradVector) TetradMatrix(edu.cmu.tetrad.util.TetradMatrix)

Example 79 with TetradMatrix

use of edu.cmu.tetrad.util.TetradMatrix in project tetrad by cmu-phil.

the class Comparison2 method compare.

/**
 * Simulates data from model parameterizing the given DAG, and runs the
 * algorithm on that data, printing out error statistics.
 */
public static ComparisonResult compare(ComparisonParameters params) {
    DataSet dataSet = null;
    Graph trueDag = null;
    IndependenceTest test = null;
    Score score = null;
    ComparisonResult result = new ComparisonResult(params);
    if (params.isDataFromFile()) {
        /**
         * Set path to the data directory *
         */
        String path = "/Users/dmalinsky/Documents/research/data/danexamples";
        File dir = new File(path);
        File[] files = dir.listFiles();
        if (files == null) {
            throw new NullPointerException("No files in " + path);
        }
        for (File file : files) {
            if (file.getName().startsWith("graph") && file.getName().contains(String.valueOf(params.getGraphNum())) && file.getName().endsWith(".g.txt")) {
                params.setGraphFile(file.getName());
                trueDag = GraphUtils.loadGraphTxt(file);
                break;
            }
        }
        String trialGraph = String.valueOf(params.getGraphNum()).concat("-").concat(String.valueOf(params.getTrial())).concat(".dat.txt");
        for (File file : files) {
            if (file.getName().startsWith("graph") && file.getName().endsWith(trialGraph)) {
                Path dataFile = Paths.get(path.concat("/").concat(file.getName()));
                Delimiter delimiter = Delimiter.TAB;
                if (params.getDataType() == ComparisonParameters.DataType.Continuous) {
                    try {
                        TabularDataReader dataReader = new ContinuousTabularDataFileReader(dataFile.toFile(), delimiter);
                        dataSet = (DataSet) DataConvertUtils.toDataModel(dataReader.readInData());
                    } catch (IOException e) {
                        e.printStackTrace();
                    }
                    params.setDataFile(file.getName());
                    break;
                } else {
                    try {
                        TabularDataReader dataReader = new VerticalDiscreteTabularDataReader(dataFile.toFile(), delimiter);
                        dataSet = (DataSet) DataConvertUtils.toDataModel(dataReader.readInData());
                    } catch (IOException e) {
                        e.printStackTrace();
                    }
                    params.setDataFile(file.getName());
                    break;
                }
            }
        }
        System.out.println("current graph file = " + params.getGraphFile());
        System.out.println("current data set file = " + params.getDataFile());
    }
    if (params.isNoData()) {
        List<Node> nodes = new ArrayList<>();
        for (int i = 0; i < params.getNumVars(); i++) {
            nodes.add(new ContinuousVariable("X" + (i + 1)));
        }
        trueDag = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
        /**
         * added 5.25.16 for tsFCI *
         */
        if (params.getAlgorithm() == ComparisonParameters.Algorithm.TsFCI) {
            trueDag = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
            trueDag = TimeSeriesUtils.graphToLagGraph(trueDag, 2);
            System.out.println("Creating Time Lag Graph : " + trueDag);
        }
        /**
         * ************************
         */
        test = new IndTestDSep(trueDag);
        score = new GraphScore(trueDag);
        if (params.getAlgorithm() == null) {
            throw new IllegalArgumentException("Algorithm not set.");
        }
        long time1 = System.currentTimeMillis();
        if (params.getAlgorithm() == ComparisonParameters.Algorithm.PC) {
            if (test == null) {
                throw new IllegalArgumentException("Test not set.");
            }
            Pc search = new Pc(test);
            result.setResultGraph(search.search());
            result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
        } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.CPC) {
            if (test == null) {
                throw new IllegalArgumentException("Test not set.");
            }
            Cpc search = new Cpc(test);
            result.setResultGraph(search.search());
            result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
        } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCLocal) {
            if (test == null) {
                throw new IllegalArgumentException("Test not set.");
            }
            PcLocal search = new PcLocal(test);
            result.setResultGraph(search.search());
            result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
        } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCStableMax) {
            if (test == null) {
                throw new IllegalArgumentException("Test not set.");
            }
            PcStableMax search = new PcStableMax(test);
            result.setResultGraph(search.search());
            result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
        } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FGES) {
            if (score == null) {
                throw new IllegalArgumentException("Score not set.");
            }
            Fges search = new Fges(score);
            // search.setFaithfulnessAssumed(params.isOneEdgeFaithfulnessAssumed());
            result.setResultGraph(search.search());
            result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
        } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FCI) {
            if (test == null) {
                throw new IllegalArgumentException("Test not set.");
            }
            Fci search = new Fci(test);
            result.setResultGraph(search.search());
            result.setCorrectResult(new DagToPag(trueDag).convert());
        } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.GFCI) {
            if (test == null) {
                throw new IllegalArgumentException("Test not set.");
            }
            GFci search = new GFci(test, score);
            result.setResultGraph(search.search());
            result.setCorrectResult(new DagToPag(trueDag).convert());
        } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.TsFCI) {
            if (test == null) {
                throw new IllegalArgumentException("Test not set.");
            }
            TsFci search = new TsFci(test);
            IKnowledge knowledge = getKnowledge(trueDag);
            search.setKnowledge(knowledge);
            result.setResultGraph(search.search());
            result.setCorrectResult(new TsDagToPag(trueDag).convert());
            System.out.println("Correct result for trial = " + result.getCorrectResult());
            System.out.println("Search result for trial = " + result.getResultGraph());
        } else {
            throw new IllegalArgumentException("Unrecognized algorithm.");
        }
        long time2 = System.currentTimeMillis();
        long elapsed = time2 - time1;
        result.setElapsed(elapsed);
        result.setTrueDag(trueDag);
        return result;
    } else if (params.getDataFile() != null) {
        // dataSet = loadDataFile(params.getDataFile());
        System.out.println("Using data from file... ");
        if (params.getGraphFile() == null) {
            throw new IllegalArgumentException("True graph file not set.");
        } else {
            System.out.println("Using graph from file... ");
        // trueDag = GraphUtils.loadGraph(File params.getGraphFile());
        }
    } else {
        if (params.getNumVars() == -1) {
            throw new IllegalArgumentException("Number of variables not set.");
        }
        if (params.getNumEdges() == -1) {
            throw new IllegalArgumentException("Number of edges not set.");
        }
        if (params.getDataType() == ComparisonParameters.DataType.Continuous) {
            List<Node> nodes = new ArrayList<>();
            for (int i = 0; i < params.getNumVars(); i++) {
                nodes.add(new ContinuousVariable("X" + (i + 1)));
            }
            trueDag = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
            /**
             * added 6.08.16 for tsFCI *
             */
            if (params.getAlgorithm() == ComparisonParameters.Algorithm.TsFCI) {
                trueDag = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
                trueDag = TimeSeriesUtils.graphToLagGraph(trueDag, 2);
                System.out.println("Creating Time Lag Graph : " + trueDag);
            }
            if (params.getDataType() == null) {
                throw new IllegalArgumentException("Data type not set or inferred.");
            }
            if (params.getSampleSize() == -1) {
                throw new IllegalArgumentException("Sample size not set.");
            }
            LargeScaleSimulation sim = new LargeScaleSimulation(trueDag);
            /**
             * added 6.08.16 for tsFCI *
             */
            if (params.getAlgorithm() == ComparisonParameters.Algorithm.TsFCI) {
                sim.setCoefRange(0.20, 0.50);
            }
            /**
             * added 6.08.16 for tsFCI *
             */
            if (params.getAlgorithm() == ComparisonParameters.Algorithm.TsFCI) {
                // //                    System.out.println("Coefs matrix : " + sim.getCoefs());
                // System.out.println(MatrixUtils.toString(sim.getCoefficientMatrix()));
                // //                    System.out.println("dim = " + sim.getCoefs()[1][1]);
                // boolean isStableTetradMatrix = allEigenvaluesAreSmallerThanOneInModulus(new TetradMatrix(sim.getCoefficientMatrix()));
                // //this TetradMatrix needs to be the matrix of coefficients from the SEM!
                // if (!isStableTetradMatrix) {
                // System.out.println("%%%%%%%%%% WARNING %%%%%%%%% not a stable set of eigenvalues for data generation");
                // System.out.println("Skipping this attempt!");
                // sim.setCoefRange(0.2, 0.5);
                // dataSet = sim.simulateDataAcyclic(params.getSampleSize());
                // }
                // 
                // /***************************/
                boolean isStableTetradMatrix;
                int attempt = 1;
                int tierSize = params.getNumVars();
                int[] sub = new int[tierSize];
                int[] sub2 = new int[tierSize];
                for (int i = 0; i < tierSize; i++) {
                    sub[i] = i;
                    sub2[i] = tierSize + i;
                }
                do {
                    dataSet = sim.simulateDataFisher(params.getSampleSize());
                    // System.out.println("Variable Nodes : " + sim.getVariableNodes());
                    // System.out.println(MatrixUtils.toString(sim.getCoefficientMatrix()));
                    TetradMatrix coefMat = new TetradMatrix(sim.getCoefficientMatrix());
                    TetradMatrix B = coefMat.getSelection(sub, sub);
                    TetradMatrix Gamma1 = coefMat.getSelection(sub2, sub);
                    TetradMatrix Gamma0 = TetradMatrix.identity(tierSize).minus(B);
                    TetradMatrix A1 = Gamma0.inverse().times(Gamma1);
                    // TetradMatrix B2 = coefMat.getSelection(sub2, sub2);
                    // System.out.println("B matrix : " + B);
                    // System.out.println("B2 matrix : " + B2);
                    // System.out.println("Gamma1 matrix : " + Gamma1);
                    // isStableTetradMatrix = allEigenvaluesAreSmallerThanOneInModulus(new TetradMatrix(sim.getCoefficientMatrix()));
                    isStableTetradMatrix = TimeSeriesUtils.allEigenvaluesAreSmallerThanOneInModulus(A1);
                    System.out.println("isStableTetradMatrix? : " + isStableTetradMatrix);
                    attempt++;
                } while ((!isStableTetradMatrix) && attempt <= 5);
                if (!isStableTetradMatrix) {
                    System.out.println("%%%%%%%%%% WARNING %%%%%%%% not a stable coefficient matrix, forcing coefs to [0.15,0.3]");
                    System.out.println("Made " + (attempt - 1) + " attempts to get stable matrix.");
                    sim.setCoefRange(0.15, 0.3);
                    dataSet = sim.simulateDataFisher(params.getSampleSize());
                } else {
                    System.out.println("Coefficient matrix is stable.");
                }
            }
        } else if (params.getDataType() == ComparisonParameters.DataType.Discrete) {
            List<Node> nodes = new ArrayList<>();
            for (int i = 0; i < params.getNumVars(); i++) {
                nodes.add(new DiscreteVariable("X" + (i + 1), 3));
            }
            trueDag = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
            if (params.getDataType() == null) {
                throw new IllegalArgumentException("Data type not set or inferred.");
            }
            if (params.getSampleSize() == -1) {
                throw new IllegalArgumentException("Sample size not set.");
            }
            int[] tiers = new int[nodes.size()];
            for (int i = 0; i < nodes.size(); i++) {
                tiers[i] = i;
            }
            BayesPm pm = new BayesPm(trueDag, 3, 3);
            MlBayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM);
            dataSet = im.simulateData(params.getSampleSize(), false, tiers);
        } else {
            throw new IllegalArgumentException("Unrecognized data type.");
        }
        if (dataSet == null) {
            throw new IllegalArgumentException("No data set.");
        }
    }
    if (params.getIndependenceTest() == ComparisonParameters.IndependenceTestType.FisherZ) {
        if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Continuous) {
            throw new IllegalArgumentException("Data type previously set to something other than continuous.");
        }
        if (Double.isNaN(params.getAlpha())) {
            throw new IllegalArgumentException("Alpha not set.");
        }
        test = new IndTestFisherZ(dataSet, params.getAlpha());
        params.setDataType(ComparisonParameters.DataType.Continuous);
    } else if (params.getIndependenceTest() == ComparisonParameters.IndependenceTestType.ChiSquare) {
        if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Discrete) {
            throw new IllegalArgumentException("Data type previously set to something other than discrete.");
        }
        if (Double.isNaN(params.getAlpha())) {
            throw new IllegalArgumentException("Alpha not set.");
        }
        test = new IndTestChiSquare(dataSet, params.getAlpha());
        params.setDataType(ComparisonParameters.DataType.Discrete);
    }
    if (params.getScore() == ScoreType.SemBic) {
        if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Continuous) {
            throw new IllegalArgumentException("Data type previously set to something other than continuous.");
        }
        if (Double.isNaN(params.getPenaltyDiscount())) {
            throw new IllegalArgumentException("Penalty discount not set.");
        }
        SemBicScore semBicScore = new SemBicScore(new CovarianceMatrixOnTheFly(dataSet));
        semBicScore.setPenaltyDiscount(params.getPenaltyDiscount());
        score = semBicScore;
        params.setDataType(ComparisonParameters.DataType.Continuous);
    } else if (params.getScore() == ScoreType.BDeu) {
        if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Discrete) {
            throw new IllegalArgumentException("Data type previously set to something other than discrete.");
        }
        if (Double.isNaN(params.getSamplePrior())) {
            throw new IllegalArgumentException("Sample prior not set.");
        }
        if (Double.isNaN(params.getStructurePrior())) {
            throw new IllegalArgumentException("Structure prior not set.");
        }
        score = new BDeuScore(dataSet);
        ((BDeuScore) score).setSamplePrior(params.getSamplePrior());
        ((BDeuScore) score).setStructurePrior(params.getStructurePrior());
        params.setDataType(ComparisonParameters.DataType.Discrete);
        params.setDataType(ComparisonParameters.DataType.Discrete);
    }
    if (params.getAlgorithm() == null) {
        throw new IllegalArgumentException("Algorithm not set.");
    }
    long time1 = System.currentTimeMillis();
    if (params.getAlgorithm() == ComparisonParameters.Algorithm.PC) {
        if (test == null) {
            throw new IllegalArgumentException("Test not set.");
        }
        Pc search = new Pc(test);
        result.setResultGraph(search.search());
        result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.CPC) {
        if (test == null) {
            throw new IllegalArgumentException("Test not set.");
        }
        Cpc search = new Cpc(test);
        result.setResultGraph(search.search());
        result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCLocal) {
        if (test == null) {
            throw new IllegalArgumentException("Test not set.");
        }
        PcLocal search = new PcLocal(test);
        result.setResultGraph(search.search());
        result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCStableMax) {
        if (test == null) {
            throw new IllegalArgumentException("Test not set.");
        }
        PcStableMax search = new PcStableMax(test);
        result.setResultGraph(search.search());
        result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FGES) {
        if (score == null) {
            throw new IllegalArgumentException("Score not set.");
        }
        Fges search = new Fges(score);
        // search.setFaithfulnessAssumed(params.isOneEdgeFaithfulnessAssumed());
        result.setResultGraph(search.search());
        result.setCorrectResult(SearchGraphUtils.patternForDag(new EdgeListGraph(trueDag)));
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FCI) {
        if (test == null) {
            throw new IllegalArgumentException("Test not set.");
        }
        Fci search = new Fci(test);
        result.setResultGraph(search.search());
        result.setCorrectResult(new DagToPag(trueDag).convert());
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.GFCI) {
        if (test == null) {
            throw new IllegalArgumentException("Test not set.");
        }
        GFci search = new GFci(test, score);
        result.setResultGraph(search.search());
        result.setCorrectResult(new DagToPag(trueDag).convert());
    } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.TsFCI) {
        if (test == null) {
            throw new IllegalArgumentException("Test not set.");
        }
        TsFci search = new TsFci(test);
        IKnowledge knowledge = getKnowledge(trueDag);
        search.setKnowledge(knowledge);
        result.setResultGraph(search.search());
        result.setCorrectResult(new TsDagToPag(trueDag).convert());
    } else {
        throw new IllegalArgumentException("Unrecognized algorithm.");
    }
    long time2 = System.currentTimeMillis();
    long elapsed = time2 - time1;
    result.setElapsed(elapsed);
    result.setTrueDag(trueDag);
    return result;
}
Also used : MlBayesIm(edu.cmu.tetrad.bayes.MlBayesIm) Node(edu.cmu.tetrad.graph.Node) ArrayList(java.util.ArrayList) ArrayList(java.util.ArrayList) List(java.util.List) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) Graph(edu.cmu.tetrad.graph.Graph) TabularDataReader(edu.pitt.dbmi.data.reader.tabular.TabularDataReader) VerticalDiscreteTabularDataReader(edu.pitt.dbmi.data.reader.tabular.VerticalDiscreteTabularDataReader) LargeScaleSimulation(edu.cmu.tetrad.sem.LargeScaleSimulation) VerticalDiscreteTabularDataReader(edu.pitt.dbmi.data.reader.tabular.VerticalDiscreteTabularDataReader) Path(java.nio.file.Path) Delimiter(edu.pitt.dbmi.data.Delimiter) ContinuousTabularDataFileReader(edu.pitt.dbmi.data.reader.tabular.ContinuousTabularDataFileReader) TetradMatrix(edu.cmu.tetrad.util.TetradMatrix) BayesPm(edu.cmu.tetrad.bayes.BayesPm)

Example 80 with TetradMatrix

use of edu.cmu.tetrad.util.TetradMatrix in project tetrad by cmu-phil.

the class RegressionUtils method residuals.

public static DataSet residuals(DataSet dataSet, Graph graph) {
    Regression regression = new RegressionDataset(dataSet);
    TetradMatrix residuals = new TetradMatrix(dataSet.getNumRows(), dataSet.getNumColumns());
    for (int i = 0; i < dataSet.getNumColumns(); i++) {
        Node target = dataSet.getVariable(i);
        Node _target = graph.getNode(target.getName());
        if (_target == null) {
            throw new IllegalArgumentException("Data variable not in graph: " + target);
        }
        Set<Node> _regressors = new HashSet<>(graph.getParents(_target));
        System.out.println("For " + target + " regressors are " + _regressors);
        List<Node> regressors = new LinkedList<>();
        for (Node node : _regressors) {
            regressors.add(dataSet.getVariable(node.getName()));
        }
        RegressionResult result = regression.regress(target, regressors);
        TetradVector residualsColumn = result.getResiduals();
        // residuals.viewColumn(i).assign(residualsColumn);
        residuals.assignColumn(i, residualsColumn);
    }
    return ColtDataSet.makeContinuousData(dataSet.getVariables(), residuals);
}
Also used : TetradVector(edu.cmu.tetrad.util.TetradVector) Node(edu.cmu.tetrad.graph.Node) TetradMatrix(edu.cmu.tetrad.util.TetradMatrix) LinkedList(java.util.LinkedList) HashSet(java.util.HashSet)

Aggregations

TetradMatrix (edu.cmu.tetrad.util.TetradMatrix)161 TetradVector (edu.cmu.tetrad.util.TetradVector)46 ArrayList (java.util.ArrayList)43 Node (edu.cmu.tetrad.graph.Node)41 List (java.util.List)12 CovarianceMatrix (edu.cmu.tetrad.data.CovarianceMatrix)10 DepthChoiceGenerator (edu.cmu.tetrad.util.DepthChoiceGenerator)9 SingularMatrixException (org.apache.commons.math3.linear.SingularMatrixException)9 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)8 RegressionResult (edu.cmu.tetrad.regression.RegressionResult)8 Test (org.junit.Test)8 Regression (edu.cmu.tetrad.regression.Regression)7 RegressionDataset (edu.cmu.tetrad.regression.RegressionDataset)7 SemIm (edu.cmu.tetrad.sem.SemIm)7 Graph (edu.cmu.tetrad.graph.Graph)6 SemPm (edu.cmu.tetrad.sem.SemPm)6 Vector (java.util.Vector)6 DoubleArrayList (cern.colt.list.DoubleArrayList)5 DataSet (edu.cmu.tetrad.data.DataSet)5 ICovarianceMatrix (edu.cmu.tetrad.data.ICovarianceMatrix)5