Search in sources :

Example 6 with WishartSufficientStatistics

use of dr.math.distributions.WishartSufficientStatistics in project beast-mcmc by beast-dev.

the class IntegratedMultivariateTraitLikelihood method calculateLogLikelihood.

public double calculateLogLikelihood() {
    if (updateRestrictedNodePartials) {
        if (clampList != null) {
            setupClamps();
        }
        updateRestrictedNodePartials = false;
    }
    double logLikelihood = 0;
    double[][] traitPrecision = diffusionModel.getPrecisionmatrix();
    double logDetTraitPrecision = Math.log(diffusionModel.getDeterminantPrecisionMatrix());
    double[] conditionalRootMean = tmp2;
    final boolean computeWishartStatistics = getComputeWishartSufficientStatistics();
    if (computeWishartStatistics) {
        wishartStatistics = new WishartSufficientStatistics(dimTrait);
    }
    // Use dynamic programming to compute conditional likelihoods at each internal node
    postOrderTraverse(treeModel, treeModel.getRoot(), traitPrecision, logDetTraitPrecision, computeWishartStatistics);
    if (DEBUG) {
        System.err.println("mean: " + new Vector(cacheHelper.getMeanCache()));
        System.err.println("correctedMean: " + new Vector(cacheHelper.getCorrectedMeanCache()));
        System.err.println("upre: " + new Vector(upperPrecisionCache));
        System.err.println("lpre: " + new Vector(lowerPrecisionCache));
        System.err.println("cach: " + new Vector(logRemainderDensityCache));
    }
    // Compute the contribution of each datum at the root
    final int rootIndex = treeModel.getRoot().getNumber();
    // Precision scalar of datum conditional on root
    double conditionalRootPrecision = lowerPrecisionCache[rootIndex];
    for (int datum = 0; datum < numData; datum++) {
        double thisLogLikelihood = 0;
        // Get conditional mean of datum conditional on root
        // System.arraycopy(meanCache, rootIndex * dim + datum * dimTrait, conditionalRootMean, 0, dimTrait);
        System.arraycopy(cacheHelper.getMeanCache(), rootIndex * dim + datum * dimTrait, conditionalRootMean, 0, dimTrait);
        if (DEBUG) {
            System.err.println("Datum #" + datum);
            System.err.println("root mean: " + new Vector(conditionalRootMean));
            System.err.println("root prec: " + conditionalRootPrecision);
            System.err.println("diffusion prec: " + new Matrix(traitPrecision));
        }
        // B = root prior precision
        // z = root prior mean
        // A = likelihood precision
        // y = likelihood mean
        // y'Ay
        double yAy = computeWeightedAverageAndSumOfSquares(conditionalRootMean, Ay, traitPrecision, dimTrait, // Also fills in Ay
        conditionalRootPrecision);
        if (conditionalRootPrecision != 0) {
            thisLogLikelihood += -LOG_SQRT_2_PI * dimTrait + 0.5 * (logDetTraitPrecision + dimTrait * Math.log(conditionalRootPrecision) - yAy);
        }
        if (DEBUG) {
            double[][] T = new double[dimTrait][dimTrait];
            for (int i = 0; i < dimTrait; i++) {
                for (int j = 0; j < dimTrait; j++) {
                    T[i][j] = traitPrecision[i][j] * conditionalRootPrecision;
                }
            }
            System.err.println("Conditional root MVN precision = \n" + new Matrix(T));
            System.err.println("Conditional root MVN density = " + MultivariateNormalDistribution.logPdf(conditionalRootMean, new double[dimTrait], T, Math.log(MultivariateNormalDistribution.calculatePrecisionMatrixDeterminate(T)), 1.0));
        }
        if (integrateRoot) {
            // Integrate root trait out against rootPrior
            thisLogLikelihood += integrateLogLikelihoodAtRoot(conditionalRootMean, Ay, tmpM, traitPrecision, // Ay is destroyed
            conditionalRootPrecision);
        }
        if (DEBUG) {
            System.err.println("yAy = " + yAy);
            System.err.println("logLikelihood (before remainders) = " + thisLogLikelihood + " (should match conditional root MVN density when root not integrated out)");
        }
        logLikelihood += thisLogLikelihood;
    }
    logLikelihood += sumLogRemainders();
    if (DEBUG) {
        System.out.println("logLikelihood is " + logLikelihood);
    }
    if (DEBUG) {
        // Root trait is univariate!!!
        System.err.println("logLikelihood (final) = " + logLikelihood);
    //            checkViaLargeMatrixInversion();
    }
    if (DEBUG_PNAS) {
        checkLogLikelihood(logLikelihood, sumLogRemainders(), conditionalRootMean, conditionalRootPrecision, traitPrecision);
        for (int i = 0; i < logRemainderDensityCache.length; ++i) {
            if (logRemainderDensityCache[i] < -1E10) {
                System.err.println(logRemainderDensityCache[i] + " @ " + i);
            }
        }
    }
    // Should redraw internal node states when needed
    areStatesRedrawn = false;
    return logLikelihood;
}
Also used : WishartSufficientStatistics(dr.math.distributions.WishartSufficientStatistics) SymmetricMatrix(dr.math.matrixAlgebra.SymmetricMatrix) Matrix(dr.math.matrixAlgebra.Matrix) Vector(dr.math.matrixAlgebra.Vector)

