Search in sources :

Example 1 with KernelGaussian

use of edu.cmu.tetrad.search.kernel.KernelGaussian in project tetrad by cmu-phil.

the class IndTestHsic method isIndependent.

/**
 * Determines whether variable x is independent of variable y given a list of conditioning variables z.
 *
 * @param x the one variable being compared.
 * @param y the second variable being compared.
 * @param z the list of conditioning variables.
 * @return true iff x _||_ y | z.
 */
public boolean isIndependent(Node y, Node x, List<Node> z) {
    int m = sampleSize();
    // choose kernels using median distance heuristic
    Kernel xKernel = new KernelGaussian(1);
    Kernel yKernel = new KernelGaussian(1);
    List<Kernel> zKernel = new ArrayList<>();
    yKernel.setDefaultBw(this.dataSet, y);
    xKernel.setDefaultBw(this.dataSet, x);
    if (!z.isEmpty()) {
        for (int i = 0; i < z.size(); i++) {
            Kernel Zi = new KernelGaussian(1);
            Zi.setDefaultBw(this.dataSet, z.get(i));
            zKernel.add(Zi);
        }
    }
    // consruct Gram matricces
    TetradMatrix Ky = null;
    TetradMatrix Kx = null;
    TetradMatrix Kz = null;
    // use incomplete Cholesky to approximate
    if (useIncompleteCholesky > 0) {
        Ky = KernelUtils.incompleteCholeskyGramMatrix(Arrays.asList(yKernel), this.dataSet, Arrays.asList(y), useIncompleteCholesky);
        Kx = KernelUtils.incompleteCholeskyGramMatrix(Arrays.asList(xKernel), this.dataSet, Arrays.asList(x), useIncompleteCholesky);
        if (!z.isEmpty()) {
            Kz = KernelUtils.incompleteCholeskyGramMatrix(zKernel, this.dataSet, z, useIncompleteCholesky);
        }
    } else // otherwise compute directly
    {
        Ky = KernelUtils.constructCentralizedGramMatrix(Arrays.asList(yKernel), this.dataSet, Arrays.asList(y));
        Kx = KernelUtils.constructCentralizedGramMatrix(Arrays.asList(xKernel), this.dataSet, Arrays.asList(x));
        if (!z.isEmpty()) {
            Kz = KernelUtils.constructCentralizedGramMatrix(zKernel, this.dataSet, z);
        }
    }
    // get Hilbert-Schmidt dependence measure
    if (z.isEmpty()) {
        if (useIncompleteCholesky > 0) {
            this.hsic = empiricalHSICincompleteCholesky(Ky, Kx, m);
        } else {
            this.hsic = empiricalHSIC(Ky, Kx, m);
        }
    } else {
        if (useIncompleteCholesky > 0) {
            this.hsic = empiricalHSICincompleteCholesky(Ky, Kx, Kz, m);
        } else {
            this.hsic = empiricalHSIC(Ky, Kx, Kz, m);
        }
    }
    // shuffle data for approximate the null distribution
    double[] nullapprox = new double[this.perms];
    int[] zind = null;
    int ycol = this.dataSet.getColumn(y);
    List<List<Integer>> clusterAssign = null;
    if (!z.isEmpty()) {
        // get clusters for z
        KMeans kmeans = KMeans.randomClusters((m / 3));
        zind = new int[z.size()];
        for (int j = 0; j < z.size(); j++) {
            zind[j] = dataSet.getColumn(z.get(j));
        }
        kmeans.cluster(dataSet.subsetColumns(z).getDoubleData());
        clusterAssign = kmeans.getClusters();
    }
    for (int i = 0; i < this.perms; i++) {
        DataSet shuffleData = new ColtDataSet((ColtDataSet) dataSet);
        // shuffle data
        if (z.isEmpty()) {
            List<Integer> indicesList = new ArrayList<>();
            for (int j = 0; j < m; j++) {
                indicesList.add(j);
            }
            Collections.shuffle(indicesList);
            for (int j = 0; j < m; j++) {
                double shuffleVal = dataSet.getDouble(indicesList.get(j), ycol);
                shuffleData.setDouble(j, ycol, shuffleVal);
            }
        } else {
            // shuffle data within clusters
            for (int j = 0; j < clusterAssign.size(); j++) {
                List<Integer> shuffleCluster = new ArrayList<>(clusterAssign.get(j));
                Collections.shuffle(shuffleCluster);
                for (int k = 0; k < shuffleCluster.size(); k++) {
                    // first swap y;
                    double swapVal = dataSet.getDouble(clusterAssign.get(j).get(k), ycol);
                    shuffleData.setDouble(shuffleCluster.get(k), ycol, swapVal);
                    // now swap z
                    for (int zi = 0; zi < z.size(); zi++) {
                        swapVal = dataSet.getDouble(clusterAssign.get(j).get(k), zind[zi]);
                        shuffleData.setDouble(shuffleCluster.get(k), zind[zi], swapVal);
                    }
                }
            }
        }
        // reset bandwidths
        yKernel.setDefaultBw(shuffleData, y);
        for (int j = 0; j < z.size(); j++) {
            zKernel.get(j).setDefaultBw(shuffleData, z.get(j));
        }
        // Gram matrices
        TetradMatrix Kyn = null;
        if (useIncompleteCholesky > 0) {
            Kyn = KernelUtils.incompleteCholeskyGramMatrix(Arrays.asList(yKernel), shuffleData, Arrays.asList(y), useIncompleteCholesky);
        } else {
            Kyn = KernelUtils.constructCentralizedGramMatrix(Arrays.asList(yKernel), shuffleData, Arrays.asList(y));
        }
        TetradMatrix Kzn = null;
        if (!z.isEmpty()) {
            if (useIncompleteCholesky > 0) {
                Kzn = KernelUtils.incompleteCholeskyGramMatrix(zKernel, shuffleData, z, useIncompleteCholesky);
            } else {
                Kzn = KernelUtils.constructCentralizedGramMatrix(zKernel, shuffleData, z);
            }
        }
        // HSIC
        if (z.isEmpty()) {
            if (useIncompleteCholesky > 0) {
                nullapprox[i] = empiricalHSICincompleteCholesky(Kyn, Kx, m);
            } else {
                nullapprox[i] = empiricalHSIC(Kyn, Kx, m);
            }
        } else {
            if (useIncompleteCholesky > 0) {
                nullapprox[i] = empiricalHSICincompleteCholesky(Kyn, Kx, Kz, m);
            } else {
                nullapprox[i] = empiricalHSIC(Kyn, Kx, Kz, m);
            }
        }
    }
    // permutation test to get p-value
    double evalCdf = 0.0;
    for (int i = 0; i < this.perms; i++) {
        if (nullapprox[i] <= this.hsic) {
            evalCdf += 1.0;
        }
    }
    evalCdf /= (double) this.perms;
    this.pValue = 1.0 - evalCdf;
    // reject if pvalue <= alpha
    if (this.pValue <= this.alpha) {
        TetradLogger.getInstance().log("dependencies", SearchLogUtils.dependenceFactMsg(x, y, z, getPValue()));
        return false;
    }
    if (verbose) {
        TetradLogger.getInstance().log("independencies", SearchLogUtils.independenceFactMsg(x, y, z, getPValue()));
    }
    return true;
}
Also used : KMeans(edu.cmu.tetrad.cluster.KMeans) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) DataSet(edu.cmu.tetrad.data.DataSet) ArrayList(java.util.ArrayList) TetradMatrix(edu.cmu.tetrad.util.TetradMatrix) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) ArrayList(java.util.ArrayList) List(java.util.List) Kernel(edu.cmu.tetrad.search.kernel.Kernel) KernelGaussian(edu.cmu.tetrad.search.kernel.KernelGaussian)

