Search in sources :

Example 1 with WishartSufficientStatistics

use of dr.math.distributions.WishartSufficientStatistics in project beast-mcmc by beast-dev.

the class ContinuousDataLikelihoodDelegate method calculateLikelihood.

/**
     * Calculate the log likelihood of the current state.
     *
     * @return the log likelihood.
     */
@Override
public double calculateLikelihood(List<BranchOperation> branchOperations, List<NodeOperation> nodeOperations, int rootNodeNumber) throws LikelihoodException {
    // TODO Cache branchNormalization
    branchNormalization = rateTransformation.getNormalization();
    int branchUpdateCount = 0;
    for (BranchOperation op : branchOperations) {
        branchUpdateIndices[branchUpdateCount] = op.getBranchNumber();
        branchLengths[branchUpdateCount] = op.getBranchLength() * branchNormalization;
        branchUpdateCount++;
    }
    if (!updateTipData.isEmpty()) {
        if (updateTipData.getFirst() == -1) {
            // Update all tips
            setAllTipData(flip);
        } else {
            while (!updateTipData.isEmpty()) {
                int tipIndex = updateTipData.removeFirst();
                setTipData(tipIndex, flip);
            }
        }
    }
    if (updateDiffusionModel) {
        diffusionProcessDelegate.setDiffusionModels(cdi, flip);
    }
    if (branchUpdateCount > 0) {
        diffusionProcessDelegate.updateDiffusionMatrices(cdi, branchUpdateIndices, branchLengths, branchUpdateCount, flip);
    }
    if (flip) {
        // Flip all the buffers to be written to first...
        for (NodeOperation op : nodeOperations) {
            partialBufferHelper.flipOffset(op.getNodeNumber());
        }
    }
    int operationCount = nodeOperations.size();
    int k = 0;
    for (NodeOperation op : nodeOperations) {
        int nodeNum = op.getNodeNumber();
        operations[k + 0] = getActiveNodeIndex(op.getNodeNumber());
        // source node 1
        operations[k + 1] = getActiveNodeIndex(op.getLeftChild());
        // source matrix 1
        operations[k + 2] = getActiveMatrixIndex(op.getLeftChild());
        // source node 2
        operations[k + 3] = getActiveNodeIndex(op.getRightChild());
        // source matrix 2
        operations[k + 4] = getActiveMatrixIndex(op.getRightChild());
        k += ContinuousDiffusionIntegrator.OPERATION_TUPLE_SIZE;
    }
    int[] degreesOfFreedom = null;
    double[] outerProducts = null;
    if (computeWishartStatistics) {
        // TODO Abstract this ugliness away
        degreesOfFreedom = new int[numTraits];
        outerProducts = new double[dimTrait * dimTrait * numTraits];
        cdi.setWishartStatistics(degreesOfFreedom, outerProducts);
    }
    cdi.updatePostOrderPartials(operations, operationCount, computeWishartStatistics);
    double[] logLikelihoods = new double[numTraits];
    rootProcessDelegate.calculateRootLogLikelihood(cdi, partialBufferHelper.getOffsetIndex(rootNodeNumber), logLikelihoods, computeWishartStatistics);
    if (computeWishartStatistics) {
        // TODO Abstract this ugliness away
        cdi.getWishartStatistics(degreesOfFreedom, outerProducts);
        wishartStatistics = new WishartSufficientStatistics(degreesOfFreedom, outerProducts);
    } else {
        wishartStatistics = null;
    }
    double logL = 0.0;
    for (double d : logLikelihoods) {
        logL += d;
    }
    updateDiffusionModel = false;
    return logL;
}
Also used : WishartSufficientStatistics(dr.math.distributions.WishartSufficientStatistics)

Example 2 with WishartSufficientStatistics

use of dr.math.distributions.WishartSufficientStatistics in project beast-mcmc by beast-dev.

the class FullyConjugateMultivariateTraitLikelihood method setup.

private void setup() {
    if (!PostPreKnown) {
        double[][] traitPrecision = diffusionModel.getPrecisionmatrix();
        double logDetTraitPrecision = Math.log(diffusionModel.getDeterminantPrecisionMatrix());
        final boolean computeWishartStatistics = getComputeWishartSufficientStatistics();
        if (computeWishartStatistics) {
            wishartStatistics = new WishartSufficientStatistics(dimTrait);
        }
        // Use dynamic programming to compute conditional likelihoods at each internal node
        postOrderTraverse(treeModel, treeModel.getRoot(), traitPrecision, logDetTraitPrecision, computeWishartStatistics);
        doPreOrderTraversal(treeModel.getRoot());
    }
    PostPreKnown = true;
}
Also used : WishartSufficientStatistics(dr.math.distributions.WishartSufficientStatistics)

Example 3 with WishartSufficientStatistics

