use of edu.cmu.tetrad.data.ColtDataSet in project tetrad by cmu-phil.
the class TestColtDataSet method testContinuous.
@Test
public final void testContinuous() {
int rows = 10;
int cols = 5;
List<Node> _variables = new LinkedList<>();
for (int i = 0; i < cols; i++) {
_variables.add(new ContinuousVariable("X" + i));
}
DataSet dataSet = new ColtDataSet(rows, _variables);
RandomUtil randomUtil = RandomUtil.getInstance();
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
dataSet.setDouble(i, j, randomUtil.nextDouble());
}
}
List<Node> variables = dataSet.getVariables();
List<Node> newVars = new LinkedList<>();
newVars.add(variables.get(2));
newVars.add(variables.get(4));
DataSet _dataSet = dataSet.subsetColumns(newVars);
assertEquals(dataSet.getDoubleData().getColumn(2).get(0), _dataSet.getDoubleData().getColumn(0).get(0), .001);
assertEquals(dataSet.getDoubleData().getColumn(4).get(0), _dataSet.getDoubleData().getColumn(1).get(0), .001);
}
use of edu.cmu.tetrad.data.ColtDataSet in project tetrad by cmu-phil.
the class TestColtDataSet method testDiscreteFromScratch.
@Test
public void testDiscreteFromScratch() {
DataSet dataSet = new ColtDataSet(0, Collections.EMPTY_LIST);
DiscreteVariable x1 = new DiscreteVariable("X1");
dataSet.addVariable(x1);
dataSet.setInt(0, 0, 0);
dataSet.setInt(1, 0, 2);
dataSet.setInt(2, 0, 1);
DiscreteVariable x2 = new DiscreteVariable("X2");
dataSet.addVariable(x2);
dataSet.setInt(0, 1, 0);
dataSet.setInt(1, 1, 2);
dataSet.setInt(2, 1, 1);
ColtDataSet _dataSet = new ColtDataSet((ColtDataSet) dataSet);
assertEquals(dataSet, _dataSet);
assertEquals(dataSet.getInt(1, 1), 2);
}
use of edu.cmu.tetrad.data.ColtDataSet in project tetrad by cmu-phil.
the class TestColtDataSet method testRemoveRows.
@Test
public void testRemoveRows() {
int rows = 10;
int cols = 5;
List<Node> variables = new LinkedList<>();
for (int i = 0; i < cols; i++) {
variables.add(new ContinuousVariable("X" + i));
}
DataSet dataSet = new ColtDataSet(rows, variables);
RandomUtil randomUtil = RandomUtil.getInstance();
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
dataSet.setDouble(i, j, randomUtil.nextDouble());
}
}
int numRows = dataSet.getNumRows();
double d = dataSet.getDouble(3, 0);
int[] _rows = new int[2];
_rows[0] = 1;
_rows[1] = 2;
dataSet.removeRows(_rows);
assertEquals(numRows - 2, dataSet.getNumRows());
assertEquals(d, dataSet.getDouble(1, 0), 0.001);
}
use of edu.cmu.tetrad.data.ColtDataSet in project tetrad by cmu-phil.
the class IndTestHsic method isIndependent.
/**
* Determines whether variable x is independent of variable y given a list of conditioning variables z.
*
* @param x the one variable being compared.
* @param y the second variable being compared.
* @param z the list of conditioning variables.
* @return true iff x _||_ y | z.
*/
public boolean isIndependent(Node y, Node x, List<Node> z) {
int m = sampleSize();
// choose kernels using median distance heuristic
Kernel xKernel = new KernelGaussian(1);
Kernel yKernel = new KernelGaussian(1);
List<Kernel> zKernel = new ArrayList<>();
yKernel.setDefaultBw(this.dataSet, y);
xKernel.setDefaultBw(this.dataSet, x);
if (!z.isEmpty()) {
for (int i = 0; i < z.size(); i++) {
Kernel Zi = new KernelGaussian(1);
Zi.setDefaultBw(this.dataSet, z.get(i));
zKernel.add(Zi);
}
}
// consruct Gram matricces
TetradMatrix Ky = null;
TetradMatrix Kx = null;
TetradMatrix Kz = null;
// use incomplete Cholesky to approximate
if (useIncompleteCholesky > 0) {
Ky = KernelUtils.incompleteCholeskyGramMatrix(Arrays.asList(yKernel), this.dataSet, Arrays.asList(y), useIncompleteCholesky);
Kx = KernelUtils.incompleteCholeskyGramMatrix(Arrays.asList(xKernel), this.dataSet, Arrays.asList(x), useIncompleteCholesky);
if (!z.isEmpty()) {
Kz = KernelUtils.incompleteCholeskyGramMatrix(zKernel, this.dataSet, z, useIncompleteCholesky);
}
} else // otherwise compute directly
{
Ky = KernelUtils.constructCentralizedGramMatrix(Arrays.asList(yKernel), this.dataSet, Arrays.asList(y));
Kx = KernelUtils.constructCentralizedGramMatrix(Arrays.asList(xKernel), this.dataSet, Arrays.asList(x));
if (!z.isEmpty()) {
Kz = KernelUtils.constructCentralizedGramMatrix(zKernel, this.dataSet, z);
}
}
// get Hilbert-Schmidt dependence measure
if (z.isEmpty()) {
if (useIncompleteCholesky > 0) {
this.hsic = empiricalHSICincompleteCholesky(Ky, Kx, m);
} else {
this.hsic = empiricalHSIC(Ky, Kx, m);
}
} else {
if (useIncompleteCholesky > 0) {
this.hsic = empiricalHSICincompleteCholesky(Ky, Kx, Kz, m);
} else {
this.hsic = empiricalHSIC(Ky, Kx, Kz, m);
}
}
// shuffle data for approximate the null distribution
double[] nullapprox = new double[this.perms];
int[] zind = null;
int ycol = this.dataSet.getColumn(y);
List<List<Integer>> clusterAssign = null;
if (!z.isEmpty()) {
// get clusters for z
KMeans kmeans = KMeans.randomClusters((m / 3));
zind = new int[z.size()];
for (int j = 0; j < z.size(); j++) {
zind[j] = dataSet.getColumn(z.get(j));
}
kmeans.cluster(dataSet.subsetColumns(z).getDoubleData());
clusterAssign = kmeans.getClusters();
}
for (int i = 0; i < this.perms; i++) {
DataSet shuffleData = new ColtDataSet((ColtDataSet) dataSet);
// shuffle data
if (z.isEmpty()) {
List<Integer> indicesList = new ArrayList<>();
for (int j = 0; j < m; j++) {
indicesList.add(j);
}
Collections.shuffle(indicesList);
for (int j = 0; j < m; j++) {
double shuffleVal = dataSet.getDouble(indicesList.get(j), ycol);
shuffleData.setDouble(j, ycol, shuffleVal);
}
} else {
// shuffle data within clusters
for (int j = 0; j < clusterAssign.size(); j++) {
List<Integer> shuffleCluster = new ArrayList<>(clusterAssign.get(j));
Collections.shuffle(shuffleCluster);
for (int k = 0; k < shuffleCluster.size(); k++) {
// first swap y;
double swapVal = dataSet.getDouble(clusterAssign.get(j).get(k), ycol);
shuffleData.setDouble(shuffleCluster.get(k), ycol, swapVal);
// now swap z
for (int zi = 0; zi < z.size(); zi++) {
swapVal = dataSet.getDouble(clusterAssign.get(j).get(k), zind[zi]);
shuffleData.setDouble(shuffleCluster.get(k), zind[zi], swapVal);
}
}
}
}
// reset bandwidths
yKernel.setDefaultBw(shuffleData, y);
for (int j = 0; j < z.size(); j++) {
zKernel.get(j).setDefaultBw(shuffleData, z.get(j));
}
// Gram matrices
TetradMatrix Kyn = null;
if (useIncompleteCholesky > 0) {
Kyn = KernelUtils.incompleteCholeskyGramMatrix(Arrays.asList(yKernel), shuffleData, Arrays.asList(y), useIncompleteCholesky);
} else {
Kyn = KernelUtils.constructCentralizedGramMatrix(Arrays.asList(yKernel), shuffleData, Arrays.asList(y));
}
TetradMatrix Kzn = null;
if (!z.isEmpty()) {
if (useIncompleteCholesky > 0) {
Kzn = KernelUtils.incompleteCholeskyGramMatrix(zKernel, shuffleData, z, useIncompleteCholesky);
} else {
Kzn = KernelUtils.constructCentralizedGramMatrix(zKernel, shuffleData, z);
}
}
// HSIC
if (z.isEmpty()) {
if (useIncompleteCholesky > 0) {
nullapprox[i] = empiricalHSICincompleteCholesky(Kyn, Kx, m);
} else {
nullapprox[i] = empiricalHSIC(Kyn, Kx, m);
}
} else {
if (useIncompleteCholesky > 0) {
nullapprox[i] = empiricalHSICincompleteCholesky(Kyn, Kx, Kz, m);
} else {
nullapprox[i] = empiricalHSIC(Kyn, Kx, Kz, m);
}
}
}
// permutation test to get p-value
double evalCdf = 0.0;
for (int i = 0; i < this.perms; i++) {
if (nullapprox[i] <= this.hsic) {
evalCdf += 1.0;
}
}
evalCdf /= (double) this.perms;
this.pValue = 1.0 - evalCdf;
// reject if pvalue <= alpha
if (this.pValue <= this.alpha) {
TetradLogger.getInstance().log("dependencies", SearchLogUtils.dependenceFactMsg(x, y, z, getPValue()));
return false;
}
if (verbose) {
TetradLogger.getInstance().log("independencies", SearchLogUtils.independenceFactMsg(x, y, z, getPValue()));
}
return true;
}
use of edu.cmu.tetrad.data.ColtDataSet in project tetrad by cmu-phil.
the class SemEstimatorWrapper method serializableInstance.
// public SemEstimatorWrapper(DataWrapper dataWrapper,
// SemPmWrapper semPmWrapper,
// SemImWrapper semImWrapper,
// Parameters params) {
// if (dataWrapper == null) {
// throw new NullPointerException();
// }
//
// if (semPmWrapper == null) {
// throw new NullPointerException();
// }
//
// if (semImWrapper == null) {
// throw new NullPointerException();
// }
//
// DataSet dataSet =
// (DataSet) dataWrapper.getSelectedDataModel();
// SemPm semPm = semPmWrapper.getSemPm();
// SemIm semIm = semImWrapper.getSemIm();
//
// this.semEstimator = new SemEstimator(dataSet, semPm, getOptimizer());
// if (!degreesOfFreedomCheck(semPm)) return;
// this.semEstimator.setTrueSemIm(semIm);
// this.semEstimator.setNumRestarts(getParams().getInt("numRestarts", 1));
// this.semEstimator.estimate();
//
// this.params = params;
//
// log();
// }
/**
* Generates a simple exemplar of this class to test serialization.
*
* @see TetradSerializableUtils
*/
public static SemEstimatorWrapper serializableInstance() {
List<Node> variables = new LinkedList<>();
ContinuousVariable x = new ContinuousVariable("X");
variables.add(x);
DataSet dataSet = new ColtDataSet(10, variables);
for (int i = 0; i < dataSet.getNumRows(); i++) {
for (int j = 0; j < dataSet.getNumColumns(); j++) {
dataSet.setDouble(i, j, RandomUtil.getInstance().nextDouble());
}
}
Dag dag = new Dag();
dag.addNode(x);
SemPm pm = new SemPm(dag);
Parameters params1 = new Parameters();
return new SemEstimatorWrapper(dataSet, pm, params1);
}
Aggregations