use of org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray in project gatk-protected by broadinstitute.
the class CoverageModelEMWorkspace method updateBiasCovariatesRegularized.
/**
* E-step update of bias covariates w/ regularization (local implementation)
*
* @return a {@link SubroutineSignal} containing the update size (key: "error_norm")
*/
@UpdatesRDD
@EvaluatesRDD
@CachesRDD
private SubroutineSignal updateBiasCovariatesRegularized(final double admixingRatio) {
mapWorkers(cb -> cb.cloneWithUpdatedCachesByTag(CoverageModelEMComputeBlock.CoverageModelICGCacheTag.E_STEP_W_REG));
cacheWorkers("after E-step update of bias covariates w/ regularization");
final INDArray W_tl_old = fetchFromWorkers(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.W_tl, 0);
final INDArray v_tl = fetchFromWorkers(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.v_tl, 0);
/* initialize the linear operators */
final GeneralLinearOperator<INDArray> linop, precond;
final ImmutablePair<GeneralLinearOperator<INDArray>, GeneralLinearOperator<INDArray>> ops = getBiasCovariatesRegularizedLinearOperators();
linop = ops.left;
precond = ops.right;
/* initialize the iterative solver */
final IterativeLinearSolverNDArray iterSolver = new IterativeLinearSolverNDArray(linop, v_tl, precond, config.getWAbsoluteTolerance(), config.getWRelativeTolerance(), config.getWMaxIterations(), x -> x.normmaxNumber().doubleValue(), /* norm */
(x, y) -> x.mul(y).sumNumber().doubleValue(), /* inner product */
true);
/* solve */
long startTime = System.nanoTime();
final SubroutineSignal sig = iterSolver.solveUsingPreconditionedConjugateGradient(W_tl_old);
linop.cleanupAfter();
precond.cleanupAfter();
long endTime = System.nanoTime();
logger.debug("CG execution time for solving the regularized M-step update equation for bias covariates" + (double) (endTime - startTime) / 1000000 + " ms");
/* check the exit status of the solver and push the new W to workers */
final ExitStatus exitStatus = sig.get(StandardSubroutineSignals.EXIT_STATUS);
if (exitStatus == ExitStatus.FAIL_MAX_ITERS) {
logger.warn("CG iterations for M-step update of bias covariates did not converge. Increase maximum iterations" + " and/or decrease absolute/relative error tolerances");
}
final int iters = sig.<Integer>get(StandardSubroutineSignals.ITERATIONS);
final INDArray W_tl_new = sig.get(StandardSubroutineSignals.SOLUTION);
switch(config.getBiasCovariatesComputeNodeCommunicationPolicy()) {
case BROADCAST_HASH_JOIN:
pushToWorkers(mapINDArrayToBlocks(W_tl_new), (W, cb) -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.W_tl, W.get(cb.getTargetSpaceBlock())));
break;
case RDD_JOIN:
joinWithWorkersAndMap(chopINDArrayToBlocks(W_tl_new), p -> p._1.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.W_tl, p._2));
break;
default:
throw new GATKException.ShouldNeverReachHereException("Unknown communication policy for M-step update" + " of bias covariates");
}
/* update F[W] */
updateFilteredBiasCovariates(W_tl_new);
final double errorNormInfinity = CoverageModelEMWorkspaceMathUtils.getINDArrayNormInfinity(W_tl_new.sub(W_tl_old));
/* send the signal to workers for consistency */
final SubroutineSignal newSig = SubroutineSignal.builder().put(StandardSubroutineSignals.RESIDUAL_ERROR_NORM, errorNormInfinity).put(StandardSubroutineSignals.ITERATIONS, iters).build();
mapWorkers(cb -> cb.cloneWithUpdatedSignal(newSig));
return newSig;
}
use of org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray in project gatk by broadinstitute.
the class CoverageModelEMWorkspace method updateBiasCovariatesRegularized.
/**
* E-step update of bias covariates w/ regularization (local implementation)
*
* @return a {@link SubroutineSignal} containing the update size (key: "error_norm")
*/
@UpdatesRDD
@EvaluatesRDD
@CachesRDD
private SubroutineSignal updateBiasCovariatesRegularized(final double admixingRatio) {
mapWorkers(cb -> cb.cloneWithUpdatedCachesByTag(CoverageModelEMComputeBlock.CoverageModelICGCacheTag.E_STEP_W_REG));
cacheWorkers("after E-step update of bias covariates w/ regularization");
final INDArray W_tl_old = fetchFromWorkers(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.W_tl, 0);
final INDArray v_tl = fetchFromWorkers(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.v_tl, 0);
/* initialize the linear operators */
final GeneralLinearOperator<INDArray> linop, precond;
final ImmutablePair<GeneralLinearOperator<INDArray>, GeneralLinearOperator<INDArray>> ops = getBiasCovariatesRegularizedLinearOperators();
linop = ops.left;
precond = ops.right;
/* initialize the iterative solver */
final IterativeLinearSolverNDArray iterSolver = new IterativeLinearSolverNDArray(linop, v_tl, precond, config.getWAbsoluteTolerance(), config.getWRelativeTolerance(), config.getWMaxIterations(), x -> x.normmaxNumber().doubleValue(), /* norm */
(x, y) -> x.mul(y).sumNumber().doubleValue(), /* inner product */
true);
/* solve */
long startTime = System.nanoTime();
final SubroutineSignal sig = iterSolver.solveUsingPreconditionedConjugateGradient(W_tl_old);
linop.cleanupAfter();
precond.cleanupAfter();
long endTime = System.nanoTime();
logger.debug("CG execution time for solving the regularized M-step update equation for bias covariates" + (double) (endTime - startTime) / 1000000 + " ms");
/* check the exit status of the solver and push the new W to workers */
final ExitStatus exitStatus = sig.get(StandardSubroutineSignals.EXIT_STATUS);
if (exitStatus == ExitStatus.FAIL_MAX_ITERS) {
logger.warn("CG iterations for M-step update of bias covariates did not converge. Increase maximum iterations" + " and/or decrease absolute/relative error tolerances");
}
final int iters = sig.<Integer>get(StandardSubroutineSignals.ITERATIONS);
final INDArray W_tl_new = sig.get(StandardSubroutineSignals.SOLUTION);
switch(config.getBiasCovariatesComputeNodeCommunicationPolicy()) {
case BROADCAST_HASH_JOIN:
pushToWorkers(mapINDArrayToBlocks(W_tl_new), (W, cb) -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.W_tl, W.get(cb.getTargetSpaceBlock())));
break;
case RDD_JOIN:
joinWithWorkersAndMap(chopINDArrayToBlocks(W_tl_new), p -> p._1.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.W_tl, p._2));
break;
default:
throw new GATKException.ShouldNeverReachHereException("Unknown communication policy for M-step update" + " of bias covariates");
}
/* update F[W] */
updateFilteredBiasCovariates(W_tl_new);
final double errorNormInfinity = CoverageModelEMWorkspaceMathUtils.getINDArrayNormInfinity(W_tl_new.sub(W_tl_old));
/* send the signal to workers for consistency */
final SubroutineSignal newSig = SubroutineSignal.builder().put(StandardSubroutineSignals.RESIDUAL_ERROR_NORM, errorNormInfinity).put(StandardSubroutineSignals.ITERATIONS, iters).build();
mapWorkers(cb -> cb.cloneWithUpdatedSignal(newSig));
return newSig;
}
Aggregations