Search in sources :

Example 1 with JavaPairRDD

use of org.apache.spark.api.java.JavaPairRDD in project incubator-systemml by apache.

the class DataTransform method spDataTransform.

public static void spDataTransform(ParameterizedBuiltinSPInstruction inst, FrameObject[] inputs, MatrixObject[] outputs, ExecutionContext ec) throws Exception {
    SparkExecutionContext sec = (SparkExecutionContext) ec;
    // Parse transform instruction (the first instruction) to obtain relevant fields
    TransformOperands oprnds = new TransformOperands(inst.getParams(), inputs[0]);
    JobConf job = new JobConf();
    FileSystem fs = IOUtilFunctions.getFileSystem(inputs[0].getFileName());
    checkIfOutputOverlapsWithTxMtd(oprnds.txMtdPath, outputs[0].getFileName(), fs);
    // find the first file in alphabetical ordering of partfiles in directory inputPath 
    String smallestFile = CSVReblockMR.findSmallestFile(job, oprnds.inputPath);
    // find column names and construct output header
    String headerLine = readHeaderLine(fs, oprnds.inputCSVProperties, smallestFile);
    HashMap<String, Integer> colNamesToIds = processColumnNames(fs, oprnds.inputCSVProperties, headerLine, smallestFile);
    int numColumns = colNamesToIds.size();
    String outHeader = getOutputHeader(fs, headerLine, oprnds);
    String tmpPath = MRJobConfiguration.constructTempOutputFilename();
    // Construct RDD for input data
    @SuppressWarnings("unchecked") JavaPairRDD<LongWritable, Text> inputData = (JavaPairRDD<LongWritable, Text>) sec.getRDDHandleForFrameObject(inputs[0], InputInfo.CSVInputInfo);
    JavaRDD<Tuple2<LongWritable, Text>> csvLines = JavaPairRDD.toRDD(inputData).toJavaRDD();
    long numRowsTf = 0, numColumnsTf = 0;
    JavaPairRDD<Long, String> tfPairRDD = null;
    if (!oprnds.isApply) {
        // build specification file with column IDs insteadof column names
        String specWithIDs = processSpecFile(fs, oprnds.inputPath, smallestFile, colNamesToIds, oprnds.inputCSVProperties, oprnds.spec);
        // enable GC on colNamesToIds
        colNamesToIds = null;
        // Build transformation metadata, including recode maps, bin definitions, etc.
        // Also, generate part offsets file (counters file), which is to be used in csv-reblock (if needed)
        String partOffsetsFile = MRJobConfiguration.constructTempOutputFilename();
        numRowsTf = GenTfMtdSPARK.runSparkJob(sec, csvLines, oprnds.txMtdPath, specWithIDs, partOffsetsFile, oprnds.inputCSVProperties, numColumns, outHeader);
        // store the specFileWithIDs as transformation metadata
        MapReduceTool.writeStringToHDFS(specWithIDs, oprnds.txMtdPath + "/" + "spec.json");
        numColumnsTf = getNumColumnsTf(fs, outHeader, oprnds.inputCSVProperties.getDelim(), oprnds.txMtdPath);
        tfPairRDD = ApplyTfCSVSPARK.runSparkJob(sec, csvLines, oprnds.txMtdPath, specWithIDs, tmpPath, oprnds.inputCSVProperties, numColumns, outHeader);
        MapReduceTool.deleteFileIfExistOnHDFS(new Path(partOffsetsFile), job);
    } else {
        // enable GC on colNamesToIds
        colNamesToIds = null;
        // copy given transform metadata (applyTxPath) to specified location (txMtdPath)
        MapReduceTool.deleteFileIfExistOnHDFS(new Path(oprnds.txMtdPath), job);
        MapReduceTool.copyFileOnHDFS(oprnds.applyTxPath, oprnds.txMtdPath);
        // path to specification file
        String specWithIDs = (oprnds.spec != null) ? oprnds.spec : MapReduceTool.readStringFromHDFSFile(oprnds.txMtdPath + "/" + "spec.json");
        numColumnsTf = getNumColumnsTf(fs, outHeader, oprnds.inputCSVProperties.getDelim(), oprnds.txMtdPath);
        // Apply transformation metadata, and perform actual transformation 
        tfPairRDD = ApplyTfCSVSPARK.runSparkJob(sec, csvLines, oprnds.txMtdPath, specWithIDs, tmpPath, oprnds.inputCSVProperties, numColumns, outHeader);
    }
    // copy auxiliary data (old and new header lines) from temporary location to txMtdPath
    moveFilesFromTmp(fs, tmpPath, oprnds.txMtdPath);
    // convert to csv output format (serialized longwritable/text)
    JavaPairRDD<LongWritable, Text> outtfPairRDD = RDDConverterUtils.stringToSerializableText(tfPairRDD);
    if (outtfPairRDD != null) {
        MatrixObject outMO = outputs[0];
        String outVar = outMO.getVarName();
        outMO.setRDDHandle(new RDDObject(outtfPairRDD, outVar));
        sec.addLineageRDD(outVar, inst.getParams().get("target"));
        //update output statistics (required for correctness)
        MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(outVar);
        mcOut.setDimension(numRowsTf, numColumnsTf);
        mcOut.setNonZeros(-1);
    }
}
Also used : Path(org.apache.hadoop.fs.Path) MatrixObject(org.apache.sysml.runtime.controlprogram.caching.MatrixObject) Text(org.apache.hadoop.io.Text) MatrixCharacteristics(org.apache.sysml.runtime.matrix.MatrixCharacteristics) Tuple2(scala.Tuple2) FileSystem(org.apache.hadoop.fs.FileSystem) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) RDDObject(org.apache.sysml.runtime.instructions.spark.data.RDDObject) SparkExecutionContext(org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext) LongWritable(org.apache.hadoop.io.LongWritable) JobConf(org.apache.hadoop.mapred.JobConf)

