Search in sources :

Example 1 with Matrix

use of dr.math.matrixAlgebra.Matrix in project beast-mcmc by beast-dev.

the class DebugableIntegratedMultivariateTraitLikelihood method computeTreeVariance.

public double[][] computeTreeVariance() {
    final int tipCount = treeModel.getExternalNodeCount();
    double[][] variance = new double[tipCount][tipCount];
    for (int i = 0; i < tipCount; i++) {
        // Fill in diagonal
        double marginalTime = getRescaledLengthToRoot(treeModel.getExternalNode(i));
        variance[i][i] = marginalTime;
        for (int j = i + 1; j < tipCount; j++) {
            NodeRef mrca = findMRCA(i, j);
            variance[i][j] = getRescaledLengthToRoot(mrca);
        }
    }
    // Make symmetric
    for (int i = 0; i < tipCount; i++) {
        for (int j = i + 1; j < tipCount; j++) {
            variance[j][i] = variance[i][j];
        }
    }
    if (DEBUG) {
        System.err.println("");
        System.err.println("New tree conditional variance:\n" + new Matrix(variance));
    }
    // Automatically prune missing tips
    variance = removeMissingTipsInTreeVariance(variance);
    if (DEBUG) {
        System.err.println("");
        System.err.println("New tree (trimmed) conditional variance:\n" + new Matrix(variance));
    }
    return variance;
}
Also used : NodeRef(dr.evolution.tree.NodeRef) SymmetricMatrix(dr.math.matrixAlgebra.SymmetricMatrix) Matrix(dr.math.matrixAlgebra.Matrix)

Example 2 with Matrix

use of dr.math.matrixAlgebra.Matrix in project beast-mcmc by beast-dev.

the class DensityMap method addTree.

public void addTree(Tree tree, double sampleTime, String attributeName1, String attributeName2) {
    checkCalibration();
    double[][] variance = null;
    Object[] obj = (Object[]) tree.getAttribute(MultivariateDiffusionModel.PRECISION_TREE_ATTRIBUTE);
    if (obj != null) {
        variance = new Matrix(MatrixParameter.parseFromSymmetricDoubleArray(obj).getParameterAsMatrix()).inverse().toComponents();
    }
    for (int i = 0; i < tree.getNodeCount(); i++) {
        NodeRef node = tree.getNode(i);
        if (node != tree.getRoot()) {
            NodeRef parent = tree.getParent(node);
            double t1 = tree.getNodeHeight(node);
            double t2 = tree.getNodeHeight(parent);
            if (t1 <= sampleTime && t2 >= sampleTime) {
                Double valueX1 = transform((Double) tree.getNodeAttribute(node, attributeName1));
                Double valueY1 = transform((Double) tree.getNodeAttribute(node, attributeName2));
                Double valueX2 = transform((Double) tree.getNodeAttribute(parent, attributeName1));
                Double valueY2 = transform((Double) tree.getNodeAttribute(parent, attributeName2));
                if (valueX1 != null && valueY1 != null && valueX2 != null && valueY2 != null) {
                    addPoint(sampleTime, t1, t2, valueX1, valueY1, valueX2, valueY2, variance);
                }
            }
        }
    }
}
Also used : NodeRef(dr.evolution.tree.NodeRef) Matrix(dr.math.matrixAlgebra.Matrix)

Example 3 with Matrix

use of dr.math.matrixAlgebra.Matrix in project beast-mcmc by beast-dev.

the class FullyConjugateMultivariateTraitLikelihood method calculateAscertainmentCorrection.

