Search in sources :

Example 51 with DataSet

use of edu.cmu.tetrad.data.DataSet in project tetrad by cmu-phil.

the class RSkew method search.

@Override
public Graph search(DataModel dataSet, Parameters parameters) {
    if (parameters.getInt("bootstrapSampleSize") < 1) {
        Graph graph = algorithm.search(dataSet, parameters);
        if (graph != null) {
            initialGraph = graph;
        } else {
            throw new IllegalArgumentException("This RSkew algorithm needs both data and a graph source as inputs; it \n" + "will orient the edges in the input graph using the data");
        }
        List<DataSet> dataSets = new ArrayList<>();
        dataSets.add(DataUtils.getContinuousDataSet(dataSet));
        Lofs2 lofs = new Lofs2(initialGraph, dataSets);
        lofs.setRule(Lofs2.Rule.RSkew);
        return lofs.orient();
    } else {
        RSkew rSkew = new RSkew(algorithm);
        if (initialGraph != null) {
            rSkew.setInitialGraph(initialGraph);
        }
        DataSet data = (DataSet) dataSet;
        GeneralBootstrapTest search = new GeneralBootstrapTest(data, rSkew, parameters.getInt("bootstrapSampleSize"));
        BootstrapEdgeEnsemble edgeEnsemble = BootstrapEdgeEnsemble.Highest;
        switch(parameters.getInt("bootstrapEnsemble", 1)) {
            case 0:
                edgeEnsemble = BootstrapEdgeEnsemble.Preserved;
                break;
            case 1:
                edgeEnsemble = BootstrapEdgeEnsemble.Highest;
                break;
            case 2:
                edgeEnsemble = BootstrapEdgeEnsemble.Majority;
        }
        search.setEdgeEnsemble(edgeEnsemble);
        search.setParameters(parameters);
        search.setVerbose(parameters.getBoolean("verbose"));
        return search.search();
    }
}
Also used : Lofs2(edu.cmu.tetrad.search.Lofs2) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) TakesInitialGraph(edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph) Graph(edu.cmu.tetrad.graph.Graph) GeneralBootstrapTest(edu.pitt.dbmi.algo.bootstrap.GeneralBootstrapTest) BootstrapEdgeEnsemble(edu.pitt.dbmi.algo.bootstrap.BootstrapEdgeEnsemble) DataSet(edu.cmu.tetrad.data.DataSet) ArrayList(java.util.ArrayList)

Example 52 with DataSet

use of edu.cmu.tetrad.data.DataSet in project tetrad by cmu-phil.

the class Skew method search.

@Override
public Graph search(DataModel dataSet, Parameters parameters) {
    if (parameters.getInt("bootstrapSampleSize") < 1) {
        Graph graph = algorithm.search(dataSet, parameters);
        if (graph != null) {
            initialGraph = graph;
        } else {
            throw new IllegalArgumentException("This Skew algorithm needs both data and a graph source as inputs; it \n" + "will orient the edges in the input graph using the data");
        }
        List<DataSet> dataSets = new ArrayList<>();
        dataSets.add(DataUtils.getContinuousDataSet(dataSet));
        Lofs2 lofs = new Lofs2(initialGraph, dataSets);
        lofs.setRule(Lofs2.Rule.Skew);
        return lofs.orient();
    } else {
        Skew skew = new Skew(algorithm);
        if (initialGraph != null) {
            skew.setInitialGraph(initialGraph);
        }
        DataSet data = (DataSet) dataSet;
        GeneralBootstrapTest search = new GeneralBootstrapTest(data, skew, parameters.getInt("bootstrapSampleSize"));
        BootstrapEdgeEnsemble edgeEnsemble = BootstrapEdgeEnsemble.Highest;
        switch(parameters.getInt("bootstrapEnsemble", 1)) {
            case 0:
                edgeEnsemble = BootstrapEdgeEnsemble.Preserved;
                break;
            case 1:
                edgeEnsemble = BootstrapEdgeEnsemble.Highest;
                break;
            case 2:
                edgeEnsemble = BootstrapEdgeEnsemble.Majority;
        }
        search.setEdgeEnsemble(edgeEnsemble);
        search.setParameters(parameters);
        search.setVerbose(parameters.getBoolean("verbose"));
        return search.search();
    }
}
Also used : Lofs2(edu.cmu.tetrad.search.Lofs2) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) TakesInitialGraph(edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph) Graph(edu.cmu.tetrad.graph.Graph) GeneralBootstrapTest(edu.pitt.dbmi.algo.bootstrap.GeneralBootstrapTest) BootstrapEdgeEnsemble(edu.pitt.dbmi.algo.bootstrap.BootstrapEdgeEnsemble) DataSet(edu.cmu.tetrad.data.DataSet) ArrayList(java.util.ArrayList)

