Search in sources :

Example 6 with ColtDataSet

use of edu.cmu.tetrad.data.ColtDataSet in project tetrad by cmu-phil.

the class TestColtDataSet method testContinuous.

@Test
public final void testContinuous() {
    int rows = 10;
    int cols = 5;
    List<Node> _variables = new LinkedList<>();
    for (int i = 0; i < cols; i++) {
        _variables.add(new ContinuousVariable("X" + i));
    }
    DataSet dataSet = new ColtDataSet(rows, _variables);
    RandomUtil randomUtil = RandomUtil.getInstance();
    for (int i = 0; i < rows; i++) {
        for (int j = 0; j < cols; j++) {
            dataSet.setDouble(i, j, randomUtil.nextDouble());
        }
    }
    List<Node> variables = dataSet.getVariables();
    List<Node> newVars = new LinkedList<>();
    newVars.add(variables.get(2));
    newVars.add(variables.get(4));
    DataSet _dataSet = dataSet.subsetColumns(newVars);
    assertEquals(dataSet.getDoubleData().getColumn(2).get(0), _dataSet.getDoubleData().getColumn(0).get(0), .001);
    assertEquals(dataSet.getDoubleData().getColumn(4).get(0), _dataSet.getDoubleData().getColumn(1).get(0), .001);
}
Also used : ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) RandomUtil(edu.cmu.tetrad.util.RandomUtil) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) LinkedList(java.util.LinkedList) Test(org.junit.Test)

Example 7 with ColtDataSet

use of edu.cmu.tetrad.data.ColtDataSet in project tetrad by cmu-phil.

the class TestColtDataSet method testDiscreteFromScratch.

@Test
public void testDiscreteFromScratch() {
    DataSet dataSet = new ColtDataSet(0, Collections.EMPTY_LIST);
    DiscreteVariable x1 = new DiscreteVariable("X1");
    dataSet.addVariable(x1);
    dataSet.setInt(0, 0, 0);
    dataSet.setInt(1, 0, 2);
    dataSet.setInt(2, 0, 1);
    DiscreteVariable x2 = new DiscreteVariable("X2");
    dataSet.addVariable(x2);
    dataSet.setInt(0, 1, 0);
    dataSet.setInt(1, 1, 2);
    dataSet.setInt(2, 1, 1);
    ColtDataSet _dataSet = new ColtDataSet((ColtDataSet) dataSet);
    assertEquals(dataSet, _dataSet);
    assertEquals(dataSet.getInt(1, 1), 2);
}
Also used : DiscreteVariable(edu.cmu.tetrad.data.DiscreteVariable) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) DataSet(edu.cmu.tetrad.data.DataSet) Test(org.junit.Test)

Example 8 with ColtDataSet

use of edu.cmu.tetrad.data.ColtDataSet in project tetrad by cmu-phil.

the class TestColtDataSet method testRemoveRows.

@Test
public void testRemoveRows() {
    int rows = 10;
    int cols = 5;
    List<Node> variables = new LinkedList<>();
    for (int i = 0; i < cols; i++) {
        variables.add(new ContinuousVariable("X" + i));
    }
    DataSet dataSet = new ColtDataSet(rows, variables);
    RandomUtil randomUtil = RandomUtil.getInstance();
    for (int i = 0; i < rows; i++) {
        for (int j = 0; j < cols; j++) {
            dataSet.setDouble(i, j, randomUtil.nextDouble());
        }
    }
    int numRows = dataSet.getNumRows();
    double d = dataSet.getDouble(3, 0);
    int[] _rows = new int[2];
    _rows[0] = 1;
    _rows[1] = 2;
    dataSet.removeRows(_rows);
    assertEquals(numRows - 2, dataSet.getNumRows());
    assertEquals(d, dataSet.getDouble(1, 0), 0.001);
}
Also used : ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) RandomUtil(edu.cmu.tetrad.util.RandomUtil) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) LinkedList(java.util.LinkedList) Test(org.junit.Test)