use of dr.math.distributions.WishartSufficientStatistics in project beast-mcmc by beast-dev.

the class NonPhylogeneticMultivariateTraitLikelihood method calculateLogLikelihood.

public double calculateLogLikelihood() {
    double[][] traitPrecision = diffusionModel.getPrecisionmatrix();
    double logDetTraitPrecision = Math.log(diffusionModel.getDeterminantPrecisionMatrix());
    double[] marginalRoot = tmp2;
    if (computeWishartStatistics) {
        wishartStatistics = new WishartSufficientStatistics(dimTrait);
    }
    // Compute the contribution of each datum at the root
    SufficientStatistics stats = computeInnerProductsForTips(traitPrecision, tmp2);
    double conditionalSumWeight = stats.sumWeight;
    double conditionalProductWeight = stats.productWeight;
    double innerProducts = stats.innerProduct;
    int nonMissingTips = stats.nonMissingTips;
    // Add in prior and integrate
    double sumWeight = conditionalSumWeight + rootPriorSampleSize;
    double productWeight = conditionalProductWeight * rootPriorSampleSize;
    double rootPrecision = productWeight / sumWeight;
    final int rootIndex = treeModel.getRoot().getNumber();
    int rootOffset = dim * rootIndex;
    for (int datum = 0; datum < numData; ++datum) {
        // Determine marginal root (scaled) mean
        for (int d = 0; d < dimTrait; ++d) {
            marginalRoot[d] = conditionalSumWeight * meanCache[rootOffset + d] + rootPriorSampleSize * rootPriorMean[d];
        }
        // Compute outer product contribution from prior
        double yAy1 = computeWeightedAverageAndSumOfSquares(rootPriorMean, Ay, traitPrecision, dimTrait, rootPriorSampleSize);
        // TODO Only need to compute once
        innerProducts += yAy1;
        if (DEBUG_NO_TREE) {
            System.err.println("OP for root");
            System.err.println("Value  = " + new Vector(rootPriorMean));
            System.err.print("Prec   = \n" + new Matrix(traitPrecision));
            System.err.println("Weight = " + rootPriorSampleSize + "\n");
        }
        // Compute outer product differences to complete square
        double yAy2 = computeWeightedAverageAndSumOfSquares(marginalRoot, Ay, traitPrecision, dimTrait, 1.0 / sumWeight);
        innerProducts -= yAy2;
        // Add prior on root contribution
        if (computeWishartStatistics) {
            final double[] outerProducts = wishartStatistics.getScaleMatrix();
            final double weight = conditionalSumWeight * rootPriorSampleSize / sumWeight;
            for (int i = 0; i < dimTrait; i++) {
                final double diffi = meanCache[rootOffset + i] - rootPriorMean[i];
                for (int j = 0; j < dimTrait; j++) {
                    outerProducts[i * dimTrait + j] += diffi * weight * (meanCache[rootOffset + j] - rootPriorMean[j]);
                }
            }
            wishartStatistics.incrementDf(1);
        }
        rootOffset += dimTrait;
    }
    if (DEBUG_NO_TREE) {
        System.err.println("SumWeight    : " + sumWeight);
        System.err.println("ProductWeight: " + productWeight);
        System.err.println("Total OP     : " + innerProducts);
    }
    // Compute log likelihood
    double logLikelihood = -LOG_SQRT_2_PI * dimTrait * nonMissingTips * numData + 0.5 * logDetTraitPrecision * nonMissingTips * numData + 0.5 * Math.log(rootPrecision) * dimTrait * numData - 0.5 * innerProducts;
    if (DEBUG_NO_TREE) {
        System.err.println("logLikelihood (final) = " + logLikelihood);
        System.err.println("numData = " + numData);
    }
    // Should redraw internal node states when needed
    areStatesRedrawn = false;
    return logLikelihood;
}
Also used : WishartSufficientStatistics(dr.math.distributions.WishartSufficientStatistics) WishartSufficientStatistics(dr.math.distributions.WishartSufficientStatistics) Matrix(dr.math.matrixAlgebra.Matrix) Vector(dr.math.matrixAlgebra.Vector)

Example 4 with WishartSufficientStatistics

use of dr.math.distributions.WishartSufficientStatistics in project beast-mcmc by beast-dev.

the class FullyConjugateMultivariateTraitLikelihood method getReport.