Example 53 with DataSet

use of edu.cmu.tetrad.data.DataSet in project tetrad by cmu-phil.

the class Tanh method search.

@Override
public Graph search(DataModel dataSet, Parameters parameters) {
    if (parameters.getInt("bootstrapSampleSize") < 1) {
        Graph graph = algorithm.search(dataSet, parameters);
        if (graph != null) {
            initialGraph = graph;
        } else {
            throw new IllegalArgumentException("This Tanh algorithm needs both data and a graph source as inputs; it \n" + "will orient the edges in the input graph using the data");
        }
        List<DataSet> dataSets = new ArrayList<>();
        dataSets.add(DataUtils.getContinuousDataSet(dataSet));
        Lofs2 lofs = new Lofs2(initialGraph, dataSets);
        lofs.setRule(Lofs2.Rule.Tanh);
        return lofs.orient();
    } else {
        Tanh tanh = new Tanh(algorithm);
        if (initialGraph != null) {
            tanh.setInitialGraph(initialGraph);
        }
        DataSet data = (DataSet) dataSet;
        GeneralBootstrapTest search = new GeneralBootstrapTest(data, tanh, parameters.getInt("bootstrapSampleSize"));
        BootstrapEdgeEnsemble edgeEnsemble = BootstrapEdgeEnsemble.Highest;
        switch(parameters.getInt("bootstrapEnsemble", 1)) {
            case 0:
                edgeEnsemble = BootstrapEdgeEnsemble.Preserved;
                break;
            case 1:
                edgeEnsemble = BootstrapEdgeEnsemble.Highest;
                break;
            case 2:
                edgeEnsemble = BootstrapEdgeEnsemble.Majority;
        }
        search.setEdgeEnsemble(edgeEnsemble);
        search.setParameters(parameters);
        search.setVerbose(parameters.getBoolean("verbose"));
        return search.search();
    }
}
Also used : Lofs2(edu.cmu.tetrad.search.Lofs2) EdgeListGraph(edu.cmu.tetrad.graph.EdgeListGraph) TakesInitialGraph(edu.cmu.tetrad.algcomparison.utils.TakesInitialGraph) Graph(edu.cmu.tetrad.graph.Graph) GeneralBootstrapTest(edu.pitt.dbmi.algo.bootstrap.GeneralBootstrapTest) BootstrapEdgeEnsemble(edu.pitt.dbmi.algo.bootstrap.BootstrapEdgeEnsemble) DataSet(edu.cmu.tetrad.data.DataSet) ArrayList(java.util.ArrayList)

Example 54 with DataSet

use of edu.cmu.tetrad.data.DataSet in project tetrad by cmu-phil.

the class IndTestDirichletScore method indTestSubset.

// ==========================PUBLIC METHODS=============================//
/**
 * Creates a new independence test instance for a subset of the variables.
 */
public IndependenceTest indTestSubset(List<Node> vars) {
    if (vars.isEmpty()) {
        throw new IllegalArgumentException("Subset may not be empty.");
    }
    for (Node var : vars) {
        if (!variables.contains(var)) {
            throw new IllegalArgumentException("All vars must be original vars");
        }
    }
    int[] indices = new int[vars.size()];
    for (int i = 0; i < indices.length; i++) {
        indices[i] = indexMap.get(vars.get(i));
    }
    DataSet newDataSet = dataSet.subsetColumns(indices);
    return new IndTestDirichletScore(newDataSet, getSamplePrior(), getStructurePrior());
}
Also used : DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node)

Example 55 with DataSet

use of edu.cmu.tetrad.data.DataSet in project tetrad by cmu-phil.

the class IndTestHsic method isIndependent.

/**
 * Determines whether variable x is independent of variable y given a list of conditioning variables z.
 *
 * @param x the one variable being compared.
 * @param y the second variable being compared.
 * @param z the list of conditioning variables.
 * @return true iff x _||_ y | z.
 */