Example 9 with ColtDataSet

use of edu.cmu.tetrad.data.ColtDataSet in project tetrad by cmu-phil.

the class IndTestHsic method isIndependent.

/**
 * Determines whether variable x is independent of variable y given a list of conditioning variables z.
 *
 * @param x the one variable being compared.
 * @param y the second variable being compared.
 * @param z the list of conditioning variables.
 * @return true iff x _||_ y | z.
 */
public boolean isIndependent(Node y, Node x, List<Node> z) {
    int m = sampleSize();
    // choose kernels using median distance heuristic
    Kernel xKernel = new KernelGaussian(1);
    Kernel yKernel = new KernelGaussian(1);
    List<Kernel> zKernel = new ArrayList<>();
    yKernel.setDefaultBw(this.dataSet, y);
    xKernel.setDefaultBw(this.dataSet, x);
    if (!z.isEmpty()) {
        for (int i = 0; i < z.size(); i++) {
            Kernel Zi = new KernelGaussian(1);
            Zi.setDefaultBw(this.dataSet, z.get(i));
            zKernel.add(Zi);
        }
    }
    // consruct Gram matricces
    TetradMatrix Ky = null;
    TetradMatrix Kx = null;
    TetradMatrix Kz = null;
    // use incomplete Cholesky to approximate
    if (useIncompleteCholesky > 0) {
        Ky = KernelUtils.incompleteCholeskyGramMatrix(Arrays.asList(yKernel), this.dataSet, Arrays.asList(y), useIncompleteCholesky);
        Kx = KernelUtils.incompleteCholeskyGramMatrix(Arrays.asList(xKernel), this.dataSet, Arrays.asList(x), useIncompleteCholesky);
        if (!z.isEmpty()) {
            Kz = KernelUtils.incompleteCholeskyGramMatrix(zKernel, this.dataSet, z, useIncompleteCholesky);
        }
    } else // otherwise compute directly
    {
        Ky = KernelUtils.constructCentralizedGramMatrix(Arrays.asList(yKernel), this.dataSet, Arrays.asList(y));
        Kx = KernelUtils.constructCentralizedGramMatrix(Arrays.asList(xKernel), this.dataSet, Arrays.asList(x));
        if (!z.isEmpty()) {
            Kz = KernelUtils.constructCentralizedGramMatrix(zKernel, this.dataSet, z);
        }
    }
    // get Hilbert-Schmidt dependence measure
    if (z.isEmpty()) {
        if (useIncompleteCholesky > 0) {
            this.hsic = empiricalHSICincompleteCholesky(Ky, Kx, m);
        } else {
            this.hsic = empiricalHSIC(Ky, Kx, m);
        }
    } else {
        if (useIncompleteCholesky > 0) {
            this.hsic = empiricalHSICincompleteCholesky(Ky, Kx, Kz, m);
        } else {
            this.hsic = empiricalHSIC(Ky, Kx, Kz, m);
        }
    }
    // shuffle data for approximate the null distribution
    double[] nullapprox = new double[this.perms];
    int[] zind = null;
    int ycol = this.dataSet.getColumn(y);
    List<List<Integer>> clusterAssign = null;
    if (!z.isEmpty()) {
        // get clusters for z
        KMeans kmeans = KMeans.randomClusters((m / 3));
        zind = new int[z.size()];
        for (int j = 0; j < z.size(); j++) {
            zind[j] = dataSet.getColumn(z.get(j));
        }
        kmeans.cluster(dataSet.subsetColumns(z).getDoubleData());
        clusterAssign = kmeans.getClusters();
    }
    for (int i = 0; i < this.perms; i++) {
        DataSet shuffleData = new ColtDataSet((ColtDataSet) dataSet);
        // shuffle data
        if (z.isEmpty()) {
            List<Integer> indicesList = new ArrayList<>();
            for (int j = 0; j < m; j++) {
                indicesList.add(j);
            }
            Collections.shuffle(indicesList);
            for (int j = 0; j < m; j++) {
                double shuffleVal = dataSet.getDouble(indicesList.get(j), ycol);
                shuffleData.setDouble(j, ycol, shuffleVal);
            }
        } else {
            // shuffle data within clusters
            for (int j = 0; j < clusterAssign.size(); j++) {
                List<Integer> shuffleCluster = new ArrayList<>(clusterAssign.get(j));
                Collections.shuffle(shuffleCluster);
                for (int k = 0; k < shuffleCluster.size(); k++) {
                    // first swap y;
                    double swapVal = dataSet.getDouble(clusterAssign.get(j).get(k), ycol);
                    shuffleData.setDouble(shuffleCluster.get(k), ycol, swapVal);
                    // now swap z
                    for (int zi = 0; zi < z.size(); zi++) {
                        swapVal = dataSet.getDouble(clusterAssign.get(j).get(k), zind[zi]);
                        shuffleData.setDouble(shuffleCluster.get(k), zind[zi], swapVal);
                    }
                }
            }
        }
        // reset bandwidths
        yKernel.setDefaultBw(shuffleData, y);
        for (int j = 0; j < z.size(); j++) {
            zKernel.get(j).setDefaultBw(shuffleData, z.get(j));
        }
        // Gram matrices
        TetradMatrix Kyn = null;
        if (useIncompleteCholesky > 0) {
            Kyn = KernelUtils.incompleteCholeskyGramMatrix(Arrays.asList(yKernel), shuffleData, Arrays.asList(y), useIncompleteCholesky);
        } else {
            Kyn = KernelUtils.constructCentralizedGramMatrix(Arrays.asList(yKernel), shuffleData, Arrays.asList(y));
        }
        TetradMatrix Kzn = null;
        if (!z.isEmpty()) {
            if (useIncompleteCholesky > 0) {
                Kzn = KernelUtils.incompleteCholeskyGramMatrix(zKernel, shuffleData, z, useIncompleteCholesky);
            } else {
                Kzn = KernelUtils.constructCentralizedGramMatrix(zKernel, shuffleData, z);
            }
        }
        // HSIC
        if (z.isEmpty()) {
            if (useIncompleteCholesky > 0) {
                nullapprox[i] = empiricalHSICincompleteCholesky(Kyn, Kx, m);
            } else {
                nullapprox[i] = empiricalHSIC(Kyn, Kx, m);
            }
        } else {
            if (useIncompleteCholesky > 0) {
                nullapprox[i] = empiricalHSICincompleteCholesky(Kyn, Kx, Kz, m);
            } else {
                nullapprox[i] = empiricalHSIC(Kyn, Kx, Kz, m);
            }
        }
    }
    // permutation test to get p-value
    double evalCdf = 0.0;
    for (int i = 0; i < this.perms; i++) {
        if (nullapprox[i] <= this.hsic) {
            evalCdf += 1.0;
        }
    }
    evalCdf /= (double) this.perms;
    this.pValue = 1.0 - evalCdf;
    // reject if pvalue <= alpha
    if (this.pValue <= this.alpha) {
        TetradLogger.getInstance().log("dependencies", SearchLogUtils.dependenceFactMsg(x, y, z, getPValue()));
        return false;
    }
    if (verbose) {
        TetradLogger.getInstance().log("independencies", SearchLogUtils.independenceFactMsg(x, y, z, getPValue()));
    }
    return true;
}
Also used : KMeans(edu.cmu.tetrad.cluster.KMeans) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) DataSet(edu.cmu.tetrad.data.DataSet) ArrayList(java.util.ArrayList) TetradMatrix(edu.cmu.tetrad.util.TetradMatrix) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) ArrayList(java.util.ArrayList) List(java.util.List) Kernel(edu.cmu.tetrad.search.kernel.Kernel) KernelGaussian(edu.cmu.tetrad.search.kernel.KernelGaussian)

