use of dr.math.matrixAlgebra.Matrix in project beast-mcmc by beast-dev.
the class NonPhylogeneticMultivariateTraitLikelihood method calculateLogLikelihood.
public double calculateLogLikelihood() {
double[][] traitPrecision = diffusionModel.getPrecisionmatrix();
double logDetTraitPrecision = Math.log(diffusionModel.getDeterminantPrecisionMatrix());
double[] marginalRoot = tmp2;
if (computeWishartStatistics) {
wishartStatistics = new WishartSufficientStatistics(dimTrait);
}
// Compute the contribution of each datum at the root
SufficientStatistics stats = computeInnerProductsForTips(traitPrecision, tmp2);
double conditionalSumWeight = stats.sumWeight;
double conditionalProductWeight = stats.productWeight;
double innerProducts = stats.innerProduct;
int nonMissingTips = stats.nonMissingTips;
// Add in prior and integrate
double sumWeight = conditionalSumWeight + rootPriorSampleSize;
double productWeight = conditionalProductWeight * rootPriorSampleSize;
double rootPrecision = productWeight / sumWeight;
final int rootIndex = treeModel.getRoot().getNumber();
int rootOffset = dim * rootIndex;
for (int datum = 0; datum < numData; ++datum) {
// Determine marginal root (scaled) mean
for (int d = 0; d < dimTrait; ++d) {
marginalRoot[d] = conditionalSumWeight * meanCache[rootOffset + d] + rootPriorSampleSize * rootPriorMean[d];
}
// Compute outer product contribution from prior
double yAy1 = computeWeightedAverageAndSumOfSquares(rootPriorMean, Ay, traitPrecision, dimTrait, rootPriorSampleSize);
// TODO Only need to compute once
innerProducts += yAy1;
if (DEBUG_NO_TREE) {
System.err.println("OP for root");
System.err.println("Value = " + new Vector(rootPriorMean));
System.err.print("Prec = \n" + new Matrix(traitPrecision));
System.err.println("Weight = " + rootPriorSampleSize + "\n");
}
// Compute outer product differences to complete square
double yAy2 = computeWeightedAverageAndSumOfSquares(marginalRoot, Ay, traitPrecision, dimTrait, 1.0 / sumWeight);
innerProducts -= yAy2;
// Add prior on root contribution
if (computeWishartStatistics) {
final double[] outerProducts = wishartStatistics.getScaleMatrix();
final double weight = conditionalSumWeight * rootPriorSampleSize / sumWeight;
for (int i = 0; i < dimTrait; i++) {
final double diffi = meanCache[rootOffset + i] - rootPriorMean[i];
for (int j = 0; j < dimTrait; j++) {
outerProducts[i * dimTrait + j] += diffi * weight * (meanCache[rootOffset + j] - rootPriorMean[j]);
}
}
wishartStatistics.incrementDf(1);
}
rootOffset += dimTrait;
}
if (DEBUG_NO_TREE) {
System.err.println("SumWeight : " + sumWeight);
System.err.println("ProductWeight: " + productWeight);
System.err.println("Total OP : " + innerProducts);
}
// Compute log likelihood
double logLikelihood = -LOG_SQRT_2_PI * dimTrait * nonMissingTips * numData + 0.5 * logDetTraitPrecision * nonMissingTips * numData + 0.5 * Math.log(rootPrecision) * dimTrait * numData - 0.5 * innerProducts;
if (DEBUG_NO_TREE) {
System.err.println("logLikelihood (final) = " + logLikelihood);
System.err.println("numData = " + numData);
}
// Should redraw internal node states when needed
areStatesRedrawn = false;
return logLikelihood;
}
use of dr.math.matrixAlgebra.Matrix in project beast-mcmc by beast-dev.
the class NonPhylogeneticMultivariateTraitLikelihood method computeInnerProductsForTips.
// Useful identity for computing outerproducts for Wishart statistics
// \sum (y_i - \bar{y}) (y_i - \bar{y})^{t} = \sum y_i y_i^{t} - n \bar{y} \bar{y}^t
private SufficientStatistics computeInnerProductsForTips(double[][] traitPrecision, double[] tmpVector) {
// Compute the contribution of each datum at the root
final int rootIndex = treeModel.getRoot().getNumber();
final int meanOffset = dim * rootIndex;
// Zero-out root mean
for (int d = 0; d < dim; ++d) {
meanCache[meanOffset + d] = 0;
}
double innerProducts = 0.0;
// Compute the contribution of each datum at the root
double productWeight = 1.0;
double sumWeight = 0.0;
int nonMissingTips = 0;
for (int i = 0; i < treeModel.getExternalNodeCount(); ++i) {
NodeRef tipNode = treeModel.getExternalNode(i);
final int tipNumber = tipNode.getNumber();
double tipWeight = 0.0;
if (!missingTraits.isCompletelyMissing(tipNumber)) {
tipWeight = 1.0 / getLengthToRoot(tipNode);
int tipOffset = dim * tipNumber;
int rootOffset = dim * rootIndex;
for (int datum = 0; datum < numData; ++datum) {
// Add weighted tip value
for (int d = 0; d < dimTrait; ++d) {
meanCache[rootOffset + d] += tipWeight * meanCache[tipOffset + d];
tmpVector[d] = meanCache[tipOffset + d];
}
// Compute outer product
double yAy = computeWeightedAverageAndSumOfSquares(tmpVector, Ay, traitPrecision, dimTrait, tipWeight);
innerProducts += yAy;
if (DEBUG_NO_TREE) {
System.err.println("OP for " + tipNumber + " = " + yAy);
System.err.println("Value = " + new Vector(tmpVector));
System.err.print("Prec =\n" + new Matrix(traitPrecision));
System.err.println("weight = " + tipWeight + "\n");
}
tipOffset += dimTrait;
rootOffset += dimTrait;
}
if (computeWishartStatistics) {
incrementOuterProducts(tipNumber, tipWeight);
}
}
if (tipWeight > 0.0) {
sumWeight += tipWeight;
productWeight *= tipWeight;
++nonMissingTips;
}
}
lowerPrecisionCache[rootIndex] = sumWeight;
normalize(meanCache, meanOffset, dim, sumWeight);
if (computeWishartStatistics) {
incrementOuterProducts(rootIndex, -sumWeight);
wishartStatistics.incrementDf(-1);
}
return new SufficientStatistics(sumWeight, productWeight, innerProducts, nonMissingTips);
}
use of dr.math.matrixAlgebra.Matrix in project beast-mcmc by beast-dev.
the class SemiConjugateMultivariateTraitLikelihood method setRootPrior.
private void setRootPrior(MultivariateNormalDistribution rootPrior) {
rootPriorMean = rootPrior.getMean();
rootPriorPrecision = rootPrior.getScaleMatrix();
try {
logRootPriorPrecisionDeterminant = Math.log(new Matrix(rootPriorPrecision).determinant());
} catch (IllegalDimension illegalDimension) {
illegalDimension.printStackTrace();
}
setRootPriorSumOfSquares();
}
use of dr.math.matrixAlgebra.Matrix in project beast-mcmc by beast-dev.
the class SemiConjugateMultivariateTraitLikelihood method computeMarginalRootMeanAndVariance.
protected double[][] computeMarginalRootMeanAndVariance(double[] rootMean, double[][] treePrecision, double[][] treeVariance, double rootPrecision) {
// Fills in Ay
computeWeightedAverageAndSumOfSquares(rootMean, Ay, treePrecision, dimTrait, rootPrecision);
double[][] AplusB = tmpM;
for (int i = 0; i < dimTrait; i++) {
// Ay is filled with sum, and original value is destroyed
Ay[i] += Bz[i];
for (int j = 0; j < dimTrait; j++) {
AplusB[i][j] = treePrecision[i][j] * rootPrecision + rootPriorPrecision[i][j];
}
}
Matrix mat = new Matrix(AplusB);
double[][] invAplusB = mat.inverse().toComponents();
// Expected value: (A + B)^{-1}(Ay + Bz)
for (int i = 0; i < dimTrait; i++) {
rootMean[i] = 0.0;
for (int j = 0; j < dimTrait; j++) {
rootMean[i] += invAplusB[i][j] * Ay[j];
}
}
return invAplusB;
}
use of dr.math.matrixAlgebra.Matrix in project beast-mcmc by beast-dev.
the class SemiConjugateMultivariateTraitLikelihood method integrateLogLikelihoodAtRoot.
protected double integrateLogLikelihoodAtRoot(double[] y, double[] Ay, double[][] AplusB, double[][] treePrecision, double rootPrecision) {
double detAplusB = 0;
double square = 0;
if (dimTrait > 1) {
for (int i = 0; i < dimTrait; i++) {
// Ay is filled with sum, and original value is destroyed
Ay[i] += Bz[i];
for (int j = 0; j < dimTrait; j++) {
AplusB[i][j] = treePrecision[i][j] * rootPrecision + rootPriorPrecision[i][j];
}
}
Matrix mat = new Matrix(AplusB);
try {
detAplusB = mat.determinant();
} catch (IllegalDimension illegalDimension) {
illegalDimension.printStackTrace();
}
double[][] invAplusB = mat.inverse().toComponents();
for (int i = 0; i < dimTrait; i++) {
for (int j = 0; j < dimTrait; j++) square += Ay[i] * invAplusB[i][j] * Ay[j];
}
} else {
// 1D is very simple
detAplusB = treePrecision[0][0] * rootPrecision + rootPriorPrecision[0][0];
Ay[0] += Bz[0];
square = Ay[0] * Ay[0] / detAplusB;
}
double retValue = 0.5 * (logRootPriorPrecisionDeterminant - Math.log(detAplusB) - zBz + square);
if (DEBUG) {
System.err.println("(Ay+Bz)(A+B)^{-1}(Ay+Bz) = " + square);
System.err.println("density = " + retValue);
System.err.println("zBz = " + zBz);
}
return retValue;
}
Aggregations