public boolean isIndependent(Node y, Node x, List<Node> z) {
    int m = sampleSize();
    // choose kernels using median distance heuristic
    Kernel xKernel = new KernelGaussian(1);
    Kernel yKernel = new KernelGaussian(1);
    List<Kernel> zKernel = new ArrayList<>();
    yKernel.setDefaultBw(this.dataSet, y);
    xKernel.setDefaultBw(this.dataSet, x);
    if (!z.isEmpty()) {
        for (int i = 0; i < z.size(); i++) {
            Kernel Zi = new KernelGaussian(1);
            Zi.setDefaultBw(this.dataSet, z.get(i));
            zKernel.add(Zi);
        }
    }
    // consruct Gram matricces
    TetradMatrix Ky = null;
    TetradMatrix Kx = null;
    TetradMatrix Kz = null;
    // use incomplete Cholesky to approximate
    if (useIncompleteCholesky > 0) {
        Ky = KernelUtils.incompleteCholeskyGramMatrix(Arrays.asList(yKernel), this.dataSet, Arrays.asList(y), useIncompleteCholesky);
        Kx = KernelUtils.incompleteCholeskyGramMatrix(Arrays.asList(xKernel), this.dataSet, Arrays.asList(x), useIncompleteCholesky);
        if (!z.isEmpty()) {
            Kz = KernelUtils.incompleteCholeskyGramMatrix(zKernel, this.dataSet, z, useIncompleteCholesky);
        }
    } else // otherwise compute directly
    {
        Ky = KernelUtils.constructCentralizedGramMatrix(Arrays.asList(yKernel), this.dataSet, Arrays.asList(y));
        Kx = KernelUtils.constructCentralizedGramMatrix(Arrays.asList(xKernel), this.dataSet, Arrays.asList(x));
        if (!z.isEmpty()) {
            Kz = KernelUtils.constructCentralizedGramMatrix(zKernel, this.dataSet, z);
        }
    }
    // get Hilbert-Schmidt dependence measure
    if (z.isEmpty()) {
        if (useIncompleteCholesky > 0) {
            this.hsic = empiricalHSICincompleteCholesky(Ky, Kx, m);
        } else {
            this.hsic = empiricalHSIC(Ky, Kx, m);
        }
    } else {
        if (useIncompleteCholesky > 0) {
            this.hsic = empiricalHSICincompleteCholesky(Ky, Kx, Kz, m);
        } else {
            this.hsic = empiricalHSIC(Ky, Kx, Kz, m);
        }
    }
    // shuffle data for approximate the null distribution
    double[] nullapprox = new double[this.perms];
    int[] zind = null;
    int ycol = this.dataSet.getColumn(y);
    List<List<Integer>> clusterAssign = null;
    if (!z.isEmpty()) {
        // get clusters for z
        KMeans kmeans = KMeans.randomClusters((m / 3));
        zind = new int[z.size()];
        for (int j = 0; j < z.size(); j++) {
            zind[j] = dataSet.getColumn(z.get(j));
        }
        kmeans.cluster(dataSet.subsetColumns(z).getDoubleData());
        clusterAssign = kmeans.getClusters();
    }
    for (int i = 0; i < this.perms; i++) {
        DataSet shuffleData = new ColtDataSet((ColtDataSet) dataSet);
        // shuffle data
        if (z.isEmpty()) {
            List<Integer> indicesList = new ArrayList<>();
            for (int j = 0; j < m; j++) {
                indicesList.add(j);
            }
            Collections.shuffle(indicesList);
            for (int j = 0; j < m; j++) {
                double shuffleVal = dataSet.getDouble(indicesList.get(j), ycol);
                shuffleData.setDouble(j, ycol, shuffleVal);
            }
        } else {
            // shuffle data within clusters
            for (int j = 0; j < clusterAssign.size(); j++) {
                List<Integer> shuffleCluster = new ArrayList<>(clusterAssign.get(j));
                Collections.shuffle(shuffleCluster);
                for (int k = 0; k < shuffleCluster.size(); k++) {
                    // first swap y;
                    double swapVal = dataSet.getDouble(clusterAssign.get(j).get(k), ycol);
                    shuffleData.setDouble(shuffleCluster.get(k), ycol, swapVal);
                    // now swap z
                    for (int zi = 0; zi < z.size(); zi++) {
                        swapVal = dataSet.getDouble(clusterAssign.get(j).get(k), zind[zi]);
                        shuffleData.setDouble(shuffleCluster.get(k), zind[zi], swapVal);
                    }
                }
            }
        }
        // reset bandwidths
        yKernel.setDefaultBw(shuffleData, y);
        for (int j = 0; j < z.size(); j++) {
            zKernel.get(j).setDefaultBw(shuffleData, z.get(j));
        }
        // Gram matrices
        TetradMatrix Kyn = null;
        if (useIncompleteCholesky > 0) {
            Kyn = KernelUtils.incompleteCholeskyGramMatrix(Arrays.asList(yKernel), shuffleData, Arrays.asList(y), useIncompleteCholesky);
        } else {
            Kyn = KernelUtils.constructCentralizedGramMatrix(Arrays.asList(yKernel), shuffleData, Arrays.asList(y));
        }
        TetradMatrix Kzn = null;
        if (!z.isEmpty()) {
            if (useIncompleteCholesky > 0) {
                Kzn = KernelUtils.incompleteCholeskyGramMatrix(zKernel, shuffleData, z, useIncompleteCholesky);
            } else {
                Kzn = KernelUtils.constructCentralizedGramMatrix(zKernel, shuffleData, z);
            }
        }
        // HSIC
        if (z.isEmpty()) {
            if (useIncompleteCholesky > 0) {
                nullapprox[i] = empiricalHSICincompleteCholesky(Kyn, Kx, m);
            } else {
                nullapprox[i] = empiricalHSIC(Kyn, Kx, m);
            }
        } else {
            if (useIncompleteCholesky > 0) {
                nullapprox[i] = empiricalHSICincompleteCholesky(Kyn, Kx, Kz, m);
            } else {
                nullapprox[i] = empiricalHSIC(Kyn, Kx, Kz, m);
            }
        }
    }
    // permutation test to get p-value
    double evalCdf = 0.0;
    for (int i = 0; i < this.perms; i++) {
        if (nullapprox[i] <= this.hsic) {
            evalCdf += 1.0;
        }
    }
    evalCdf /= (double) this.perms;
    this.pValue = 1.0 - evalCdf;
    // reject if pvalue <= alpha
    if (this.pValue <= this.alpha) {
        TetradLogger.getInstance().log("dependencies", SearchLogUtils.dependenceFactMsg(x, y, z, getPValue()));
        return false;
    }
    if (verbose) {
        TetradLogger.getInstance().log("independencies", SearchLogUtils.independenceFactMsg(x, y, z, getPValue()));
    }
    return true;
}
Also used : KMeans(edu.cmu.tetrad.cluster.KMeans) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) DataSet(edu.cmu.tetrad.data.DataSet) ArrayList(java.util.ArrayList) TetradMatrix(edu.cmu.tetrad.util.TetradMatrix) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) ArrayList(java.util.ArrayList) List(java.util.List) Kernel(edu.cmu.tetrad.search.kernel.Kernel) KernelGaussian(edu.cmu.tetrad.search.kernel.KernelGaussian)