Example 2 with JavaPairRDD

use of org.apache.spark.api.java.JavaPairRDD in project incubator-systemml by apache.

the class MLContextConversionUtil method frameObjectToBinaryBlockFrame.

/**
	 * Convert a {@code FrameObject} to a {@code BinaryBlockFrame}.
	 * 
	 * @param frameObject
	 *            the {@code FrameObject}
	 * @param sparkExecutionContext
	 *            the Spark execution context
	 * @return the {@code FrameObject} converted to a {@code BinaryBlockFrame}
	 */
public static BinaryBlockFrame frameObjectToBinaryBlockFrame(FrameObject frameObject, SparkExecutionContext sparkExecutionContext) {
    try {
        @SuppressWarnings("unchecked") JavaPairRDD<Long, FrameBlock> binaryBlock = (JavaPairRDD<Long, FrameBlock>) sparkExecutionContext.getRDDHandleForFrameObject(frameObject, InputInfo.BinaryBlockInputInfo);
        MatrixCharacteristics matrixCharacteristics = frameObject.getMatrixCharacteristics();
        FrameSchema fs = new FrameSchema(Arrays.asList(frameObject.getSchema()));
        FrameMetadata fm = new FrameMetadata(fs, matrixCharacteristics);
        return new BinaryBlockFrame(binaryBlock, fm);
    } catch (DMLRuntimeException e) {
        throw new MLContextException("DMLRuntimeException while converting frame object to BinaryBlockFrame", e);
    }
}
Also used : FrameBlock(org.apache.sysml.runtime.matrix.data.FrameBlock) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) MatrixCharacteristics(org.apache.sysml.runtime.matrix.MatrixCharacteristics) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException)

Example 3 with JavaPairRDD

use of org.apache.spark.api.java.JavaPairRDD in project gatk-protected by broadinstitute.

the class HaplotypeCallerSpark method createReadShards.