Example 7 with WishartSufficientStatistics

use of dr.math.distributions.WishartSufficientStatistics in project beast-mcmc by beast-dev.

the class PrecisionMatrixGibbsOperator method incrementOuterProduct.

private void incrementOuterProduct(double[][] S, ConjugateWishartStatisticsProvider integratedLikelihood) {
    final WishartSufficientStatistics sufficientStatistics = integratedLikelihood.getWishartStatistics();
    final double[] outerProducts = sufficientStatistics.getScaleMatrix();
    final double df = sufficientStatistics.getDf();
    if (DEBUG) {
        System.err.println("OP df = " + df);
        System.err.println("OP    = " + new Vector(outerProducts));
    }
    if (debugModel != null) {
        final WishartSufficientStatistics debug = ((ConjugateWishartStatisticsProvider) debugModel).getWishartStatistics();
        System.err.println(df + " ?= " + debug.getDf());
        System.err.println(new Vector(outerProducts));
        System.err.println("");
        System.err.println(new Vector(debug.getScaleMatrix()));
        System.exit(-1);
    }
    //        final double df = 2;
    //        final double df = integratedLikelihood.getTotalTreePrecision();
    //        System.err.println("OuterProducts = \n" + new Matrix(outerProducts));
    //        System.err.println("Total tree DF  = " + df);
    //        System.exit(-1);
    final int dim = S.length;
    for (int i = 0; i < dim; i++) {
        System.arraycopy(outerProducts, i * dim, S[i], 0, dim);
    }
    numberObservations = df;
//        checkDiagonals(outerProducts);
}
Also used : WishartSufficientStatistics(dr.math.distributions.WishartSufficientStatistics) ConjugateWishartStatisticsProvider(dr.math.interfaces.ConjugateWishartStatisticsProvider) Vector(dr.math.matrixAlgebra.Vector)

Aggregations

WishartSufficientStatistics (dr.math.distributions.WishartSufficientStatistics)7 Vector (dr.math.matrixAlgebra.Vector)4 Matrix (dr.math.matrixAlgebra.Matrix)3 NodeRef (dr.evolution.tree.NodeRef)1 MultivariateNormalDistribution (dr.math.distributions.MultivariateNormalDistribution)1 ConjugateWishartStatisticsProvider (dr.math.interfaces.ConjugateWishartStatisticsProvider)1 IllegalDimension (dr.math.matrixAlgebra.IllegalDimension)1 SymmetricMatrix (dr.math.matrixAlgebra.SymmetricMatrix)1