use of org.nd4j.linalg.indexing.INDArrayIndex in project gatk by broadinstitute.
the class CoverageModelEMWorkspace method updateFilteredBiasCovariates.
/**
* This method applies the Fourier filter on a given bias covariates matrix, applies the Fourier filter on it,
* partitions the result, and pushes it to compute block(s)
*
* @param biasCovariates any T x D bias covariates matrix
*/
@UpdatesRDD
private void updateFilteredBiasCovariates(@Nonnull final INDArray biasCovariates) {
final INDArray filteredBiasCovariates = Nd4j.create(biasCovariates.shape());
/* instantiate the Fourier filter */
final FourierLinearOperatorNDArray regularizerFourierLinearOperator = createRegularizerFourierLinearOperator();
/* FFT by resolving W_tl on l */
for (int li = 0; li < numLatents; li++) {
final INDArrayIndex[] slice = { NDArrayIndex.all(), NDArrayIndex.point(li) };
filteredBiasCovariates.get(slice).assign(regularizerFourierLinearOperator.operate(biasCovariates.get(slice)));
}
/* sent the new W to workers */
switch(config.getBiasCovariatesComputeNodeCommunicationPolicy()) {
case BROADCAST_HASH_JOIN:
pushToWorkers(mapINDArrayToBlocks(filteredBiasCovariates), (W, cb) -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.F_W_tl, W.get(cb.getTargetSpaceBlock())));
break;
case RDD_JOIN:
joinWithWorkersAndMap(chopINDArrayToBlocks(filteredBiasCovariates), p -> p._1.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.F_W_tl, p._2));
break;
}
}
Aggregations