use of edu.cmu.tetrad.util.TetradMatrix in project tetrad by cmu-phil.
the class MVPLikelihood method multipleRegression.
private double multipleRegression(TetradVector Y, TetradMatrix X) {
int n = X.rows();
TetradVector r;
if (X.columns() >= n) {
TetradVector ones = new TetradVector(n);
for (int i = 0; i < n; i++) ones.set(i, 1);
r = ones.scalarMult(ones.dotProduct(Y) / (double) n).minus(Y);
} else {
try {
TetradMatrix Xt = X.transpose();
TetradMatrix XtX = Xt.times(X);
r = X.times(XtX.inverse().times(Xt.times(Y))).minus(Y);
} catch (Exception e) {
TetradVector ones = new TetradVector(n);
for (int i = 0; i < n; i++) ones.set(i, 1);
r = ones.scalarMult(ones.dotProduct(Y) / (double) n).minus(Y);
}
}
double sigma2 = r.dotProduct(r) / n;
double lik;
if (sigma2 < 0) {
TetradVector ones = new TetradVector(n);
for (int i = 0; i < n; i++) ones.set(i, 1);
r = ones.scalarMult(ones.dotProduct(Y) / (double) Math.max(n, 2)).minus(Y);
sigma2 = r.dotProduct(r) / n;
lik = -(n / 2) * (Math.log(2 * Math.PI) + Math.log(sigma2) + 1);
} else if (sigma2 == 0) {
lik = 0;
} else {
lik = -(n / 2) * (Math.log(2 * Math.PI) + Math.log(sigma2) + 1);
}
if (Double.isInfinite(lik) || Double.isNaN(lik)) {
System.out.println(lik);
}
return lik;
}
use of edu.cmu.tetrad.util.TetradMatrix in project tetrad by cmu-phil.
the class MVPLikelihood method approxMultinomialRegression.
private double approxMultinomialRegression(TetradMatrix Y, TetradMatrix X) {
int n = X.rows();
int d = Y.columns();
double lik = 0.0;
TetradMatrix P;
if (d >= n || X.columns() >= n) {
TetradMatrix ones = new TetradMatrix(n, 1);
for (int i = 0; i < n; i++) ones.set(i, 0, 1);
P = ones.times(ones.transpose().times(Y).scalarMult(1 / (double) n));
} else {
try {
TetradMatrix Xt = X.transpose();
TetradMatrix XtX = Xt.times(X);
P = X.times(XtX.inverse().times(Xt.times(Y)));
} catch (Exception e) {
TetradMatrix ones = new TetradMatrix(n, 1);
for (int i = 0; i < n; i++) ones.set(i, 0, 1);
P = ones.times(ones.transpose().times(Y).scalarMult(1 / (double) n));
}
for (int i = 0; i < n; i++) {
double min = 1;
double center = 1 / (double) d;
double bound = 1 / (double) n;
for (int j = 0; j < d; j++) {
min = Math.min(min, P.get(i, j));
}
if (X.columns() > 1 && min < bound) {
min = (bound - center) / (min - center);
for (int j = 0; j < d; j++) {
P.set(i, j, min * P.get(i, j) + center * (1 - min));
}
}
}
}
for (int i = 0; i < n; i++) {
lik += Math.log(P.getRow(i).dotProduct(Y.getRow(i)));
}
if (Double.isInfinite(lik) || Double.isNaN(lik)) {
System.out.println(lik);
}
return lik;
}
use of edu.cmu.tetrad.util.TetradMatrix in project tetrad by cmu-phil.
the class MVPLikelihood method getLik.
public double getLik(int child_index, int[] parents) {
double lik = 0;
Node c = variables.get(child_index);
List<ContinuousVariable> continuous_parents = new ArrayList<>();
List<DiscreteVariable> discrete_parents = new ArrayList<>();
if (c instanceof DiscreteVariable && discretize) {
for (int p : parents) {
Node parent = discreteVariables.get(p);
discrete_parents.add((DiscreteVariable) parent);
}
} else {
for (int p : parents) {
Node parent = variables.get(p);
if (parent instanceof ContinuousVariable) {
continuous_parents.add((ContinuousVariable) parent);
} else {
discrete_parents.add((DiscreteVariable) parent);
}
}
}
int p = continuous_parents.size();
List<List<Integer>> cells = adTree.getCellLeaves(discrete_parents);
int[] continuousCols = new int[p];
for (int j = 0; j < p; j++) continuousCols[j] = nodesHash.get(continuous_parents.get(j));
for (List<Integer> cell : cells) {
// for (int[] cell : cells) {
int r = cell.size();
// int r = cell.length;
if (r > 1) {
double[] mean = new double[p];
double[] var = new double[p];
for (int i = 0; i < p; i++) {
for (int j = 0; j < r; j++) {
mean[i] += continuousData[continuousCols[i]][cell.get(j)];
var[i] += Math.pow(continuousData[continuousCols[i]][cell.get(j)], 2);
}
mean[i] /= r;
var[i] /= r;
var[i] -= Math.pow(mean[i], 2);
var[i] = Math.sqrt(var[i]);
if (Double.isNaN(var[i])) {
System.out.println(var[i]);
}
}
int degree = fDegree;
if (fDegree < 1) {
degree = (int) Math.floor(Math.log(r));
}
TetradMatrix subset = new TetradMatrix(r, p * degree + 1);
for (int i = 0; i < r; i++) {
subset.set(i, p * degree, 1);
for (int j = 0; j < p; j++) {
for (int d = 0; d < degree; d++) {
subset.set(i, p * d + j, Math.pow((continuousData[continuousCols[j]][cell.get(i)] - mean[j]) / var[j], d + 1));
}
}
}
if (c instanceof ContinuousVariable) {
TetradVector target = new TetradVector(r);
for (int i = 0; i < r; i++) {
target.set(i, continuousData[child_index][cell.get(i)]);
// target.set(i, continuousData[child_index][cell[i]]);
}
lik += multipleRegression(target, subset);
} else {
TetradMatrix target = new TetradMatrix(r, ((DiscreteVariable) c).getNumCategories());
for (int i = 0; i < r; i++) {
target.set(i, discreteData[child_index][cell.get(i)], 1);
}
lik += approxMultinomialRegression(target, subset);
}
}
}
return lik;
}
use of edu.cmu.tetrad.util.TetradMatrix in project tetrad by cmu-phil.
the class Ling method zerolessDiagonalPermutations.
private List<PermutationMatrixPair> zerolessDiagonalPermutations(TetradMatrix ica_W, boolean approximateZeros, List<Node> vars, DataSet dataSet) {
List<PermutationMatrixPair> permutations = new Vector<>();
if (approximateZeros) {
// setInsignificantEntriesToZero(ica_W);
pruneEdgesByResampling(dataSet.getDoubleData());
ica_W = removeZeroRowsAndCols(ica_W, vars);
}
// find assignments
TetradMatrix mat = ica_W.transpose();
// returns all zeroless-diagonal column-permutations
List<List<Integer>> nRookAssignments = nRookColumnAssignments(mat, makeAllRows(mat.rows()));
// for each assignment, add the corresponding permutation to 'permutations'
for (List<Integer> permutation : nRookAssignments) {
TetradMatrix matrixW = permuteRows(ica_W, permutation).transpose();
PermutationMatrixPair permTetradMatrixPair = new PermutationMatrixPair(permutation, matrixW);
permutations.add(permTetradMatrixPair);
}
return permutations;
}
use of edu.cmu.tetrad.util.TetradMatrix in project tetrad by cmu-phil.
the class Ling method findCandidateModel.
private StoredGraphs findCandidateModel(List<Node> variables, TetradMatrix matrixW, boolean approximateZeros) {
TetradMatrix normalizedZldW;
List<PermutationMatrixPair> zldPerms;
StoredGraphs gs = new StoredGraphs();
System.out.println("Calculating zeroless diagonal permutations...");
TetradLogger.getInstance().log("lingDetails", "Calculating zeroless diagonal permutations.");
zldPerms = zerolessDiagonalPermutation(matrixW, approximateZeros, variables, dataSet);
// zldPerms = zerolessDiagonalPermutations(matrixW, approximateZeros, variables, dataSet);
System.out.println("Calculated zeroless diagonal permutations.");
// for each W~, compute a candidate B, and score it
for (PermutationMatrixPair zldPerm : zldPerms) {
TetradLogger.getInstance().log("lingDetails", "" + zldPerm);
System.out.println(zldPerm);
normalizedZldW = LingUtils.normalizeDiagonal(zldPerm.getMatrixW());
// Note: add method to deal with this data
// B~ = I - W~
zldPerm.setMatrixBhat(computeBhatTetradMatrix(normalizedZldW, variables));
TetradMatrix doubleData = zldPerm.getMatrixBhat().getDoubleData();
boolean isStableTetradMatrix = allEigenvaluesAreSmallerThanOneInModulus(new TetradMatrix(doubleData.toArray()));
GraphWithParameters graph = new GraphWithParameters(zldPerm.getMatrixBhat());
gs.addGraph(graph.getGraph());
gs.addStable(isStableTetradMatrix);
gs.addData(zldPerm.getMatrixBhat());
}
TetradLogger.getInstance().log("stableGraphs", "Stable Graphs:");
for (int d = 0; d < gs.getNumGraphs(); d++) {
if (!gs.isStable(d)) {
continue;
}
TetradLogger.getInstance().log("stableGraphs", "" + gs.getGraph(d));
if (TetradLogger.getInstance().getLoggerConfig() != null && TetradLogger.getInstance().getLoggerConfig().isEventActive("stableGraphs")) {
TetradLogger.getInstance().log("wMatrices", "" + gs.getData(d));
}
}
TetradLogger.getInstance().log("unstableGraphs", "Unstable Graphs:");
for (int d = 0; d < gs.getNumGraphs(); d++) {
if (gs.isStable(d)) {
continue;
}
TetradLogger.getInstance().log("unstableGraphs", "" + gs.getGraph(d));
if (TetradLogger.getInstance().getLoggerConfig() != null && TetradLogger.getInstance().getLoggerConfig().isEventActive("unstableGraphs")) {
TetradLogger.getInstance().log("wMatrices", "" + gs.getData(d));
}
}
return gs;
}
Aggregations