use of org.apache.commons.math3.stat.descriptive.moment.Variance in project gatk-protected by broadinstitute.
the class GATKProtectedMathUtils method rowStdDevs.
/**
* Calculates the standard deviation per row from a matrix.
* @param matrix the input matrix.
* @return never {@code null}, an array with as many positions as rows in {@code matrix}.
* @throws IllegalArgumentException if {@code matrix} is {@code null}.
*/
public static double[] rowStdDevs(final RealMatrix matrix) {
Utils.nonNull(matrix);
final Variance varianceEvaluator = new Variance();
return IntStream.range(0, matrix.getRowDimension()).mapToDouble(r -> Math.sqrt(varianceEvaluator.evaluate(matrix.getRow(r)))).toArray();
}
use of org.apache.commons.math3.stat.descriptive.moment.Variance in project gatk-protected by broadinstitute.
the class GATKProtectedMathUtils method rowVariances.
public static double[] rowVariances(final RealMatrix matrix) {
Utils.nonNull(matrix);
final Variance varianceEvaluator = new Variance();
return IntStream.range(0, matrix.getRowDimension()).mapToDouble(r -> varianceEvaluator.evaluate(matrix.getRow(r))).toArray();
}
use of org.apache.commons.math3.stat.descriptive.moment.Variance in project gatk by broadinstitute.
the class CoverageModelEMWorkspace method updateSampleUnexplainedVariance.
/**
* E-step update of the sample-specific unexplained variance
*
* @return a {@link SubroutineSignal} containing the update size (key: "error_norm") and the average
* number of function evaluations per sample (key: "iterations")
*/
@EvaluatesRDD
@UpdatesRDD
@CachesRDD
public SubroutineSignal updateSampleUnexplainedVariance() {
mapWorkers(cb -> cb.cloneWithUpdatedCachesByTag(CoverageModelEMComputeBlock.CoverageModelICGCacheTag.E_STEP_GAMMA));
cacheWorkers("after E-step for sample unexplained variance initialization");
/* create a compound objective function for simultaneous multi-sample queries */
final java.util.function.Function<Map<Integer, Double>, Map<Integer, Double>> objFunc = arg -> {
if (arg.isEmpty()) {
return Collections.emptyMap();
}
final int[] sampleIndices = arg.keySet().stream().mapToInt(i -> i).toArray();
final INDArray gammaValues = Nd4j.create(Arrays.stream(sampleIndices).mapToDouble(arg::get).toArray(), new int[] { sampleIndices.length, 1 });
final INDArray eval = mapWorkersAndReduce(cb -> cb.calculateSampleSpecificVarianceObjectiveFunctionMultiSample(sampleIndices, gammaValues), INDArray::add);
final Map<Integer, Double> output = new HashMap<>();
IntStream.range(0, sampleIndices.length).forEach(evalIdx -> output.put(sampleIndices[evalIdx], eval.getDouble(evalIdx)));
return output;
};
final java.util.function.Function<UnivariateSolverSpecifications, AbstractUnivariateSolver> solverFactory = spec -> new RobustBrentSolver(spec.getRelativeAccuracy(), spec.getAbsoluteAccuracy(), spec.getFunctionValueAccuracy(), null, config.getSampleSpecificVarianceSolverNumBisections(), config.getSampleSpecificVarianceSolverRefinementDepth());
/* instantiate a synchronized multi-sample root finder and add jobs */
final SynchronizedUnivariateSolver syncSolver = new SynchronizedUnivariateSolver(objFunc, solverFactory, numSamples);
IntStream.range(0, numSamples).forEach(si -> {
final double x0 = 0.5 * config.getSampleSpecificVarianceUpperLimit();
syncSolver.add(si, 0, config.getSampleSpecificVarianceUpperLimit(), x0, config.getSampleSpecificVarianceAbsoluteTolerance(), config.getSampleSpecificVarianceRelativeTolerance(), config.getSampleSpecificVarianceMaximumIterations());
});
/* solve and collect statistics */
final INDArray newSampleUnexplainedVariance = Nd4j.create(numSamples, 1);
final List<Integer> numberOfEvaluations = new ArrayList<>(numSamples);
try {
final Map<Integer, SynchronizedUnivariateSolver.UnivariateSolverSummary> newSampleSpecificVarianceMap = syncSolver.solve();
newSampleSpecificVarianceMap.entrySet().forEach(entry -> {
final int sampleIndex = entry.getKey();
final SynchronizedUnivariateSolver.UnivariateSolverSummary summary = entry.getValue();
double val = 0;
switch(summary.status) {
case SUCCESS:
val = summary.x;
break;
case TOO_MANY_EVALUATIONS:
logger.warn("Could not locate the root of gamma -- increase the maximum number of" + "function evaluations");
break;
}
newSampleUnexplainedVariance.put(sampleIndex, 0, val);
numberOfEvaluations.add(summary.evaluations);
});
} catch (final InterruptedException ex) {
throw new RuntimeException("The update of sample unexplained variance was interrupted -- can not continue");
}
/* admix */
final INDArray newSampleUnexplainedVarianceAdmixed = newSampleUnexplainedVariance.mul(config.getMeanFieldAdmixingRatio()).addi(sampleUnexplainedVariance.mul(1 - config.getMeanFieldAdmixingRatio()));
/* calculate the error */
final double errorNormInfinity = CoverageModelEMWorkspaceMathUtils.getINDArrayNormInfinity(newSampleUnexplainedVarianceAdmixed.sub(sampleUnexplainedVariance));
/* update local copy */
sampleUnexplainedVariance.assign(newSampleUnexplainedVarianceAdmixed);
/* push to workers */
pushToWorkers(newSampleUnexplainedVarianceAdmixed, (arr, cb) -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.gamma_s, newSampleUnexplainedVarianceAdmixed));
final int iterations = (int) (numberOfEvaluations.stream().mapToDouble(d -> d).sum() / numSamples);
return SubroutineSignal.builder().put(StandardSubroutineSignals.RESIDUAL_ERROR_NORM, errorNormInfinity).put(StandardSubroutineSignals.ITERATIONS, iterations).build();
}
use of org.apache.commons.math3.stat.descriptive.moment.Variance in project gatk by broadinstitute.
the class CoverageModelEMWorkspace method initializeWorkersWithPCA.
/**
* Initialize model parameters by performing PCA.
*/
@EvaluatesRDD
@UpdatesRDD
@CachesRDD
private void initializeWorkersWithPCA() {
logger.info("Initializing model parameters using PCA...");
/* initially, set m_t, Psi_t and W_tl to zero to get an estimate of the read depth */
final int numLatents = config.getNumLatents();
mapWorkers(cb -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.m_t, Nd4j.zeros(new int[] { 1, cb.getTargetSpaceBlock().getNumElements() })).cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.Psi_t, Nd4j.zeros(new int[] { 1, cb.getTargetSpaceBlock().getNumElements() })));
if (biasCovariatesEnabled) {
mapWorkers(cb -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.W_tl, Nd4j.zeros(new int[] { cb.getTargetSpaceBlock().getNumElements(), numLatents })));
}
/* update read depth without taking into account correction from bias covariates */
updateReadDepthPosteriorExpectations(1.0, true);
/* fetch sample covariance matrix */
final int minPCAInitializationReadCount = config.getMinPCAInitializationReadCount();
mapWorkers(cb -> cb.cloneWithPCAInitializationData(minPCAInitializationReadCount, Integer.MAX_VALUE));
cacheWorkers("PCA initialization");
final INDArray targetCovarianceMatrix = mapWorkersAndReduce(CoverageModelEMComputeBlock::calculateTargetCovarianceMatrixForPCAInitialization, INDArray::add);
/* perform eigen-decomposition on the target covariance matrix */
final ImmutablePair<INDArray, INDArray> targetCovarianceEigensystem = CoverageModelEMWorkspaceMathUtils.eig(targetCovarianceMatrix, false, logger);
/* the eigenvalues of sample covariance matrix can be immediately inferred by scaling */
final INDArray sampleCovarianceEigenvalues = targetCovarianceEigensystem.getLeft().div(numSamples);
/* estimate the isotropic unexplained variance -- see Bishop 12.46 */
final int residualDim = numTargets - numLatents;
final double isotropicVariance = sampleCovarianceEigenvalues.get(NDArrayIndex.interval(numLatents, numSamples)).sumNumber().doubleValue() / residualDim;
logger.info(String.format("PCA estimate of isotropic unexplained variance: %f", isotropicVariance));
/* estimate bias factors -- see Bishop 12.45 */
final INDArray scaleFactors = Transforms.sqrt(sampleCovarianceEigenvalues.get(NDArrayIndex.interval(0, numLatents)).sub(isotropicVariance), false);
final INDArray biasCovariatesPCA = Nd4j.create(new int[] { numTargets, numLatents });
for (int li = 0; li < numLatents; li++) {
final INDArray v = targetCovarianceEigensystem.getRight().getColumn(li);
/* calculate [Delta_PCA_st]^T v */
/* note: we do not need to broadcast vec since it is small and lambda capture is just fine */
final INDArray unnormedBiasCovariate = CoverageModelSparkUtils.assembleINDArrayBlocksFromCollection(mapWorkersAndCollect(cb -> ImmutablePair.of(cb.getTargetSpaceBlock(), cb.getINDArrayFromCache(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.Delta_PCA_st).transpose().mmul(v))), 0);
final double norm = unnormedBiasCovariate.norm1Number().doubleValue();
final INDArray normedBiasCovariate = unnormedBiasCovariate.divi(norm).muli(scaleFactors.getDouble(li));
biasCovariatesPCA.getColumn(li).assign(normedBiasCovariate);
}
if (ardEnabled) {
/* a better estimate of ARD coefficients */
biasCovariatesARDCoefficients.assign(Nd4j.zeros(new int[] { 1, numLatents }).addi(config.getInitialARDPrecisionRelativeToNoise() / isotropicVariance));
}
final CoverageModelParameters modelParamsFromPCA = new CoverageModelParameters(processedTargetList, Nd4j.zeros(new int[] { 1, numTargets }), Nd4j.zeros(new int[] { 1, numTargets }).addi(isotropicVariance), biasCovariatesPCA, biasCovariatesARDCoefficients);
/* clear PCA initialization data from workers */
mapWorkers(CoverageModelEMComputeBlock::cloneWithRemovedPCAInitializationData);
/* push model parameters to workers */
initializeWorkersWithGivenModel(modelParamsFromPCA);
/* update bias latent posterior expectations without admixing */
updateBiasLatentPosteriorExpectations(1.0);
}
use of org.apache.commons.math3.stat.descriptive.moment.Variance in project gatk by broadinstitute.
the class CoverageModelEMWorkspace method updateTargetUnexplainedVarianceIsotropic.
/**
* M-step update of unexplained variance in the isotropic mode
*
* @return a {@link SubroutineSignal} object containing "error_norm" and "iterations" fields
*/
@UpdatesRDD
@CachesRDD
private SubroutineSignal updateTargetUnexplainedVarianceIsotropic() {
mapWorkers(cb -> cb.cloneWithUpdatedCachesByTag(CoverageModelEMComputeBlock.CoverageModelICGCacheTag.M_STEP_PSI));
cacheWorkers("after M-step update of isotropic unexplained variance initialization");
final double oldIsotropicTargetSpecificVariance = fetchFromWorkers(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.Psi_t, 1).meanNumber().doubleValue();
final UnivariateFunction objFunc = psi -> mapWorkersAndReduce(cb -> cb.calculateSampleTargetSummedTargetSpecificVarianceObjectiveFunction(psi), (a, b) -> a + b);
final UnivariateFunction meritFunc = psi -> mapWorkersAndReduce(cb -> cb.calculateSampleTargetSummedTargetSpecificVarianceMeritFunction(psi), (a, b) -> a + b);
final RobustBrentSolver solver = new RobustBrentSolver(config.getTargetSpecificVarianceRelativeTolerance(), config.getTargetSpecificVarianceAbsoluteTolerance(), CoverageModelGlobalConstants.DEFAULT_FUNCTION_EVALUATION_ACCURACY, meritFunc, config.getTargetSpecificVarianceSolverNumBisections(), config.getTargetSpecificVarianceSolverRefinementDepth());
double newIsotropicTargetSpecificVariance;
try {
newIsotropicTargetSpecificVariance = solver.solve(config.getTargetSpecificVarianceMaxIterations(), objFunc, 0, config.getTargetSpecificVarianceUpperLimit());
} catch (NoBracketingException e) {
logger.warn("Root of M-step optimality equation for isotropic unexplained variance could be bracketed");
newIsotropicTargetSpecificVariance = oldIsotropicTargetSpecificVariance;
} catch (TooManyEvaluationsException e) {
logger.warn("Too many evaluations -- increase the number of root-finding iterations for the M-step update" + " of unexplained variance");
newIsotropicTargetSpecificVariance = oldIsotropicTargetSpecificVariance;
}
/* update the compute block(s) */
final double errNormInfinity = FastMath.abs(newIsotropicTargetSpecificVariance - oldIsotropicTargetSpecificVariance);
final int maxIterations = solver.getEvaluations();
final double finalizedNewIsotropicTargetSpecificVariance = newIsotropicTargetSpecificVariance;
mapWorkers(cb -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.Psi_t, Nd4j.zeros(1, cb.getTargetSpaceBlock().getNumElements()).addi(finalizedNewIsotropicTargetSpecificVariance)));
return SubroutineSignal.builder().put(StandardSubroutineSignals.RESIDUAL_ERROR_NORM, errNormInfinity).put(StandardSubroutineSignals.ITERATIONS, maxIterations).build();
}
Aggregations