protected double calculateAscertainmentCorrection(int taxonIndex) {
    NodeRef tip = treeModel.getNode(taxonIndex);
    int nodeIndex = treeModel.getNode(taxonIndex).getNumber();
    if (ascertainedData == null) {
        // Assumes that ascertained data are fixed
        ascertainedData = new double[dimTrait];
    }
    // diffusionModel.diffusionPrecisionMatrixParameter.setParameterValue(0,2); // For debugging non-1 values
    double[][] traitPrecision = diffusionModel.getPrecisionmatrix();
    double logDetTraitPrecision = Math.log(diffusionModel.getDeterminantPrecisionMatrix());
    double lengthToRoot = getRescaledLengthToRoot(tip);
    double marginalPrecisionScalar = 1.0 / lengthToRoot + rootPriorSampleSize;
    double logLikelihood = 0;
    for (int datum = 0; datum < numData; ++datum) {
        // Get observed trait value
        System.arraycopy(meanCache, nodeIndex * dim + datum * dimTrait, ascertainedData, 0, dimTrait);
        if (DEBUG_ASCERTAINMENT) {
            System.err.println("Datum #" + datum);
            System.err.println("Value: " + new Vector(ascertainedData));
            System.err.println("Cond : " + lengthToRoot);
            System.err.println("MargV: " + 1.0 / marginalPrecisionScalar);
            System.err.println("MargP: " + marginalPrecisionScalar);
            System.err.println("diffusion prec: " + new Matrix(traitPrecision));
        }
        double SSE;
        if (dimTrait > 1) {
            throw new RuntimeException("Still need to implement multivariate ascertainment correction");
        } else {
            double precision = traitPrecision[0][0] * marginalPrecisionScalar;
            SSE = ascertainedData[0] * precision * ascertainedData[0];
        }
        double thisLogLikelihood = -LOG_SQRT_2_PI * dimTrait + 0.5 * (logDetTraitPrecision + dimTrait * Math.log(marginalPrecisionScalar) - SSE);
        if (DEBUG_ASCERTAINMENT) {
            System.err.println("LogLik: " + thisLogLikelihood);
            dr.math.distributions.NormalDistribution normal = new dr.math.distributions.NormalDistribution(0, Math.sqrt(1.0 / (traitPrecision[0][0] * marginalPrecisionScalar)));
            System.err.println("TTTLik: " + normal.logPdf(ascertainedData[0]));
            if (datum >= 10) {
                System.exit(-1);
            }
        }
        logLikelihood += thisLogLikelihood;
    }
    return logLikelihood;
}
Also used : NodeRef(dr.evolution.tree.NodeRef) Matrix(dr.math.matrixAlgebra.Matrix) MultivariateNormalDistribution(dr.math.distributions.MultivariateNormalDistribution) Vector(dr.math.matrixAlgebra.Vector)

Example 4 with Matrix

use of dr.math.matrixAlgebra.Matrix in project beast-mcmc by beast-dev.

the class IntegratedMultivariateTraitLikelihood method incrementRemainderDensities.

private void incrementRemainderDensities(double[][] precisionMatrix, double logDetPrecisionMatrix, int thisIndex, int thisOffset, int childOffset0, int childOffset1, double precision0, double precision1, double OUFactor0, double OUFactor1, boolean cacheOuterProducts) {
    final double remainderPrecision = precision0 * precision1 / (precision0 + precision1);
    if (cacheOuterProducts) {
        incrementOuterProducts(thisOffset, childOffset0, childOffset1, precision0, precision1);
    }
    for (int k = 0; k < numData; k++) {
        double childSS0 = 0;
        double childSS1 = 0;
        double crossSS = 0;
        for (int i = 0; i < dimTrait; i++) {
            // In case of no drift, getCorrectedMeanCache() simply returns mean cache
            // final double wChild0i = meanCache[childOffset0 + k * dimTrait + i] * precision0;
            final double wChild0i = cacheHelper.getCorrectedMeanCache()[childOffset0 + k * dimTrait + i] * precision0;
            // final double wChild1i = meanCache[childOffset1 + k * dimTrait + i] * precision1;
            final double wChild1i = cacheHelper.getCorrectedMeanCache()[childOffset1 + k * dimTrait + i] * precision1;
            for (int j = 0; j < dimTrait; j++) {
                // subtract "correction"
                // final double child0j = meanCache[childOffset0 + k * dimTrait + j];
                final double child0j = cacheHelper.getCorrectedMeanCache()[childOffset0 + k * dimTrait + j];
                // subtract "correction"
                // final double child1j = meanCache[childOffset1 + k * dimTrait + j];
                final double child1j = cacheHelper.getCorrectedMeanCache()[childOffset1 + k * dimTrait + j];
                childSS0 += wChild0i * precisionMatrix[i][j] * child0j;
                childSS1 += wChild1i * precisionMatrix[i][j] * child1j;
                // make sure meanCache in following is not "corrected"
                // crossSS += (wChild0i + wChild1i) * precisionMatrix[i][j] * meanCache[thisOffset + k * dimTrait + j];
                crossSS += (wChild0i + wChild1i) * precisionMatrix[i][j] * cacheHelper.getMeanCache()[thisOffset + k * dimTrait + j];
            }
        }
        logRemainderDensityCache[thisIndex] += -dimTrait * LOG_SQRT_2_PI + 0.5 * (dimTrait * Math.log(remainderPrecision) + logDetPrecisionMatrix) - 0.5 * (childSS0 + childSS1 - crossSS) - // changeou
        dimTrait * (Math.log(OUFactor0) + Math.log(OUFactor1));
        if (DEBUG && logRemainderDensityCache[thisIndex] > 1E2) {
            System.err.println(thisIndex);
            System.err.println(logRemainderDensityCache[thisIndex]);
            System.err.println("rP = " + remainderPrecision);
            System.err.println("p0 = " + precision0);
            System.err.println("p1 = " + precision1 + "\n");
            System.err.println(new Matrix(precisionMatrix));
            System.err.println(childSS0);
            System.err.println(childSS1);
            System.err.println(crossSS);
            for (int i = 0; i < dimTrait; ++i) {
                System.err.println("\t" + cacheHelper.getCorrectedMeanCache()[childOffset0 + 0 * dimTrait + i] + " " + cacheHelper.getCorrectedMeanCache()[childOffset1 + 0 * dimTrait + i]);
            }
            System.exit(-1);
        }
    // double tempnum = childSS0 + childSS1 - crossSS;
    // System.err.println("childSS0 + childSS1 - crossSS:  " + tempnum);
    }
// System.err.println("logRemainderDensity: " + logRemainderDensityCache[thisIndex]);
// System.err.println("thisIndex: " + thisIndex);
// System.err.println("remainder precision: " + remainderPrecision);
// System.err.println("precision0: " + precision0);
// System.err.println("precision1: " + precision1);
// System.err.println("precision0*precision1: " + precision0*precision1);
// System.err.println("logDetPrecisionMatrix: " + logDetPrecisionMatrix);
}
Also used : SymmetricMatrix(dr.math.matrixAlgebra.SymmetricMatrix) Matrix(dr.math.matrixAlgebra.Matrix)