/**
     * Create an RDD of {@link Shard} from an RDD of {@link GATKRead}
     * @param shardBoundariesBroadcast  broadcast of an {@link OverlapDetector} loaded with the intervals that should be used for creating ReadShards
     * @param reads Rdd of {@link GATKRead}
     * @return a Rdd of reads grouped into potentially overlapping shards
     */
private static JavaRDD<Shard<GATKRead>> createReadShards(final Broadcast<OverlapDetector<ShardBoundary>> shardBoundariesBroadcast, final JavaRDD<GATKRead> reads) {
    final JavaPairRDD<ShardBoundary, GATKRead> paired = reads.flatMapToPair(read -> {
        final Collection<ShardBoundary> overlappingShards = shardBoundariesBroadcast.value().getOverlaps(read);
        return overlappingShards.stream().map(key -> new Tuple2<>(key, read)).iterator();
    });
    final JavaPairRDD<ShardBoundary, Iterable<GATKRead>> shardsWithReads = paired.groupByKey();
    return shardsWithReads.map(shard -> new SparkReadShard(shard._1(), shard._2()));
}
Also used : CommandLineProgramProperties(org.broadinstitute.barclay.argparser.CommandLineProgramProperties) SparkProgramGroup(org.broadinstitute.hellbender.cmdline.programgroups.SparkProgramGroup) ReferenceMultiSource(org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource) Argument(org.broadinstitute.barclay.argparser.Argument) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Advanced(org.broadinstitute.barclay.argparser.Advanced) org.broadinstitute.hellbender.cmdline(org.broadinstitute.hellbender.cmdline) ArgumentCollection(org.broadinstitute.barclay.argparser.ArgumentCollection) GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) SAMFileHeader(htsjdk.samtools.SAMFileHeader) GATKException(org.broadinstitute.hellbender.exceptions.GATKException) Function(java.util.function.Function) ReferenceSequenceFile(htsjdk.samtools.reference.ReferenceSequenceFile) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) ReferenceBases(org.broadinstitute.hellbender.utils.reference.ReferenceBases) HaplotypeCallerArgumentCollection(org.broadinstitute.hellbender.tools.walkers.haplotypecaller.HaplotypeCallerArgumentCollection) SparkReadShard(org.broadinstitute.hellbender.engine.spark.SparkReadShard) StreamSupport(java.util.stream.StreamSupport) JavaRDD(org.apache.spark.api.java.JavaRDD) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) org.broadinstitute.barclay.argparser(org.broadinstitute.barclay.argparser) Broadcast(org.apache.spark.broadcast.Broadcast) OverlapDetector(htsjdk.samtools.util.OverlapDetector) Iterator(java.util.Iterator) SAMSequenceDictionary(htsjdk.samtools.SAMSequenceDictionary) Collection(java.util.Collection) GATKSparkTool(org.broadinstitute.hellbender.engine.spark.GATKSparkTool) IOException(java.io.IOException) ReadFilter(org.broadinstitute.hellbender.engine.filters.ReadFilter) Tuple2(scala.Tuple2) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) HaplotypeCaller(org.broadinstitute.hellbender.tools.walkers.haplotypecaller.HaplotypeCaller) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval) Collectors(java.util.stream.Collectors) HaplotypeCallerEngine(org.broadinstitute.hellbender.tools.walkers.haplotypecaller.HaplotypeCallerEngine) Serializable(java.io.Serializable) org.broadinstitute.hellbender.engine(org.broadinstitute.hellbender.engine) List(java.util.List) Stream(java.util.stream.Stream) UserException(org.broadinstitute.hellbender.exceptions.UserException) VariantContextWriter(htsjdk.variant.variantcontext.writer.VariantContextWriter) VariantContext(htsjdk.variant.variantcontext.VariantContext) Utils(org.broadinstitute.hellbender.utils.Utils) ReferenceSequence(htsjdk.samtools.reference.ReferenceSequence) VisibleForTesting(com.google.common.annotations.VisibleForTesting) Collections(java.util.Collections) GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) SparkReadShard(org.broadinstitute.hellbender.engine.spark.SparkReadShard) Tuple2(scala.Tuple2)

