use of org.apache.commons.math3.stat.correlation.Covariance in project jstructure by JonStargaryen.
the class SVDSuperimposer method align.
@Override
public StructureAlignmentResult align(AtomContainer reference, AtomContainer query) {
AtomContainer originalReference = reference;
AtomContainer originalCandidate = query;
Pair<GroupContainer, GroupContainer> atomContainerPair = AbstractAlignmentAlgorithm.comparableGroupContainerPair(reference, query, minimalSetOfAtomNames, maximalSetOfAtomNames);
reference = atomContainerPair.getLeft();
query = atomContainerPair.getRight();
// calculate centroids
double[] centroid1 = reference.calculate().centroid().getValue();
double[] centroid2 = query.calculate().centroid().getValue();
// center atoms
reference.calculate().center();
query.calculate().center();
// compose covariance matrix and calculate SVD
RealMatrix matrix1 = convertToMatrix(reference);
RealMatrix matrix2 = convertToMatrix(query);
RealMatrix covariance = matrix2.transpose().multiply(matrix1);
SingularValueDecomposition svd = new SingularValueDecomposition(covariance);
// R = (V * U')'
RealMatrix ut = svd.getU().transpose();
RealMatrix rotationMatrix = svd.getV().multiply(ut).transpose();
// check if reflection
if (new LUDecomposition(rotationMatrix).getDeterminant() < 0) {
RealMatrix v = svd.getV().transpose();
v.setEntry(2, 0, (0 - v.getEntry(2, 0)));
v.setEntry(2, 1, (0 - v.getEntry(2, 1)));
v.setEntry(2, 2, (0 - v.getEntry(2, 2)));
rotationMatrix = v.transpose().multiply(ut).transpose();
}
double[][] rotation = rotationMatrix.getData();
// calculate translation
double[] translation = LinearAlgebra.on(centroid1).subtract(LinearAlgebra.on(centroid2).multiply(rotation)).getValue();
logger.trace("rotation matrix\n{}\ntranslation vector\n{}", Arrays.deepToString(rotationMatrix.getData()), Arrays.toString(translation));
/* transform 2nd atom select - employ neutral translation (3D vector of zeros), because the atoms are already
* centered and calculate RMSD */
query.calculate().transform(new Transformation(rotation));
double rmsd = calculateRmsd(reference, query);
// return alignment
return new StructureAlignmentResult(originalReference, originalCandidate, query, rmsd, translation, rotation);
}
use of org.apache.commons.math3.stat.correlation.Covariance in project GDSC-SMLM by aherbert.
the class MaximumLikelihoodFitter method computeFit.
/*
* (non-Javadoc)
*
* @see gdsc.smlm.fitting.nonlinear.BaseFunctionSolver#computeFit(double[], double[], double[], double[])
*/
public FitStatus computeFit(double[] y, double[] y_fit, double[] a, double[] a_dev) {
final int n = y.length;
LikelihoodWrapper maximumLikelihoodFunction = createLikelihoodWrapper((NonLinearFunction) f, n, y, a);
@SuppressWarnings("rawtypes") BaseOptimizer baseOptimiser = null;
try {
double[] startPoint = getInitialSolution(a);
PointValuePair optimum = null;
if (searchMethod == SearchMethod.POWELL || searchMethod == SearchMethod.POWELL_BOUNDED || searchMethod == SearchMethod.POWELL_ADAPTER) {
// Non-differentiable version using Powell Optimiser
// This is as per the method in Numerical Recipes 10.5 (Direction Set (Powell's) method)
// I could extend the optimiser and implement bounds on the directions moved. However the mapping
// adapter seems to work OK.
final boolean basisConvergence = false;
// Perhaps these thresholds should be tighter?
// The default is to use the sqrt() of the overall tolerance
//final double lineRel = FastMath.sqrt(relativeThreshold);
//final double lineAbs = FastMath.sqrt(absoluteThreshold);
//final double lineRel = relativeThreshold * 1e2;
//final double lineAbs = absoluteThreshold * 1e2;
// Since we are fitting only a small number of parameters then just use the same tolerance
// for each search direction
final double lineRel = relativeThreshold;
final double lineAbs = absoluteThreshold;
CustomPowellOptimizer o = new CustomPowellOptimizer(relativeThreshold, absoluteThreshold, lineRel, lineAbs, null, basisConvergence);
baseOptimiser = o;
OptimizationData maxIterationData = null;
if (getMaxIterations() > 0)
maxIterationData = new MaxIter(getMaxIterations());
if (searchMethod == SearchMethod.POWELL_ADAPTER) {
// Try using the mapping adapter for a bounded Powell search
MultivariateFunctionMappingAdapter adapter = new MultivariateFunctionMappingAdapter(new MultivariateLikelihood(maximumLikelihoodFunction), lower, upper);
optimum = o.optimize(maxIterationData, new MaxEval(getMaxEvaluations()), new ObjectiveFunction(adapter), GoalType.MINIMIZE, new InitialGuess(adapter.boundedToUnbounded(startPoint)));
double[] solution = adapter.unboundedToBounded(optimum.getPointRef());
optimum = new PointValuePair(solution, optimum.getValue());
} else {
if (powellFunction == null) {
// Python code by using the sqrt of the number of photons and background.
if (mapGaussian) {
Gaussian2DFunction gf = (Gaussian2DFunction) f;
// Re-map signal and background using the sqrt
int[] indices = gf.gradientIndices();
int[] map = new int[indices.length];
int count = 0;
// Background is always first
if (indices[0] == Gaussian2DFunction.BACKGROUND) {
map[count++] = 0;
}
// Look for the Signal in multiple peak 2D Gaussians
for (int i = 1; i < indices.length; i++) if (indices[i] % 6 == Gaussian2DFunction.SIGNAL) {
map[count++] = i;
}
if (count > 0) {
powellFunction = new MappedMultivariateLikelihood(maximumLikelihoodFunction, Arrays.copyOf(map, count));
}
}
if (powellFunction == null) {
powellFunction = new MultivariateLikelihood(maximumLikelihoodFunction);
}
}
// Update the maximum likelihood function in the Powell function wrapper
powellFunction.fun = maximumLikelihoodFunction;
OptimizationData positionChecker = null;
// new org.apache.commons.math3.optim.PositionChecker(relativeThreshold, absoluteThreshold);
SimpleBounds simpleBounds = null;
if (powellFunction.isMapped()) {
MappedMultivariateLikelihood adapter = (MappedMultivariateLikelihood) powellFunction;
if (searchMethod == SearchMethod.POWELL_BOUNDED)
simpleBounds = new SimpleBounds(adapter.map(lower), adapter.map(upper));
optimum = o.optimize(maxIterationData, new MaxEval(getMaxEvaluations()), new ObjectiveFunction(powellFunction), GoalType.MINIMIZE, new InitialGuess(adapter.map(startPoint)), positionChecker, simpleBounds);
double[] solution = adapter.unmap(optimum.getPointRef());
optimum = new PointValuePair(solution, optimum.getValue());
} else {
if (searchMethod == SearchMethod.POWELL_BOUNDED)
simpleBounds = new SimpleBounds(lower, upper);
optimum = o.optimize(maxIterationData, new MaxEval(getMaxEvaluations()), new ObjectiveFunction(powellFunction), GoalType.MINIMIZE, new InitialGuess(startPoint), positionChecker, simpleBounds);
}
}
} else if (searchMethod == SearchMethod.BOBYQA) {
// Differentiable approximation using Powell's BOBYQA algorithm.
// This is slower than the Powell optimiser and requires a high number of evaluations.
int numberOfInterpolationPoints = this.getNumberOfFittedParameters() + 2;
BOBYQAOptimizer o = new BOBYQAOptimizer(numberOfInterpolationPoints);
baseOptimiser = o;
optimum = o.optimize(new MaxEval(getMaxEvaluations()), new ObjectiveFunction(new MultivariateLikelihood(maximumLikelihoodFunction)), GoalType.MINIMIZE, new InitialGuess(startPoint), new SimpleBounds(lower, upper));
} else if (searchMethod == SearchMethod.CMAES) {
// TODO - Understand why the CMAES optimiser does not fit very well on test data. It appears
// to converge too early and the likelihood scores are not as low as the other optimisers.
// CMAESOptimiser based on Matlab code:
// https://www.lri.fr/~hansen/cmaes.m
// Take the defaults from the Matlab documentation
//Double.NEGATIVE_INFINITY;
double stopFitness = 0;
boolean isActiveCMA = true;
int diagonalOnly = 0;
int checkFeasableCount = 1;
RandomGenerator random = new Well19937c();
boolean generateStatistics = false;
// The sigma determines the search range for the variables. It should be 1/3 of the initial search region.
double[] sigma = new double[lower.length];
for (int i = 0; i < sigma.length; i++) sigma[i] = (upper[i] - lower[i]) / 3;
int popSize = (int) (4 + Math.floor(3 * Math.log(sigma.length)));
// The CMAES optimiser is random and restarting can overcome problems with quick convergence.
// The Apache commons documentations states that convergence should occur between 30N and 300N^2
// function evaluations
final int n30 = FastMath.min(sigma.length * sigma.length * 30, getMaxEvaluations() / 2);
evaluations = 0;
OptimizationData[] data = new OptimizationData[] { new InitialGuess(startPoint), new CMAESOptimizer.PopulationSize(popSize), new MaxEval(getMaxEvaluations()), new CMAESOptimizer.Sigma(sigma), new ObjectiveFunction(new MultivariateLikelihood(maximumLikelihoodFunction)), GoalType.MINIMIZE, new SimpleBounds(lower, upper) };
// Iterate to prevent early convergence
int repeat = 0;
while (evaluations < n30) {
if (repeat++ > 1) {
// Update the start point and population size
data[0] = new InitialGuess(optimum.getPointRef());
popSize *= 2;
data[1] = new CMAESOptimizer.PopulationSize(popSize);
}
CMAESOptimizer o = new CMAESOptimizer(getMaxIterations(), stopFitness, isActiveCMA, diagonalOnly, checkFeasableCount, random, generateStatistics, new SimpleValueChecker(relativeThreshold, absoluteThreshold));
baseOptimiser = o;
PointValuePair result = o.optimize(data);
iterations += o.getIterations();
evaluations += o.getEvaluations();
// o.getEvaluations(), totalEvaluations);
if (optimum == null || result.getValue() < optimum.getValue()) {
optimum = result;
}
}
// Prevent incrementing the iterations again
baseOptimiser = null;
} else if (searchMethod == SearchMethod.BFGS) {
// BFGS can use an approximate line search minimisation where as Powell and conjugate gradient
// methods require a more accurate line minimisation. The BFGS search does not do a full
// minimisation but takes appropriate steps in the direction of the current gradient.
// Do not use the convergence checker on the value of the function. Use the convergence on the
// point coordinate and gradient
//BFGSOptimizer o = new BFGSOptimizer(new SimpleValueChecker(rel, abs));
BFGSOptimizer o = new BFGSOptimizer();
baseOptimiser = o;
// Configure maximum step length for each dimension using the bounds
double[] stepLength = new double[lower.length];
for (int i = 0; i < stepLength.length; i++) {
stepLength[i] = (upper[i] - lower[i]) * 0.3333333;
if (stepLength[i] <= 0)
stepLength[i] = Double.POSITIVE_INFINITY;
}
// The GoalType is always minimise so no need to pass this in
OptimizationData positionChecker = null;
//new org.apache.commons.math3.optim.PositionChecker(relativeThreshold, absoluteThreshold);
optimum = o.optimize(new MaxEval(getMaxEvaluations()), new ObjectiveFunctionGradient(new MultivariateVectorLikelihood(maximumLikelihoodFunction)), new ObjectiveFunction(new MultivariateLikelihood(maximumLikelihoodFunction)), new InitialGuess(startPoint), new SimpleBounds(lowerConstraint, upperConstraint), new BFGSOptimizer.GradientTolerance(relativeThreshold), positionChecker, new BFGSOptimizer.StepLength(stepLength));
} else {
// The line search algorithm often fails. This is due to searching into a region where the
// function evaluates to a negative so has been clipped. This means the upper bound of the line
// cannot be found.
// Note that running it on an easy problem (200 photons with fixed fitting (no background)) the algorithm
// does sometimes produces results better than the Powell algorithm but it is slower.
BoundedNonLinearConjugateGradientOptimizer o = new BoundedNonLinearConjugateGradientOptimizer((searchMethod == SearchMethod.CONJUGATE_GRADIENT_FR) ? Formula.FLETCHER_REEVES : Formula.POLAK_RIBIERE, new SimpleValueChecker(relativeThreshold, absoluteThreshold));
baseOptimiser = o;
// Note: The gradients may become unstable at the edge of the bounds. Or they will not change
// direction if the true solution is on the bounds since the gradient will always continue
// towards the bounds. This is key to the conjugate gradient method. It searches along a vector
// until the direction of the gradient is in the opposite direction (using dot products, i.e.
// cosine of angle between them)
// NR 10.7 states there is no advantage of the variable metric DFP or BFGS methods over
// conjugate gradient methods. So I will try these first.
// Try this:
// Adapt the conjugate gradient optimiser to use the gradient to pick the search direction
// and then for the line minimisation. However if the function is out of bounds then clip the
// variables at the bounds and continue.
// If the current point is at the bounds and the gradient is to continue out of bounds then
// clip the gradient too.
// Or: just use the gradient for the search direction then use the line minimisation/rest
// as per the Powell optimiser. The bounds should limit the search.
// I tried a Bounded conjugate gradient optimiser with clipped variables:
// This sometimes works. However when the variables go a long way out of the expected range the gradients
// can have vastly different magnitudes. This results in the algorithm stalling since the gradients
// can be close to zero and the some of the parameters are no longer adjusted.
// Perhaps this can be looked for and the algorithm then gives up and resorts to a Powell optimiser from
// the current point.
// Changed the bracketing step to very small (default is 1, changed to 0.001). This improves the
// performance. The gradient direction is very sensitive to small changes in the coordinates so a
// tighter bracketing of the line search helps.
// Tried using a non-gradient method for the line search copied from the Powell optimiser:
// This also works when the bracketing step is small but the number of iterations is higher.
// 24.10.2014: I have tried to get conjugate gradient to work but the gradient function
// must not behave suitably for the optimiser. In the current state both methods of using a
// Bounded Conjugate Gradient Optimiser perform poorly relative to other optimisers:
// Simulated : n=1000, signal=200, x=0.53, y=0.47
// LVM : n=1000, signal=171, x=0.537, y=0.471 (1.003s)
// Powell : n=1000, signal=187, x=0.537, y=0.48 (1.238s)
// Gradient based PR (constrained): n=858, signal=161, x=0.533, y=0.474 (2.54s)
// Gradient based PR (bounded): n=948, signal=161, x=0.533, y=0.473 (2.67s)
// Non-gradient based : n=1000, signal=151.47, x=0.535, y=0.474 (1.626s)
// The conjugate optimisers are slower, under predict the signal by the most and in the case of
// the gradient based optimiser, fail to converge on some problems. This is worse when constrained
// fitting is used and not tightly bounded fitting.
// I will leave the code in as an option but would not recommend using it. I may remove it in the
// future.
// Note: It is strange that the non-gradient based line minimisation is more successful.
// It may be that the gradient function is not accurate (due to round off error) or that it is
// simply wrong when far from the optimum. My JUnit tests only evaluate the function within the
// expected range of the answer.
// Note the default step size on the Powell optimiser is 1 but the initial directions are unit vectors.
// So our bracketing step should be a minimum of 1 / average length of the first gradient vector to prevent
// the first step being too large when bracketing.
final double[] gradient = new double[startPoint.length];
maximumLikelihoodFunction.likelihood(startPoint, gradient);
double l = 0;
for (double d : gradient) l += d * d;
final double bracketingStep = FastMath.min(0.001, ((l > 1) ? 1.0 / l : 1));
//System.out.printf("Bracketing step = %f (length=%f)\n", bracketingStep, l);
o.setUseGradientLineSearch(gradientLineMinimisation);
optimum = o.optimize(new MaxEval(getMaxEvaluations()), new ObjectiveFunctionGradient(new MultivariateVectorLikelihood(maximumLikelihoodFunction)), new ObjectiveFunction(new MultivariateLikelihood(maximumLikelihoodFunction)), GoalType.MINIMIZE, new InitialGuess(startPoint), new SimpleBounds(lowerConstraint, upperConstraint), new BoundedNonLinearConjugateGradientOptimizer.BracketingStep(bracketingStep));
//maximumLikelihoodFunction.value(solution, gradient);
//System.out.printf("Iter = %d, %g @ %s : %s\n", iterations, ll, Arrays.toString(solution),
// Arrays.toString(gradient));
}
final double[] solution = optimum.getPointRef();
setSolution(a, solution);
if (a_dev != null) {
// Assume the Maximum Likelihood estimator returns the optimum fit (achieves the Cramer Roa
// lower bounds) and so the covariance can be obtained from the Fisher Information Matrix.
FisherInformationMatrix m = new FisherInformationMatrix(maximumLikelihoodFunction.fisherInformation(a));
setDeviations(a_dev, m.crlb(true));
}
// Reverse negative log likelihood for maximum likelihood score
value = -optimum.getValue();
} catch (TooManyIterationsException e) {
//e.printStackTrace();
return FitStatus.TOO_MANY_ITERATIONS;
} catch (TooManyEvaluationsException e) {
//e.printStackTrace();
return FitStatus.TOO_MANY_EVALUATIONS;
} catch (ConvergenceException e) {
//System.out.printf("Singular non linear model = %s\n", e.getMessage());
return FitStatus.SINGULAR_NON_LINEAR_MODEL;
} catch (BFGSOptimizer.LineSearchRoundoffException e) {
//e.printStackTrace();
return FitStatus.FAILED_TO_CONVERGE;
} catch (Exception e) {
//System.out.printf("Unknown error = %s\n", e.getMessage());
e.printStackTrace();
return FitStatus.UNKNOWN;
} finally {
if (baseOptimiser != null) {
iterations += baseOptimiser.getIterations();
evaluations += baseOptimiser.getEvaluations();
}
}
// Check this as likelihood functions can go wrong
if (Double.isInfinite(value) || Double.isNaN(value))
return FitStatus.INVALID_LIKELIHOOD;
return FitStatus.OK;
}
use of org.apache.commons.math3.stat.correlation.Covariance in project gatk by broadinstitute.
the class CoverageModelEMWorkspace method initializeWorkersWithPCA.
/**
* Initialize model parameters by performing PCA.
*/
@EvaluatesRDD
@UpdatesRDD
@CachesRDD
private void initializeWorkersWithPCA() {
logger.info("Initializing model parameters using PCA...");
/* initially, set m_t, Psi_t and W_tl to zero to get an estimate of the read depth */
final int numLatents = config.getNumLatents();
mapWorkers(cb -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.m_t, Nd4j.zeros(new int[] { 1, cb.getTargetSpaceBlock().getNumElements() })).cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.Psi_t, Nd4j.zeros(new int[] { 1, cb.getTargetSpaceBlock().getNumElements() })));
if (biasCovariatesEnabled) {
mapWorkers(cb -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.W_tl, Nd4j.zeros(new int[] { cb.getTargetSpaceBlock().getNumElements(), numLatents })));
}
/* update read depth without taking into account correction from bias covariates */
updateReadDepthPosteriorExpectations(1.0, true);
/* fetch sample covariance matrix */
final int minPCAInitializationReadCount = config.getMinPCAInitializationReadCount();
mapWorkers(cb -> cb.cloneWithPCAInitializationData(minPCAInitializationReadCount, Integer.MAX_VALUE));
cacheWorkers("PCA initialization");
final INDArray targetCovarianceMatrix = mapWorkersAndReduce(CoverageModelEMComputeBlock::calculateTargetCovarianceMatrixForPCAInitialization, INDArray::add);
/* perform eigen-decomposition on the target covariance matrix */
final ImmutablePair<INDArray, INDArray> targetCovarianceEigensystem = CoverageModelEMWorkspaceMathUtils.eig(targetCovarianceMatrix, false, logger);
/* the eigenvalues of sample covariance matrix can be immediately inferred by scaling */
final INDArray sampleCovarianceEigenvalues = targetCovarianceEigensystem.getLeft().div(numSamples);
/* estimate the isotropic unexplained variance -- see Bishop 12.46 */
final int residualDim = numTargets - numLatents;
final double isotropicVariance = sampleCovarianceEigenvalues.get(NDArrayIndex.interval(numLatents, numSamples)).sumNumber().doubleValue() / residualDim;
logger.info(String.format("PCA estimate of isotropic unexplained variance: %f", isotropicVariance));
/* estimate bias factors -- see Bishop 12.45 */
final INDArray scaleFactors = Transforms.sqrt(sampleCovarianceEigenvalues.get(NDArrayIndex.interval(0, numLatents)).sub(isotropicVariance), false);
final INDArray biasCovariatesPCA = Nd4j.create(new int[] { numTargets, numLatents });
for (int li = 0; li < numLatents; li++) {
final INDArray v = targetCovarianceEigensystem.getRight().getColumn(li);
/* calculate [Delta_PCA_st]^T v */
/* note: we do not need to broadcast vec since it is small and lambda capture is just fine */
final INDArray unnormedBiasCovariate = CoverageModelSparkUtils.assembleINDArrayBlocksFromCollection(mapWorkersAndCollect(cb -> ImmutablePair.of(cb.getTargetSpaceBlock(), cb.getINDArrayFromCache(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.Delta_PCA_st).transpose().mmul(v))), 0);
final double norm = unnormedBiasCovariate.norm1Number().doubleValue();
final INDArray normedBiasCovariate = unnormedBiasCovariate.divi(norm).muli(scaleFactors.getDouble(li));
biasCovariatesPCA.getColumn(li).assign(normedBiasCovariate);
}
if (ardEnabled) {
/* a better estimate of ARD coefficients */
biasCovariatesARDCoefficients.assign(Nd4j.zeros(new int[] { 1, numLatents }).addi(config.getInitialARDPrecisionRelativeToNoise() / isotropicVariance));
}
final CoverageModelParameters modelParamsFromPCA = new CoverageModelParameters(processedTargetList, Nd4j.zeros(new int[] { 1, numTargets }), Nd4j.zeros(new int[] { 1, numTargets }).addi(isotropicVariance), biasCovariatesPCA, biasCovariatesARDCoefficients);
/* clear PCA initialization data from workers */
mapWorkers(CoverageModelEMComputeBlock::cloneWithRemovedPCAInitializationData);
/* push model parameters to workers */
initializeWorkersWithGivenModel(modelParamsFromPCA);
/* update bias latent posterior expectations without admixing */
updateBiasLatentPosteriorExpectations(1.0);
}
use of org.apache.commons.math3.stat.correlation.Covariance in project gatk by broadinstitute.
the class CoverageModelParameters method write.
/**
* Writes the model to disk.
*
* @param outputPath model output path
*/
public static void write(@Nonnull CoverageModelParameters model, @Nonnull final String outputPath) {
/* create output directory if it doesn't exist */
createOutputPath(Utils.nonNull(outputPath, "The output path string must be non-null"));
/* write targets list */
final File targetListFile = new File(outputPath, CoverageModelGlobalConstants.TARGET_LIST_OUTPUT_FILE);
TargetWriter.writeTargetsToFile(targetListFile, model.getTargetList());
final List<String> targetNames = model.getTargetList().stream().map(Target::getName).collect(Collectors.toList());
/* write target mean bias to file */
final File targetMeanBiasFile = new File(outputPath, CoverageModelGlobalConstants.TARGET_MEAN_LOG_BIAS_OUTPUT_FILE);
Nd4jIOUtils.writeNDArrayMatrixToTextFile(model.getTargetMeanLogBias().transpose(), targetMeanBiasFile, MEAN_LOG_BIAS_MATRIX_NAME, targetNames, null);
/* write target unexplained variance to file */
final File targetUnexplainedVarianceFile = new File(outputPath, CoverageModelGlobalConstants.TARGET_UNEXPLAINED_VARIANCE_OUTPUT_FILE);
Nd4jIOUtils.writeNDArrayMatrixToTextFile(model.getTargetUnexplainedVariance().transpose(), targetUnexplainedVarianceFile, TARGET_UNEXPLAINED_VARIANCE_MATRIX_NAME, targetNames, null);
if (model.isBiasCovariatesEnabled()) {
/* write mean bias covariates to file */
final List<String> meanBiasCovariatesNames = IntStream.range(0, model.getNumLatents()).mapToObj(li -> String.format(BIAS_COVARIATE_COLUMN_NAME_FORMAT, li)).collect(Collectors.toList());
final File meanBiasCovariatesFile = new File(outputPath, CoverageModelGlobalConstants.MEAN_BIAS_COVARIATES_OUTPUT_FILE);
Nd4jIOUtils.writeNDArrayMatrixToTextFile(model.getMeanBiasCovariates(), meanBiasCovariatesFile, MEAN_BIAS_COVARIATES_MATRIX_NAME, targetNames, meanBiasCovariatesNames);
/* write norm_2 of mean bias covariates to file */
final INDArray WTW = model.getMeanBiasCovariates().transpose().mmul(model.getMeanBiasCovariates());
final double[] biasCovariatesNorm2 = IntStream.range(0, model.getNumLatents()).mapToDouble(li -> WTW.getDouble(li, li)).toArray();
final File biasCovariatesNorm2File = new File(outputPath, CoverageModelGlobalConstants.MEAN_BIAS_COVARIATES_NORM2_OUTPUT_FILE);
Nd4jIOUtils.writeNDArrayMatrixToTextFile(Nd4j.create(biasCovariatesNorm2, new int[] { 1, model.getNumLatents() }), biasCovariatesNorm2File, MEAN_BIAS_COVARIATES_NORM_2_MATRIX_NAME, null, meanBiasCovariatesNames);
/* if ARD is enabled, write the ARD coefficients and covariance of W as well */
if (model.isARDEnabled()) {
final File biasCovariatesARDCoefficientsFile = new File(outputPath, CoverageModelGlobalConstants.BIAS_COVARIATES_ARD_COEFFICIENTS_OUTPUT_FILE);
Nd4jIOUtils.writeNDArrayMatrixToTextFile(model.getBiasCovariateARDCoefficients(), biasCovariatesARDCoefficientsFile, BIAS_COVARIATES_ARD_COEFFICIENTS_MATRIX_NAME, null, meanBiasCovariatesNames);
}
}
}
use of org.apache.commons.math3.stat.correlation.Covariance in project gatk-protected by broadinstitute.
the class CoverageModelEMWorkspace method initializeWorkersWithPCA.
/**
* Initialize model parameters by performing PCA.
*/
@EvaluatesRDD
@UpdatesRDD
@CachesRDD
private void initializeWorkersWithPCA() {
logger.info("Initializing model parameters using PCA...");
/* initially, set m_t, Psi_t and W_tl to zero to get an estimate of the read depth */
final int numLatents = config.getNumLatents();
mapWorkers(cb -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.m_t, Nd4j.zeros(new int[] { 1, cb.getTargetSpaceBlock().getNumElements() })).cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.Psi_t, Nd4j.zeros(new int[] { 1, cb.getTargetSpaceBlock().getNumElements() })));
if (biasCovariatesEnabled) {
mapWorkers(cb -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.W_tl, Nd4j.zeros(new int[] { cb.getTargetSpaceBlock().getNumElements(), numLatents })));
}
/* update read depth without taking into account correction from bias covariates */
updateReadDepthPosteriorExpectations(1.0, true);
/* fetch sample covariance matrix */
final int minPCAInitializationReadCount = config.getMinPCAInitializationReadCount();
mapWorkers(cb -> cb.cloneWithPCAInitializationData(minPCAInitializationReadCount, Integer.MAX_VALUE));
cacheWorkers("PCA initialization");
final INDArray targetCovarianceMatrix = mapWorkersAndReduce(CoverageModelEMComputeBlock::calculateTargetCovarianceMatrixForPCAInitialization, INDArray::add);
/* perform eigen-decomposition on the target covariance matrix */
final ImmutablePair<INDArray, INDArray> targetCovarianceEigensystem = CoverageModelEMWorkspaceMathUtils.eig(targetCovarianceMatrix, false, logger);
/* the eigenvalues of sample covariance matrix can be immediately inferred by scaling */
final INDArray sampleCovarianceEigenvalues = targetCovarianceEigensystem.getLeft().div(numSamples);
/* estimate the isotropic unexplained variance -- see Bishop 12.46 */
final int residualDim = numTargets - numLatents;
final double isotropicVariance = sampleCovarianceEigenvalues.get(NDArrayIndex.interval(numLatents, numSamples)).sumNumber().doubleValue() / residualDim;
logger.info(String.format("PCA estimate of isotropic unexplained variance: %f", isotropicVariance));
/* estimate bias factors -- see Bishop 12.45 */
final INDArray scaleFactors = Transforms.sqrt(sampleCovarianceEigenvalues.get(NDArrayIndex.interval(0, numLatents)).sub(isotropicVariance), false);
final INDArray biasCovariatesPCA = Nd4j.create(new int[] { numTargets, numLatents });
for (int li = 0; li < numLatents; li++) {
final INDArray v = targetCovarianceEigensystem.getRight().getColumn(li);
/* calculate [Delta_PCA_st]^T v */
/* note: we do not need to broadcast vec since it is small and lambda capture is just fine */
final INDArray unnormedBiasCovariate = CoverageModelSparkUtils.assembleINDArrayBlocksFromCollection(mapWorkersAndCollect(cb -> ImmutablePair.of(cb.getTargetSpaceBlock(), cb.getINDArrayFromCache(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.Delta_PCA_st).transpose().mmul(v))), 0);
final double norm = unnormedBiasCovariate.norm1Number().doubleValue();
final INDArray normedBiasCovariate = unnormedBiasCovariate.divi(norm).muli(scaleFactors.getDouble(li));
biasCovariatesPCA.getColumn(li).assign(normedBiasCovariate);
}
if (ardEnabled) {
/* a better estimate of ARD coefficients */
biasCovariatesARDCoefficients.assign(Nd4j.zeros(new int[] { 1, numLatents }).addi(config.getInitialARDPrecisionRelativeToNoise() / isotropicVariance));
}
final CoverageModelParameters modelParamsFromPCA = new CoverageModelParameters(processedTargetList, Nd4j.zeros(new int[] { 1, numTargets }), Nd4j.zeros(new int[] { 1, numTargets }).addi(isotropicVariance), biasCovariatesPCA, biasCovariatesARDCoefficients);
/* clear PCA initialization data from workers */
mapWorkers(CoverageModelEMComputeBlock::cloneWithRemovedPCAInitializationData);
/* push model parameters to workers */
initializeWorkersWithGivenModel(modelParamsFromPCA);
/* update bias latent posterior expectations without admixing */
updateBiasLatentPosteriorExpectations(1.0);
}
Aggregations