Example 5 with Matrix

use of dr.math.matrixAlgebra.Matrix in project beast-mcmc by beast-dev.

the class ApproximateFactorAnalysisPrecisionMatrix method computeValuesImp.

private void computeValuesImp() {
    dim = L.getRowDimension();
    double[][] matrix = new double[dim][dim];
    for (int row = 0; row < L.getRowDimension(); ++row) {
        for (int col = 0; col < L.getRowDimension(); ++col) {
            double sum = 0;
            for (int k = 0; k < L.getColumnDimension(); ++k) {
                sum += L.getParameterValue(row, k) * L.getParameterValue(col, k);
            }
            matrix[row][col] = sum;
        }
    }
    for (int row = 0; row < dim; row++) {
        matrix[row][row] += 1 / gamma.getParameterValue(row);
    }
    if (DEBUG) {
        System.err.println("mult:");
        System.err.println(new Matrix(L.getParameterAsMatrix()));
        System.err.println(new Vector(gamma.getParameterValues()) + "\n");
        System.err.println(new Matrix(matrix));
    }
    matrix = new Matrix(matrix).inverse().toComponents();
    int index = 0;
    for (int row = 0; row < dim; ++row) {
        for (int col = 0; col < dim; ++col) {
            values[index] = matrix[row][col];
            ++index;
        }
    }
}
Also used : Matrix(dr.math.matrixAlgebra.Matrix) Vector(dr.math.matrixAlgebra.Vector)

Aggregations

Matrix (dr.math.matrixAlgebra.Matrix)51 SymmetricMatrix (dr.math.matrixAlgebra.SymmetricMatrix)17 Vector (dr.math.matrixAlgebra.Vector)15 IllegalDimension (dr.math.matrixAlgebra.IllegalDimension)14 SymmetricMatrix.compoundCorrelationSymmetricMatrix (dr.math.matrixAlgebra.SymmetricMatrix.compoundCorrelationSymmetricMatrix)7 NodeRef (dr.evolution.tree.NodeRef)6 MultivariateNormalDistribution (dr.math.distributions.MultivariateNormalDistribution)5 WishartSufficientStatistics (dr.math.distributions.WishartSufficientStatistics)4 Parameter (dr.inference.model.Parameter)3 DoubleMatrix1D (cern.colt.matrix.DoubleMatrix1D)2 DoubleMatrix2D (cern.colt.matrix.DoubleMatrix2D)2 DenseDoubleMatrix2D (cern.colt.matrix.impl.DenseDoubleMatrix2D)2 Tree (dr.evolution.tree.Tree)2 MatrixParameter (dr.inference.model.MatrixParameter)2 RobustEigenDecomposition (dr.math.matrixAlgebra.RobustEigenDecomposition)2 WrappedMatrix (dr.math.matrixAlgebra.WrappedMatrix)2 BranchRates (dr.evolution.tree.BranchRates)1 MutableTreeModel (dr.evolution.tree.MutableTreeModel)1 CompoundSymmetricMatrix (dr.inference.model.CompoundSymmetricMatrix)1 CorrelationSymmetricMatrix (dr.inference.model.CorrelationSymmetricMatrix)1