use of dr.math.distributions.WishartSufficientStatistics in project beast-mcmc by beast-dev.
the class ContinuousDataLikelihoodDelegate method calculateLikelihood.
/**
* Calculate the log likelihood of the current state.
*
* @return the log likelihood.
*/
@Override
public double calculateLikelihood(List<BranchOperation> branchOperations, List<NodeOperation> nodeOperations, int rootNodeNumber) throws LikelihoodException {
// TODO Cache branchNormalization
branchNormalization = rateTransformation.getNormalization();
int branchUpdateCount = 0;
for (BranchOperation op : branchOperations) {
branchUpdateIndices[branchUpdateCount] = op.getBranchNumber();
branchLengths[branchUpdateCount] = op.getBranchLength() * branchNormalization;
branchUpdateCount++;
}
if (!updateTipData.isEmpty()) {
if (updateTipData.getFirst() == -1) {
// Update all tips
setAllTipData(flip);
} else {
while (!updateTipData.isEmpty()) {
int tipIndex = updateTipData.removeFirst();
setTipData(tipIndex, flip);
}
}
}
if (updateDiffusionModel) {
diffusionProcessDelegate.setDiffusionModels(cdi, flip);
}
if (branchUpdateCount > 0) {
diffusionProcessDelegate.updateDiffusionMatrices(cdi, branchUpdateIndices, branchLengths, branchUpdateCount, flip);
}
if (flip) {
// Flip all the buffers to be written to first...
for (NodeOperation op : nodeOperations) {
partialBufferHelper.flipOffset(op.getNodeNumber());
}
}
int operationCount = nodeOperations.size();
int k = 0;
for (NodeOperation op : nodeOperations) {
int nodeNum = op.getNodeNumber();
operations[k + 0] = getActiveNodeIndex(op.getNodeNumber());
// source node 1
operations[k + 1] = getActiveNodeIndex(op.getLeftChild());
// source matrix 1
operations[k + 2] = getActiveMatrixIndex(op.getLeftChild());
// source node 2
operations[k + 3] = getActiveNodeIndex(op.getRightChild());
// source matrix 2
operations[k + 4] = getActiveMatrixIndex(op.getRightChild());
k += ContinuousDiffusionIntegrator.OPERATION_TUPLE_SIZE;
}
int[] degreesOfFreedom = null;
double[] outerProducts = null;
if (computeWishartStatistics) {
// TODO Abstract this ugliness away
degreesOfFreedom = new int[numTraits];
outerProducts = new double[dimTrait * dimTrait * numTraits];
cdi.setWishartStatistics(degreesOfFreedom, outerProducts);
}
cdi.updatePostOrderPartials(operations, operationCount, computeWishartStatistics);
double[] logLikelihoods = new double[numTraits];
rootProcessDelegate.calculateRootLogLikelihood(cdi, partialBufferHelper.getOffsetIndex(rootNodeNumber), logLikelihoods, computeWishartStatistics);
if (computeWishartStatistics) {
// TODO Abstract this ugliness away
cdi.getWishartStatistics(degreesOfFreedom, outerProducts);
wishartStatistics = new WishartSufficientStatistics(degreesOfFreedom, outerProducts);
} else {
wishartStatistics = null;
}
double logL = 0.0;
for (double d : logLikelihoods) {
logL += d;
}
updateDiffusionModel = false;
return logL;
}
use of dr.math.distributions.WishartSufficientStatistics in project beast-mcmc by beast-dev.
the class FullyConjugateMultivariateTraitLikelihood method setup.
private void setup() {
if (!PostPreKnown) {
double[][] traitPrecision = diffusionModel.getPrecisionmatrix();
double logDetTraitPrecision = Math.log(diffusionModel.getDeterminantPrecisionMatrix());
final boolean computeWishartStatistics = getComputeWishartSufficientStatistics();
if (computeWishartStatistics) {
wishartStatistics = new WishartSufficientStatistics(dimTrait);
}
// Use dynamic programming to compute conditional likelihoods at each internal node
postOrderTraverse(treeModel, treeModel.getRoot(), traitPrecision, logDetTraitPrecision, computeWishartStatistics);
doPreOrderTraversal(treeModel.getRoot());
}
PostPreKnown = true;
}
use of dr.math.distributions.WishartSufficientStatistics in project beast-mcmc by beast-dev.
the class NonPhylogeneticMultivariateTraitLikelihood method calculateLogLikelihood.
public double calculateLogLikelihood() {
double[][] traitPrecision = diffusionModel.getPrecisionmatrix();
double logDetTraitPrecision = Math.log(diffusionModel.getDeterminantPrecisionMatrix());
double[] marginalRoot = tmp2;
if (computeWishartStatistics) {
wishartStatistics = new WishartSufficientStatistics(dimTrait);
}
// Compute the contribution of each datum at the root
SufficientStatistics stats = computeInnerProductsForTips(traitPrecision, tmp2);
double conditionalSumWeight = stats.sumWeight;
double conditionalProductWeight = stats.productWeight;
double innerProducts = stats.innerProduct;
int nonMissingTips = stats.nonMissingTips;
// Add in prior and integrate
double sumWeight = conditionalSumWeight + rootPriorSampleSize;
double productWeight = conditionalProductWeight * rootPriorSampleSize;
double rootPrecision = productWeight / sumWeight;
final int rootIndex = treeModel.getRoot().getNumber();
int rootOffset = dim * rootIndex;
for (int datum = 0; datum < numData; ++datum) {
// Determine marginal root (scaled) mean
for (int d = 0; d < dimTrait; ++d) {
marginalRoot[d] = conditionalSumWeight * meanCache[rootOffset + d] + rootPriorSampleSize * rootPriorMean[d];
}
// Compute outer product contribution from prior
double yAy1 = computeWeightedAverageAndSumOfSquares(rootPriorMean, Ay, traitPrecision, dimTrait, rootPriorSampleSize);
// TODO Only need to compute once
innerProducts += yAy1;
if (DEBUG_NO_TREE) {
System.err.println("OP for root");
System.err.println("Value = " + new Vector(rootPriorMean));
System.err.print("Prec = \n" + new Matrix(traitPrecision));
System.err.println("Weight = " + rootPriorSampleSize + "\n");
}
// Compute outer product differences to complete square
double yAy2 = computeWeightedAverageAndSumOfSquares(marginalRoot, Ay, traitPrecision, dimTrait, 1.0 / sumWeight);
innerProducts -= yAy2;
// Add prior on root contribution
if (computeWishartStatistics) {
final double[] outerProducts = wishartStatistics.getScaleMatrix();
final double weight = conditionalSumWeight * rootPriorSampleSize / sumWeight;
for (int i = 0; i < dimTrait; i++) {
final double diffi = meanCache[rootOffset + i] - rootPriorMean[i];
for (int j = 0; j < dimTrait; j++) {
outerProducts[i * dimTrait + j] += diffi * weight * (meanCache[rootOffset + j] - rootPriorMean[j]);
}
}
wishartStatistics.incrementDf(1);
}
rootOffset += dimTrait;
}
if (DEBUG_NO_TREE) {
System.err.println("SumWeight : " + sumWeight);
System.err.println("ProductWeight: " + productWeight);
System.err.println("Total OP : " + innerProducts);
}
// Compute log likelihood
double logLikelihood = -LOG_SQRT_2_PI * dimTrait * nonMissingTips * numData + 0.5 * logDetTraitPrecision * nonMissingTips * numData + 0.5 * Math.log(rootPrecision) * dimTrait * numData - 0.5 * innerProducts;
if (DEBUG_NO_TREE) {
System.err.println("logLikelihood (final) = " + logLikelihood);
System.err.println("numData = " + numData);
}
// Should redraw internal node states when needed
areStatesRedrawn = false;
return logLikelihood;
}
use of dr.math.distributions.WishartSufficientStatistics in project beast-mcmc by beast-dev.
the class FullyConjugateMultivariateTraitLikelihood method getReport.
@Override
public String getReport() {
StringBuilder sb = new StringBuilder();
// sb.append(this.g)
// System.err.println("Hello");
sb.append("Tree:\n");
sb.append(getId()).append("\t");
sb.append(treeModel.toString());
sb.append("\n\n");
double[][] treeVariance = computeTreeVariance(true);
double[][] traitPrecision = getDiffusionModel().getPrecisionmatrix();
Matrix traitVariance = new Matrix(traitPrecision).inverse();
double[][] jointVariance = KroneckerOperation.product(treeVariance, traitVariance.toComponents());
sb.append("Tree variance:\n");
sb.append(new Matrix(treeVariance));
sb.append(matrixMin(treeVariance)).append("\t").append(matrixMax(treeVariance)).append("\t").append(matrixSum(treeVariance));
sb.append("\n\n");
sb.append("Trait variance:\n");
sb.append(traitVariance);
sb.append("\n\n");
// sb.append("Joint variance:\n");
// sb.append(new Matrix(jointVariance));
// sb.append("\n\n");
sb.append("Tree dim: " + treeVariance.length + "\n");
sb.append("data dim: " + jointVariance.length);
sb.append("\n\n");
double[] data = new double[jointVariance.length];
System.arraycopy(meanCache, 0, data, 0, jointVariance.length);
if (nodeToClampMap != null) {
int offset = treeModel.getExternalNodeCount() * getDimTrait();
for (Map.Entry<NodeRef, RestrictedPartials> clamps : nodeToClampMap.entrySet()) {
double[] partials = clamps.getValue().getPartials();
for (int i = 0; i < partials.length; ++i) {
data[offset] = partials[i];
++offset;
}
}
}
sb.append("Data:\n");
sb.append(new Vector(data)).append("\n");
sb.append(data.length).append("\t").append(vectorMin(data)).append("\t").append(vectorMax(data)).append("\t").append(vectorSum(data));
sb.append(treeModel.getNodeTaxon(treeModel.getExternalNode(0)).getId());
sb.append("\n\n");
MultivariateNormalDistribution mvn = new MultivariateNormalDistribution(new double[data.length], new Matrix(jointVariance).inverse().toComponents());
double logDensity = mvn.logPdf(data);
sb.append("logLikelihood: " + getLogLikelihood() + " == " + logDensity + "\n\n");
final WishartSufficientStatistics sufficientStatistics = getWishartStatistics();
final double[] outerProducts = sufficientStatistics.getScaleMatrix();
sb.append("Outer-products (DP):\n");
sb.append(new Vector(outerProducts));
sb.append(sufficientStatistics.getDf() + "\n");
Matrix treePrecision = new Matrix(treeVariance).inverse();
final int n = data.length / traitPrecision.length;
final int p = traitPrecision.length;
double[][] tmp = new double[n][p];
for (int i = 0; i < n; ++i) {
for (int j = 0; j < p; ++j) {
tmp[i][j] = data[i * p + j];
}
}
Matrix y = new Matrix(tmp);
Matrix S = null;
try {
// Using Matrix-Normal form
S = y.transpose().product(treePrecision).product(y);
} catch (IllegalDimension illegalDimension) {
illegalDimension.printStackTrace();
}
sb.append("Outer-products (from tree variance:\n");
sb.append(S);
sb.append("\n\n");
return sb.toString();
}
use of dr.math.distributions.WishartSufficientStatistics in project beast-mcmc by beast-dev.
the class WishartStatisticsWrapper method restoreState.
@Override
protected void restoreState() {
traitDataKnown = savedTraitDataKnown;
outerProductsKnown = savedOuterProductsKnown;
if (outerProductsKnown) {
WishartSufficientStatistics tmp = wishartStatistics;
wishartStatistics = savedWishartStatistics;
savedWishartStatistics = tmp;
}
}
Aggregations