Example 10 with ColtDataSet

use of edu.cmu.tetrad.data.ColtDataSet in project tetrad by cmu-phil.

the class SemEstimatorWrapper method serializableInstance.

// public SemEstimatorWrapper(DataWrapper dataWrapper,
// SemPmWrapper semPmWrapper,
// SemImWrapper semImWrapper,
// Parameters params) {
// if (dataWrapper == null) {
// throw new NullPointerException();
// }
// 
// if (semPmWrapper == null) {
// throw new NullPointerException();
// }
// 
// if (semImWrapper == null) {
// throw new NullPointerException();
// }
// 
// DataSet dataSet =
// (DataSet) dataWrapper.getSelectedDataModel();
// SemPm semPm = semPmWrapper.getSemPm();
// SemIm semIm = semImWrapper.getSemIm();
// 
// this.semEstimator = new SemEstimator(dataSet, semPm, getOptimizer());
// if (!degreesOfFreedomCheck(semPm)) return;
// this.semEstimator.setTrueSemIm(semIm);
// this.semEstimator.setNumRestarts(getParams().getInt("numRestarts", 1));
// this.semEstimator.estimate();
// 
// this.params = params;
// 
// log();
// }
/**
 * Generates a simple exemplar of this class to test serialization.
 *
 * @see TetradSerializableUtils
 */