Aggregations

DataSet (edu.cmu.tetrad.data.DataSet)216 Test (org.junit.Test)65 Graph (edu.cmu.tetrad.graph.Graph)64 Node (edu.cmu.tetrad.graph.Node)60 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)48 ArrayList (java.util.ArrayList)45 ColtDataSet (edu.cmu.tetrad.data.ColtDataSet)36 GeneralBootstrapTest (edu.pitt.dbmi.algo.bootstrap.GeneralBootstrapTest)32 EdgeListGraph (edu.cmu.tetrad.graph.EdgeListGraph)29 SemIm (edu.cmu.tetrad.sem.SemIm)28 SemPm (edu.cmu.tetrad.sem.SemPm)28 BootstrapEdgeEnsemble (edu.pitt.dbmi.algo.bootstrap.BootstrapEdgeEnsemble)26 DataModel (edu.cmu.tetrad.data.DataModel)22 Parameters (edu.cmu.tetrad.util.Parameters)22 DiscreteVariable (edu.cmu.tetrad.data.DiscreteVariable)20 File (java.io.File)16 ParseException (java.text.ParseException)16 LinkedList (java.util.LinkedList)14 ICovarianceMatrix (edu.cmu.tetrad.data.ICovarianceMatrix)13 DMSearch (edu.cmu.tetrad.search.DMSearch)10