use of edu.cmu.tetrad.data.DataSet in project tetrad by cmu-phil.
the class RSkew method search.
@Override
public Graph search(DataModel dataSet, Parameters parameters) {
if (parameters.getInt("bootstrapSampleSize") < 1) {
Graph graph = algorithm.search(dataSet, parameters);
if (graph != null) {
initialGraph = graph;
} else {
throw new IllegalArgumentException("This RSkew algorithm needs both data and a graph source as inputs; it \n" + "will orient the edges in the input graph using the data");
}
List<DataSet> dataSets = new ArrayList<>();
dataSets.add(DataUtils.getContinuousDataSet(dataSet));
Lofs2 lofs = new Lofs2(initialGraph, dataSets);
lofs.setRule(Lofs2.Rule.RSkew);
return lofs.orient();
} else {
RSkew rSkew = new RSkew(algorithm);
if (initialGraph != null) {
rSkew.setInitialGraph(initialGraph);
}
DataSet data = (DataSet) dataSet;
GeneralBootstrapTest search = new GeneralBootstrapTest(data, rSkew, parameters.getInt("bootstrapSampleSize"));
BootstrapEdgeEnsemble edgeEnsemble = BootstrapEdgeEnsemble.Highest;
switch(parameters.getInt("bootstrapEnsemble", 1)) {
case 0:
edgeEnsemble = BootstrapEdgeEnsemble.Preserved;
break;
case 1:
edgeEnsemble = BootstrapEdgeEnsemble.Highest;
break;
case 2:
edgeEnsemble = BootstrapEdgeEnsemble.Majority;
}
search.setEdgeEnsemble(edgeEnsemble);
search.setParameters(parameters);
search.setVerbose(parameters.getBoolean("verbose"));
return search.search();
}
}
use of edu.cmu.tetrad.data.DataSet in project tetrad by cmu-phil.
the class Skew method search.
@Override
public Graph search(DataModel dataSet, Parameters parameters) {
if (parameters.getInt("bootstrapSampleSize") < 1) {
Graph graph = algorithm.search(dataSet, parameters);
if (graph != null) {
initialGraph = graph;
} else {
throw new IllegalArgumentException("This Skew algorithm needs both data and a graph source as inputs; it \n" + "will orient the edges in the input graph using the data");
}
List<DataSet> dataSets = new ArrayList<>();
dataSets.add(DataUtils.getContinuousDataSet(dataSet));
Lofs2 lofs = new Lofs2(initialGraph, dataSets);
lofs.setRule(Lofs2.Rule.Skew);
return lofs.orient();
} else {
Skew skew = new Skew(algorithm);
if (initialGraph != null) {
skew.setInitialGraph(initialGraph);
}
DataSet data = (DataSet) dataSet;
GeneralBootstrapTest search = new GeneralBootstrapTest(data, skew, parameters.getInt("bootstrapSampleSize"));
BootstrapEdgeEnsemble edgeEnsemble = BootstrapEdgeEnsemble.Highest;
switch(parameters.getInt("bootstrapEnsemble", 1)) {
case 0:
edgeEnsemble = BootstrapEdgeEnsemble.Preserved;
break;
case 1:
edgeEnsemble = BootstrapEdgeEnsemble.Highest;
break;
case 2:
edgeEnsemble = BootstrapEdgeEnsemble.Majority;
}
search.setEdgeEnsemble(edgeEnsemble);
search.setParameters(parameters);
search.setVerbose(parameters.getBoolean("verbose"));
return search.search();
}
}
use of edu.cmu.tetrad.data.DataSet in project tetrad by cmu-phil.
the class Tanh method search.
@Override
public Graph search(DataModel dataSet, Parameters parameters) {
if (parameters.getInt("bootstrapSampleSize") < 1) {
Graph graph = algorithm.search(dataSet, parameters);
if (graph != null) {
initialGraph = graph;
} else {
throw new IllegalArgumentException("This Tanh algorithm needs both data and a graph source as inputs; it \n" + "will orient the edges in the input graph using the data");
}
List<DataSet> dataSets = new ArrayList<>();
dataSets.add(DataUtils.getContinuousDataSet(dataSet));
Lofs2 lofs = new Lofs2(initialGraph, dataSets);
lofs.setRule(Lofs2.Rule.Tanh);
return lofs.orient();
} else {
Tanh tanh = new Tanh(algorithm);
if (initialGraph != null) {
tanh.setInitialGraph(initialGraph);
}
DataSet data = (DataSet) dataSet;
GeneralBootstrapTest search = new GeneralBootstrapTest(data, tanh, parameters.getInt("bootstrapSampleSize"));
BootstrapEdgeEnsemble edgeEnsemble = BootstrapEdgeEnsemble.Highest;
switch(parameters.getInt("bootstrapEnsemble", 1)) {
case 0:
edgeEnsemble = BootstrapEdgeEnsemble.Preserved;
break;
case 1:
edgeEnsemble = BootstrapEdgeEnsemble.Highest;
break;
case 2:
edgeEnsemble = BootstrapEdgeEnsemble.Majority;
}
search.setEdgeEnsemble(edgeEnsemble);
search.setParameters(parameters);
search.setVerbose(parameters.getBoolean("verbose"));
return search.search();
}
}
use of edu.cmu.tetrad.data.DataSet in project tetrad by cmu-phil.
the class IndTestDirichletScore method indTestSubset.
// ==========================PUBLIC METHODS=============================//
/**
* Creates a new independence test instance for a subset of the variables.
*/
public IndependenceTest indTestSubset(List<Node> vars) {
if (vars.isEmpty()) {
throw new IllegalArgumentException("Subset may not be empty.");
}
for (Node var : vars) {
if (!variables.contains(var)) {
throw new IllegalArgumentException("All vars must be original vars");
}
}
int[] indices = new int[vars.size()];
for (int i = 0; i < indices.length; i++) {
indices[i] = indexMap.get(vars.get(i));
}
DataSet newDataSet = dataSet.subsetColumns(indices);
return new IndTestDirichletScore(newDataSet, getSamplePrior(), getStructurePrior());
}
use of edu.cmu.tetrad.data.DataSet in project tetrad by cmu-phil.
the class IndTestHsic method isIndependent.
/**
* Determines whether variable x is independent of variable y given a list of conditioning variables z.
*
* @param x the one variable being compared.
* @param y the second variable being compared.
* @param z the list of conditioning variables.
* @return true iff x _||_ y | z.
*/
public boolean isIndependent(Node y, Node x, List<Node> z) {
int m = sampleSize();
// choose kernels using median distance heuristic
Kernel xKernel = new KernelGaussian(1);
Kernel yKernel = new KernelGaussian(1);
List<Kernel> zKernel = new ArrayList<>();
yKernel.setDefaultBw(this.dataSet, y);
xKernel.setDefaultBw(this.dataSet, x);
if (!z.isEmpty()) {
for (int i = 0; i < z.size(); i++) {
Kernel Zi = new KernelGaussian(1);
Zi.setDefaultBw(this.dataSet, z.get(i));
zKernel.add(Zi);
}
}
// consruct Gram matricces
TetradMatrix Ky = null;
TetradMatrix Kx = null;
TetradMatrix Kz = null;
// use incomplete Cholesky to approximate
if (useIncompleteCholesky > 0) {
Ky = KernelUtils.incompleteCholeskyGramMatrix(Arrays.asList(yKernel), this.dataSet, Arrays.asList(y), useIncompleteCholesky);
Kx = KernelUtils.incompleteCholeskyGramMatrix(Arrays.asList(xKernel), this.dataSet, Arrays.asList(x), useIncompleteCholesky);
if (!z.isEmpty()) {
Kz = KernelUtils.incompleteCholeskyGramMatrix(zKernel, this.dataSet, z, useIncompleteCholesky);
}
} else // otherwise compute directly
{
Ky = KernelUtils.constructCentralizedGramMatrix(Arrays.asList(yKernel), this.dataSet, Arrays.asList(y));
Kx = KernelUtils.constructCentralizedGramMatrix(Arrays.asList(xKernel), this.dataSet, Arrays.asList(x));
if (!z.isEmpty()) {
Kz = KernelUtils.constructCentralizedGramMatrix(zKernel, this.dataSet, z);
}
}
// get Hilbert-Schmidt dependence measure
if (z.isEmpty()) {
if (useIncompleteCholesky > 0) {
this.hsic = empiricalHSICincompleteCholesky(Ky, Kx, m);
} else {
this.hsic = empiricalHSIC(Ky, Kx, m);
}
} else {
if (useIncompleteCholesky > 0) {
this.hsic = empiricalHSICincompleteCholesky(Ky, Kx, Kz, m);
} else {
this.hsic = empiricalHSIC(Ky, Kx, Kz, m);
}
}
// shuffle data for approximate the null distribution
double[] nullapprox = new double[this.perms];
int[] zind = null;
int ycol = this.dataSet.getColumn(y);
List<List<Integer>> clusterAssign = null;
if (!z.isEmpty()) {
// get clusters for z
KMeans kmeans = KMeans.randomClusters((m / 3));
zind = new int[z.size()];
for (int j = 0; j < z.size(); j++) {
zind[j] = dataSet.getColumn(z.get(j));
}
kmeans.cluster(dataSet.subsetColumns(z).getDoubleData());
clusterAssign = kmeans.getClusters();
}
for (int i = 0; i < this.perms; i++) {
DataSet shuffleData = new ColtDataSet((ColtDataSet) dataSet);
// shuffle data
if (z.isEmpty()) {
List<Integer> indicesList = new ArrayList<>();
for (int j = 0; j < m; j++) {
indicesList.add(j);
}
Collections.shuffle(indicesList);
for (int j = 0; j < m; j++) {
double shuffleVal = dataSet.getDouble(indicesList.get(j), ycol);
shuffleData.setDouble(j, ycol, shuffleVal);
}
} else {
// shuffle data within clusters
for (int j = 0; j < clusterAssign.size(); j++) {
List<Integer> shuffleCluster = new ArrayList<>(clusterAssign.get(j));
Collections.shuffle(shuffleCluster);
for (int k = 0; k < shuffleCluster.size(); k++) {
// first swap y;
double swapVal = dataSet.getDouble(clusterAssign.get(j).get(k), ycol);
shuffleData.setDouble(shuffleCluster.get(k), ycol, swapVal);
// now swap z
for (int zi = 0; zi < z.size(); zi++) {
swapVal = dataSet.getDouble(clusterAssign.get(j).get(k), zind[zi]);
shuffleData.setDouble(shuffleCluster.get(k), zind[zi], swapVal);
}
}
}
}
// reset bandwidths
yKernel.setDefaultBw(shuffleData, y);
for (int j = 0; j < z.size(); j++) {
zKernel.get(j).setDefaultBw(shuffleData, z.get(j));
}
// Gram matrices
TetradMatrix Kyn = null;
if (useIncompleteCholesky > 0) {
Kyn = KernelUtils.incompleteCholeskyGramMatrix(Arrays.asList(yKernel), shuffleData, Arrays.asList(y), useIncompleteCholesky);
} else {
Kyn = KernelUtils.constructCentralizedGramMatrix(Arrays.asList(yKernel), shuffleData, Arrays.asList(y));
}
TetradMatrix Kzn = null;
if (!z.isEmpty()) {
if (useIncompleteCholesky > 0) {
Kzn = KernelUtils.incompleteCholeskyGramMatrix(zKernel, shuffleData, z, useIncompleteCholesky);
} else {
Kzn = KernelUtils.constructCentralizedGramMatrix(zKernel, shuffleData, z);
}
}
// HSIC
if (z.isEmpty()) {
if (useIncompleteCholesky > 0) {
nullapprox[i] = empiricalHSICincompleteCholesky(Kyn, Kx, m);
} else {
nullapprox[i] = empiricalHSIC(Kyn, Kx, m);
}
} else {
if (useIncompleteCholesky > 0) {
nullapprox[i] = empiricalHSICincompleteCholesky(Kyn, Kx, Kz, m);
} else {
nullapprox[i] = empiricalHSIC(Kyn, Kx, Kz, m);
}
}
}
// permutation test to get p-value
double evalCdf = 0.0;
for (int i = 0; i < this.perms; i++) {
if (nullapprox[i] <= this.hsic) {
evalCdf += 1.0;
}
}
evalCdf /= (double) this.perms;
this.pValue = 1.0 - evalCdf;
// reject if pvalue <= alpha
if (this.pValue <= this.alpha) {
TetradLogger.getInstance().log("dependencies", SearchLogUtils.dependenceFactMsg(x, y, z, getPValue()));
return false;
}
if (verbose) {
TetradLogger.getInstance().log("independencies", SearchLogUtils.independenceFactMsg(x, y, z, getPValue()));
}
return true;
}
Aggregations