public static SemEstimatorWrapper serializableInstance() {
    List<Node> variables = new LinkedList<>();
    ContinuousVariable x = new ContinuousVariable("X");
    variables.add(x);
    DataSet dataSet = new ColtDataSet(10, variables);
    for (int i = 0; i < dataSet.getNumRows(); i++) {
        for (int j = 0; j < dataSet.getNumColumns(); j++) {
            dataSet.setDouble(i, j, RandomUtil.getInstance().nextDouble());
        }
    }
    Dag dag = new Dag();
    dag.addNode(x);
    SemPm pm = new SemPm(dag);
    Parameters params1 = new Parameters();
    return new SemEstimatorWrapper(dataSet, pm, params1);
}
Also used : ContinuousVariable(edu.cmu.tetrad.data.ContinuousVariable) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) Parameters(edu.cmu.tetrad.util.Parameters) ColtDataSet(edu.cmu.tetrad.data.ColtDataSet) DataSet(edu.cmu.tetrad.data.DataSet) Node(edu.cmu.tetrad.graph.Node) SemPm(edu.cmu.tetrad.sem.SemPm) Dag(edu.cmu.tetrad.graph.Dag) LinkedList(java.util.LinkedList)

Aggregations

ColtDataSet (edu.cmu.tetrad.data.ColtDataSet)28 DataSet (edu.cmu.tetrad.data.DataSet)24 Node (edu.cmu.tetrad.graph.Node)21 ContinuousVariable (edu.cmu.tetrad.data.ContinuousVariable)17 LinkedList (java.util.LinkedList)13 Test (org.junit.Test)12 DiscreteVariable (edu.cmu.tetrad.data.DiscreteVariable)9 RandomUtil (edu.cmu.tetrad.util.RandomUtil)6 Parameters (edu.cmu.tetrad.util.Parameters)3 ArrayList (java.util.ArrayList)3 KernelGaussian (edu.cmu.tetrad.search.kernel.KernelGaussian)2 DecimalFormat (java.text.DecimalFormat)2 ParseException (java.text.ParseException)2 KMeans (edu.cmu.tetrad.cluster.KMeans)1 CellTable (edu.cmu.tetrad.data.CellTable)1 DataModelList (edu.cmu.tetrad.data.DataModelList)1 TimeSeriesData (edu.cmu.tetrad.data.TimeSeriesData)1 Dag (edu.cmu.tetrad.graph.Dag)1 RegressionDataset (edu.cmu.tetrad.regression.RegressionDataset)1 RegressionResult (edu.cmu.tetrad.regression.RegressionResult)1