Example 4 with JavaPairRDD

use of org.apache.spark.api.java.JavaPairRDD in project gatk by broadinstitute.

the class CoverageModelEMWorkspace method updateCopyRatioPosteriorExpectationsSpark.

/**
     * The Spark implementation of the E-step update of copy ratio posteriors
     *
     * @return a {@link SubroutineSignal} containing the update size
     */
@EvaluatesRDD
@UpdatesRDD
@CachesRDD
private SubroutineSignal updateCopyRatioPosteriorExpectationsSpark(final double admixingRatio) {
    /* local final member variables for lambda capture */
    final List<LinearlySpacedIndexBlock> targetBlocks = new ArrayList<>();
    targetBlocks.addAll(this.targetBlocks);
    final List<Target> targetList = new ArrayList<>();
    targetList.addAll(processedTargetList);
    final List<String> sampleNameList = new ArrayList<>();
    sampleNameList.addAll(processedSampleNameList);
    final List<SexGenotypeData> sampleSexGenotypeData = new ArrayList<>();
    sampleSexGenotypeData.addAll(processedSampleSexGenotypeData);
    final int numTargetBlocks = targetBlocks.size();
    final CopyRatioExpectationsCalculator<CoverageModelCopyRatioEmissionData, STATE> calculator = this.copyRatioExpectationsCalculator;
    final INDArray sampleReadDepths = Transforms.exp(sampleMeanLogReadDepths, true);
    /* make an RDD of copy ratio posterior expectations */
    final JavaPairRDD<Integer, CopyRatioExpectations> copyRatioPosteriorExpectationsPairRDD = /* fetch copy ratio emission data from workers */
    fetchCopyRatioEmissionDataSpark().mapPartitionsToPair(it -> {
        final List<Tuple2<Integer, CopyRatioExpectations>> newPartitionData = new ArrayList<>();
        while (it.hasNext()) {
            final Tuple2<Integer, List<CoverageModelCopyRatioEmissionData>> prevDatum = it.next();
            final int si = prevDatum._1;
            final CopyRatioCallingMetadata copyRatioCallingMetadata = CopyRatioCallingMetadata.builder().sampleName(sampleNameList.get(si)).sampleSexGenotypeData(sampleSexGenotypeData.get(si)).sampleCoverageDepth(sampleReadDepths.getDouble(si)).emissionCalculationStrategy(EmissionCalculationStrategy.HYBRID_POISSON_GAUSSIAN).build();
            newPartitionData.add(new Tuple2<>(prevDatum._1, calculator.getCopyRatioPosteriorExpectations(copyRatioCallingMetadata, targetList, prevDatum._2)));
        }
        return newPartitionData.iterator();
    }, true);
    /* we need to do two things to copyRatioPosteriorExpectationsPairRDD; so we cache it */
    /* step 1. update log chain posterior expectation on the driver node */
    final double[] newSampleLogChainPosteriors = copyRatioPosteriorExpectationsPairRDD.mapValues(CopyRatioExpectations::getLogChainPosteriorProbability).collect().stream().sorted(Comparator.comparingInt(t -> t._1)).mapToDouble(t -> t._2).toArray();
    sampleLogChainPosteriors.assign(Nd4j.create(newSampleLogChainPosteriors, new int[] { numSamples, 1 }));
    /* step 2. repartition in target space */
    final JavaPairRDD<LinearlySpacedIndexBlock, ImmutablePair<INDArray, INDArray>> blockifiedCopyRatioPosteriorResultsPairRDD = copyRatioPosteriorExpectationsPairRDD.flatMapToPair(dat -> targetBlocks.stream().map(tb -> new Tuple2<>(tb, new Tuple2<>(dat._1, ImmutablePair.of(dat._2.getLogCopyRatioMeans(tb), dat._2.getLogCopyRatioVariances(tb))))).iterator()).combineByKey(/* recipe to create an singleton list */
    Collections::singletonList, /* recipe to add an element to the list */
    (list, element) -> Stream.concat(list.stream(), Stream.of(element)).collect(Collectors.toList()), /* recipe to concatenate two lists */
    (list1, list2) -> Stream.concat(list1.stream(), list2.stream()).collect(Collectors.toList()), /* repartition with respect to target space blocks */
    new HashPartitioner(numTargetBlocks)).mapValues(list -> list.stream().sorted(Comparator.comparingInt(t -> t._1)).map(p -> p._2).map(t -> ImmutablePair.of(Nd4j.create(t.left), Nd4j.create(t.right))).collect(Collectors.toList())).mapValues(CoverageModelEMWorkspace::stackCopyRatioPosteriorDataForAllSamples);
    /* we do not need copy ratio expectations anymore */
    copyRatioPosteriorExpectationsPairRDD.unpersist();
    /* step 3. merge with computeRDD and update */
    computeRDD = computeRDD.join(blockifiedCopyRatioPosteriorResultsPairRDD).mapValues(t -> t._1.cloneWithUpdatedCopyRatioPosteriors(t._2.left, t._2.right, admixingRatio));
    cacheWorkers("after E-step for copy ratio update");
    /* collect subroutine signals */
    final List<SubroutineSignal> sigs = mapWorkersAndCollect(CoverageModelEMComputeBlock::getLatestMStepSignal);
    final double errorNormInfinity = Collections.max(sigs.stream().map(sig -> sig.<Double>get(StandardSubroutineSignals.RESIDUAL_ERROR_NORM)).collect(Collectors.toList()));
    return SubroutineSignal.builder().put(StandardSubroutineSignals.RESIDUAL_ERROR_NORM, errorNormInfinity).build();
}
Also used : ScalarProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.ScalarProducer) Function2(org.apache.spark.api.java.function.Function2) HMMSegmentProcessor(org.broadinstitute.hellbender.utils.hmm.segmentation.HMMSegmentProcessor) GermlinePloidyAnnotatedTargetCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.GermlinePloidyAnnotatedTargetCollection) HiddenStateSegmentRecordWriter(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecordWriter) BiFunction(java.util.function.BiFunction) GATKException(org.broadinstitute.hellbender.exceptions.GATKException) SexGenotypeData(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeData) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) CallStringProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.CallStringProducer) StorageLevel(org.apache.spark.storage.StorageLevel) SynchronizedUnivariateSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.SynchronizedUnivariateSolver) CopyRatioExpectationsCalculator(org.broadinstitute.hellbender.tools.coveragemodel.interfaces.CopyRatioExpectationsCalculator) UnivariateSolverSpecifications(org.broadinstitute.hellbender.tools.coveragemodel.math.UnivariateSolverSpecifications) IndexRange(org.broadinstitute.hellbender.utils.IndexRange) Broadcast(org.apache.spark.broadcast.Broadcast) ExitStatus(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray.ExitStatus) SexGenotypeDataCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeDataCollection) HashPartitioner(org.apache.spark.HashPartitioner) Predicate(java.util.function.Predicate) GeneralLinearOperator(org.broadinstitute.hellbender.tools.coveragemodel.linalg.GeneralLinearOperator) Nd4j(org.nd4j.linalg.factory.Nd4j) INDArrayIndex(org.nd4j.linalg.indexing.INDArrayIndex) FastMath(org.apache.commons.math3.util.FastMath) org.broadinstitute.hellbender.tools.exome(org.broadinstitute.hellbender.tools.exome) Tuple2(scala.Tuple2) Collectors(java.util.stream.Collectors) Sets(com.google.common.collect.Sets) AbstractUnivariateSolver(org.apache.commons.math3.analysis.solvers.AbstractUnivariateSolver) FourierLinearOperatorNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.FourierLinearOperatorNDArray) Logger(org.apache.logging.log4j.Logger) Stream(java.util.stream.Stream) UserException(org.broadinstitute.hellbender.exceptions.UserException) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Utils(org.broadinstitute.hellbender.utils.Utils) Function(org.apache.spark.api.java.function.Function) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) IntStream(java.util.stream.IntStream) java.util(java.util) NDArrayIndex(org.nd4j.linalg.indexing.NDArrayIndex) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) AlleleMetadataProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.AlleleMetadataProducer) EmissionCalculationStrategy(org.broadinstitute.hellbender.tools.coveragemodel.CoverageModelCopyRatioEmissionProbabilityCalculator.EmissionCalculationStrategy) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) HiddenStateSegmentRecord(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecord) ImmutableTriple(org.apache.commons.lang3.tuple.ImmutableTriple) IterativeLinearSolverNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray) GATKProtectedMathUtils(org.broadinstitute.hellbender.utils.GATKProtectedMathUtils) Nd4jIOUtils(org.broadinstitute.hellbender.tools.coveragemodel.nd4jutils.Nd4jIOUtils) IOException(java.io.IOException) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) File(java.io.File) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VisibleForTesting(com.google.common.annotations.VisibleForTesting) Transforms(org.nd4j.linalg.ops.transforms.Transforms) LogManager(org.apache.logging.log4j.LogManager) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) Tuple2(scala.Tuple2) SexGenotypeData(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeData) HashPartitioner(org.apache.spark.HashPartitioner)