Example 2 with KernelGaussian

use of edu.cmu.tetrad.search.kernel.KernelGaussian in project tetrad by cmu-phil.

the class TestKernelGaussian method testMedianBandwidth.

/**
 * Tests the bandwidth setting to the median distance between points in the sample
 */
@Test
public void testMedianBandwidth() {
    Node X = new ContinuousVariable("X");
    DataSet dataset = new ColtDataSet(5, Arrays.asList(X));
    dataset.setDouble(0, 0, 1);
    dataset.setDouble(1, 0, 2);
    dataset.setDouble(2, 0, 3);
    dataset.setDouble(3, 0, 4);
    dataset.setDouble(4, 0, 5);
    KernelGaussian kernel = new KernelGaussian(dataset, X);
    assertTrue(kernel.getBandwidth() == 2);
}
Also used : ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) DataSet(edu.cmu.tetrad.data.DataSet) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) Node(edu.cmu.tetrad.graph.Node) KernelGaussian(edu.cmu.tetrad.search.kernel.KernelGaussian) Test(org.junit.Test)

Aggregations

ColtDataSet (edu.cmu.tetrad.data.ColtDataSet)2 DataSet (edu.cmu.tetrad.data.DataSet)2 KernelGaussian (edu.cmu.tetrad.search.kernel.KernelGaussian)2 KMeans (edu.cmu.tetrad.cluster.KMeans)1 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)1 Node (edu.cmu.tetrad.graph.Node)1 Kernel (edu.cmu.tetrad.search.kernel.Kernel)1 TetradMatrix (edu.cmu.tetrad.util.TetradMatrix)1 ArrayList (java.util.ArrayList)1 List (java.util.List)1 Test (org.junit.Test)1