@Override
public String getReport() {
    StringBuilder sb = new StringBuilder();
    //        sb.append(this.g)
    //        System.err.println("Hello");
    sb.append("Tree:\n");
    sb.append(getId()).append("\t");
    sb.append(treeModel.toString());
    sb.append("\n\n");
    double[][] treeVariance = computeTreeVariance(true);
    double[][] traitPrecision = getDiffusionModel().getPrecisionmatrix();
    Matrix traitVariance = new Matrix(traitPrecision).inverse();
    double[][] jointVariance = KroneckerOperation.product(treeVariance, traitVariance.toComponents());
    sb.append("Tree variance:\n");
    sb.append(new Matrix(treeVariance));
    sb.append(matrixMin(treeVariance)).append("\t").append(matrixMax(treeVariance)).append("\t").append(matrixSum(treeVariance));
    sb.append("\n\n");
    sb.append("Trait variance:\n");
    sb.append(traitVariance);
    sb.append("\n\n");
    //        sb.append("Joint variance:\n");
    //        sb.append(new Matrix(jointVariance));
    //        sb.append("\n\n");
    sb.append("Tree dim: " + treeVariance.length + "\n");
    sb.append("data dim: " + jointVariance.length);
    sb.append("\n\n");
    double[] data = new double[jointVariance.length];
    System.arraycopy(meanCache, 0, data, 0, jointVariance.length);
    if (nodeToClampMap != null) {
        int offset = treeModel.getExternalNodeCount() * getDimTrait();
        for (Map.Entry<NodeRef, RestrictedPartials> clamps : nodeToClampMap.entrySet()) {
            double[] partials = clamps.getValue().getPartials();
            for (int i = 0; i < partials.length; ++i) {
                data[offset] = partials[i];
                ++offset;
            }
        }
    }
    sb.append("Data:\n");
    sb.append(new Vector(data)).append("\n");
    sb.append(data.length).append("\t").append(vectorMin(data)).append("\t").append(vectorMax(data)).append("\t").append(vectorSum(data));
    sb.append(treeModel.getNodeTaxon(treeModel.getExternalNode(0)).getId());
    sb.append("\n\n");
    MultivariateNormalDistribution mvn = new MultivariateNormalDistribution(new double[data.length], new Matrix(jointVariance).inverse().toComponents());
    double logDensity = mvn.logPdf(data);
    sb.append("logLikelihood: " + getLogLikelihood() + " == " + logDensity + "\n\n");
    final WishartSufficientStatistics sufficientStatistics = getWishartStatistics();
    final double[] outerProducts = sufficientStatistics.getScaleMatrix();
    sb.append("Outer-products (DP):\n");
    sb.append(new Vector(outerProducts));
    sb.append(sufficientStatistics.getDf() + "\n");
    Matrix treePrecision = new Matrix(treeVariance).inverse();
    final int n = data.length / traitPrecision.length;
    final int p = traitPrecision.length;
    double[][] tmp = new double[n][p];
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < p; ++j) {
            tmp[i][j] = data[i * p + j];
        }
    }
    Matrix y = new Matrix(tmp);
    Matrix S = null;
    try {
        // Using Matrix-Normal form
        S = y.transpose().product(treePrecision).product(y);
    } catch (IllegalDimension illegalDimension) {
        illegalDimension.printStackTrace();
    }
    sb.append("Outer-products (from tree variance:\n");
    sb.append(S);
    sb.append("\n\n");
    return sb.toString();
}
Also used : IllegalDimension(dr.math.matrixAlgebra.IllegalDimension) MultivariateNormalDistribution(dr.math.distributions.MultivariateNormalDistribution) WishartSufficientStatistics(dr.math.distributions.WishartSufficientStatistics) NodeRef(dr.evolution.tree.NodeRef) Matrix(dr.math.matrixAlgebra.Matrix) Vector(dr.math.matrixAlgebra.Vector)

Example 5 with WishartSufficientStatistics

use of dr.math.distributions.WishartSufficientStatistics in project beast-mcmc by beast-dev.

the class WishartStatisticsWrapper method restoreState.

@Override
protected void restoreState() {
    traitDataKnown = savedTraitDataKnown;
    outerProductsKnown = savedOuterProductsKnown;
    if (outerProductsKnown) {
        WishartSufficientStatistics tmp = wishartStatistics;
        wishartStatistics = savedWishartStatistics;
        savedWishartStatistics = tmp;
    }
}
Also used : WishartSufficientStatistics(dr.math.distributions.WishartSufficientStatistics)

Aggregations

WishartSufficientStatistics (dr.math.distributions.WishartSufficientStatistics)7 Vector (dr.math.matrixAlgebra.Vector)4 Matrix (dr.math.matrixAlgebra.Matrix)3 NodeRef (dr.evolution.tree.NodeRef)1 MultivariateNormalDistribution (dr.math.distributions.MultivariateNormalDistribution)1 ConjugateWishartStatisticsProvider (dr.math.interfaces.ConjugateWishartStatisticsProvider)1 IllegalDimension (dr.math.matrixAlgebra.IllegalDimension)1 SymmetricMatrix (dr.math.matrixAlgebra.SymmetricMatrix)1