Example 5 with JavaPairRDD

use of org.apache.spark.api.java.JavaPairRDD in project gatk by broadinstitute.

the class CoverageModelWPreconditionerSpark method operate.

@Override
public INDArray operate(@Nonnull final INDArray W_tl) throws DimensionMismatchException {
    if (W_tl.rank() != 2 || W_tl.shape()[0] != numTargets || W_tl.shape()[1] != numLatents) {
        throw new DimensionMismatchException(W_tl.length(), numTargets * numLatents);
    }
    long startTimeRFFT = System.nanoTime();
    /* forward rfft */
    final INDArray W_kl = Nd4j.create(fftSize, numLatents);
    IntStream.range(0, numLatents).parallel().forEach(li -> W_kl.get(NDArrayIndex.all(), NDArrayIndex.point(li)).assign(Nd4j.create(F_tt.getForwardFFT(W_tl.get(NDArrayIndex.all(), NDArrayIndex.point(li))), new int[] { fftSize, 1 })));
    long endTimeRFFT = System.nanoTime();
    /* apply the preconditioner in the Fourier space */
    long startTimePrecond = System.nanoTime();
    final Map<LinearlySpacedIndexBlock, INDArray> W_kl_map = CoverageModelSparkUtils.partitionINDArrayToMap(fourierSpaceBlocks, W_kl);
    final Broadcast<Map<LinearlySpacedIndexBlock, INDArray>> W_kl_bc = ctx.broadcast(W_kl_map);
    final JavaPairRDD<LinearlySpacedIndexBlock, INDArray> preconditionedWRDD = linOpPairRDD.mapToPair(p -> {
        final INDArray W_kl_chuck = W_kl_bc.value().get(p._1);
        final INDArray linOp_chunk = p._2;
        final int blockSize = linOp_chunk.shape()[0];
        final List<INDArray> linOpWList = IntStream.range(0, blockSize).parallel().mapToObj(k -> CoverageModelEMWorkspaceMathUtils.linsolve(linOp_chunk.get(NDArrayIndex.point(k)), W_kl_chuck.get(NDArrayIndex.point(k)))).collect(Collectors.toList());
        return new Tuple2<>(p._1, Nd4j.vstack(linOpWList));
    });
    W_kl.assign(CoverageModelSparkUtils.assembleINDArrayBlocksFromRDD(preconditionedWRDD, 0));
    W_kl_bc.destroy();
    //        final JavaPairRDD<LinearlySpacedIndexBlock, INDArray> W_kl_RDD = CoverageModelSparkUtils.rddFromINDArray(W_kl,
    //                fourierSpaceBlocks, ctx, true);
    //        W_kl.assign(CoverageModelSparkUtils.assembleINDArrayBlocks(linOpPairRDD.join((W_kl_RDD))
    //                .mapValues(p -> {
    //                    final INDArray linOp = p._1;
    //                    final INDArray W = p._2;
    //                    final int blockSize = linOp.shape()[0];
    //                    final List<INDArray> linOpWList = IntStream.range(0, blockSize).parallel().mapToObj(k ->
    //                            CoverageModelEMWorkspaceMathUtils.linsolve(linOp.get(NDArrayIndex.point(k)),
    //                                    W.get(NDArrayIndex.point(k))))
    //                            .collect(Collectors.toList());
    //                    return Nd4j.vstack(linOpWList);
    //                }), false));
    //        W_kl_RDD.unpersist();
    long endTimePrecond = System.nanoTime();
    /* irfft */
    long startTimeIRFFT = System.nanoTime();
    final INDArray res = Nd4j.create(numTargets, numLatents);
    IntStream.range(0, numLatents).parallel().forEach(li -> res.get(NDArrayIndex.all(), NDArrayIndex.point(li)).assign(F_tt.getInverseFFT(W_kl.get(NDArrayIndex.all(), NDArrayIndex.point(li)))));
    long endTimeIRFFT = System.nanoTime();
    logger.debug("Local FFT timing: " + (endTimeRFFT - startTimeRFFT + endTimeIRFFT - startTimeIRFFT) / 1000000 + " ms");
    logger.debug("Spark preconditioner application timing: " + (endTimePrecond - startTimePrecond) / 1000000 + " ms");
    return res;
}
Also used : Broadcast(org.apache.spark.broadcast.Broadcast) IntStream(java.util.stream.IntStream) NDArrayIndex(org.nd4j.linalg.indexing.NDArrayIndex) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) GeneralLinearOperator(org.broadinstitute.hellbender.tools.coveragemodel.linalg.GeneralLinearOperator) Nd4j(org.nd4j.linalg.factory.Nd4j) DimensionMismatchException(org.apache.commons.math3.exception.DimensionMismatchException) Tuple2(scala.Tuple2) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) Collectors(java.util.stream.Collectors) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) FourierLinearOperatorNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.FourierLinearOperatorNDArray) List(java.util.List) Logger(org.apache.logging.log4j.Logger) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Map(java.util.Map) LogManager(org.apache.logging.log4j.LogManager) Nonnull(javax.annotation.Nonnull) DimensionMismatchException(org.apache.commons.math3.exception.DimensionMismatchException) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Tuple2(scala.Tuple2) Map(java.util.Map)

Aggregations

JavaPairRDD (org.apache.spark.api.java.JavaPairRDD)99 MatrixBlock (org.apache.sysml.runtime.matrix.data.MatrixBlock)44 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)42 MatrixIndexes (org.apache.sysml.runtime.matrix.data.MatrixIndexes)42 MatrixCharacteristics (org.apache.sysml.runtime.matrix.MatrixCharacteristics)41 Tuple2 (scala.Tuple2)35 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)33 JavaRDD (org.apache.spark.api.java.JavaRDD)28 List (java.util.List)27 SparkExecutionContext (org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext)24 FrameBlock (org.apache.sysml.runtime.matrix.data.FrameBlock)23 Collectors (java.util.stream.Collectors)22 IOException (java.io.IOException)17 RDDObject (org.apache.sysml.runtime.instructions.spark.data.RDDObject)16 LongWritable (org.apache.hadoop.io.LongWritable)15 Broadcast (org.apache.spark.broadcast.Broadcast)15 Text (org.apache.hadoop.io.Text)12 UserException (org.broadinstitute.hellbender.exceptions.UserException)12 Function (org.apache.spark.api.java.function.Function)11 MatrixObject (org.apache.sysml.runtime.controlprogram.caching.MatrixObject)11