Search in sources :

Example 26 with JavaRDD

use of org.apache.spark.api.java.JavaRDD in project gatk by broadinstitute.

the class CNLOHCaller method calcNewRhos.

private double[] calcNewRhos(final List<ACNVModeledSegment> segments, final List<double[][][]> responsibilitiesBySeg, final double lambda, final double[] rhos, final int[] mVals, final int[] nVals, final JavaSparkContext ctx) {
    // Since, we pass in the entire responsibilities matrix, we need the correct index for each rho.  That, and the
    //  fact that this is a univariate objective function, means we need to create an instance for each rho.  And
    //  then we blast across Spark.
    final List<Pair<? extends Function<Double, Double>, SearchInterval>> objectives = IntStream.range(0, rhos.length).mapToObj(i -> new Pair<>(new Function<Double, Double>() {

        @Override
        public Double apply(Double rho) {
            return calculateESmnObjective(rho, segments, responsibilitiesBySeg, mVals, nVals, lambda, i);
        }
    }, new SearchInterval(0.0, 1.0, rhos[i]))).collect(Collectors.toList());
    final JavaRDD<Pair<? extends Function<Double, Double>, SearchInterval>> objectivesRDD = ctx.parallelize(objectives);
    final List<Double> resultsAsDouble = objectivesRDD.map(objective -> optimizeIt(objective.getFirst(), objective.getSecond())).collect();
    return resultsAsDouble.stream().mapToDouble(Double::doubleValue).toArray();
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) SearchInterval(org.apache.commons.math3.optim.univariate.SearchInterval) RealVector(org.apache.commons.math3.linear.RealVector) Function(java.util.function.Function) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) GammaDistribution(org.apache.commons.math3.distribution.GammaDistribution) Gamma(org.apache.commons.math3.special.Gamma) ACNVModeledSegment(org.broadinstitute.hellbender.tools.exome.ACNVModeledSegment) BaseAbstractUnivariateIntegrator(org.apache.commons.math3.analysis.integration.BaseAbstractUnivariateIntegrator) GoalType(org.apache.commons.math3.optim.nonlinear.scalar.GoalType) MatrixUtils(org.apache.commons.math3.linear.MatrixUtils) UnivariateObjectiveFunction(org.apache.commons.math3.optim.univariate.UnivariateObjectiveFunction) ArrayRealVector(org.apache.commons.math3.linear.ArrayRealVector) SimpsonIntegrator(org.apache.commons.math3.analysis.integration.SimpsonIntegrator) JavaRDD(org.apache.spark.api.java.JavaRDD) GATKProtectedMathUtils(org.broadinstitute.hellbender.utils.GATKProtectedMathUtils) Pair(org.apache.commons.math3.util.Pair) Collectors(java.util.stream.Collectors) BrentOptimizer(org.apache.commons.math3.optim.univariate.BrentOptimizer) Serializable(java.io.Serializable) List(java.util.List) Percentile(org.apache.commons.math3.stat.descriptive.rank.Percentile) Logger(org.apache.logging.log4j.Logger) MathUtils(org.broadinstitute.hellbender.utils.MathUtils) NormalDistribution(org.apache.commons.math3.distribution.NormalDistribution) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) Variance(org.apache.commons.math3.stat.descriptive.moment.Variance) Utils(org.broadinstitute.hellbender.utils.Utils) RealMatrix(org.apache.commons.math3.linear.RealMatrix) VisibleForTesting(com.google.common.annotations.VisibleForTesting) HomoSapiensConstants(org.broadinstitute.hellbender.utils.variant.HomoSapiensConstants) MaxEval(org.apache.commons.math3.optim.MaxEval) LogManager(org.apache.logging.log4j.LogManager) Function(java.util.function.Function) UnivariateObjectiveFunction(org.apache.commons.math3.optim.univariate.UnivariateObjectiveFunction) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) SearchInterval(org.apache.commons.math3.optim.univariate.SearchInterval) Pair(org.apache.commons.math3.util.Pair)

Example 27 with JavaRDD

use of org.apache.spark.api.java.JavaRDD in project gatk by broadinstitute.

the class HaplotypeCallerSpark method writeVariants.

/**
     * WriteVariants, this is currently going to be horribly slow and explosive on a full size file since it performs a collect.
     *
     * This will be replaced by a parallel writer similar to what's done with {@link org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSink}
     */
private void writeVariants(JavaRDD<VariantContext> variants) {
    final List<VariantContext> collectedVariants = variants.collect();
    final SAMSequenceDictionary referenceDictionary = getReferenceSequenceDictionary();
    final List<VariantContext> sortedVariants = collectedVariants.stream().sorted((o1, o2) -> IntervalUtils.compareLocatables(o1, o2, referenceDictionary)).collect(Collectors.toList());
    final HaplotypeCallerEngine hcEngine = new HaplotypeCallerEngine(hcArgs, getHeaderForReads(), new ReferenceMultiSourceAdapter(getReference(), getAuthHolder()));
    try (final VariantContextWriter writer = hcEngine.makeVCFWriter(output, getBestAvailableSequenceDictionary())) {
        hcEngine.writeHeader(writer, getHeaderForReads().getSequenceDictionary(), Collections.emptySet());
        sortedVariants.forEach(writer::add);
    }
}
Also used : CommandLineProgramProperties(org.broadinstitute.barclay.argparser.CommandLineProgramProperties) SparkProgramGroup(org.broadinstitute.hellbender.cmdline.programgroups.SparkProgramGroup) ReferenceMultiSource(org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource) Argument(org.broadinstitute.barclay.argparser.Argument) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Advanced(org.broadinstitute.barclay.argparser.Advanced) org.broadinstitute.hellbender.cmdline(org.broadinstitute.hellbender.cmdline) ArgumentCollection(org.broadinstitute.barclay.argparser.ArgumentCollection) GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) SAMFileHeader(htsjdk.samtools.SAMFileHeader) GATKException(org.broadinstitute.hellbender.exceptions.GATKException) Function(java.util.function.Function) ReferenceSequenceFile(htsjdk.samtools.reference.ReferenceSequenceFile) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) ReferenceBases(org.broadinstitute.hellbender.utils.reference.ReferenceBases) HaplotypeCallerArgumentCollection(org.broadinstitute.hellbender.tools.walkers.haplotypecaller.HaplotypeCallerArgumentCollection) SparkReadShard(org.broadinstitute.hellbender.engine.spark.SparkReadShard) StreamSupport(java.util.stream.StreamSupport) JavaRDD(org.apache.spark.api.java.JavaRDD) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) org.broadinstitute.barclay.argparser(org.broadinstitute.barclay.argparser) Broadcast(org.apache.spark.broadcast.Broadcast) OverlapDetector(htsjdk.samtools.util.OverlapDetector) Iterator(java.util.Iterator) SAMSequenceDictionary(htsjdk.samtools.SAMSequenceDictionary) Collection(java.util.Collection) GATKSparkTool(org.broadinstitute.hellbender.engine.spark.GATKSparkTool) IOException(java.io.IOException) ReadFilter(org.broadinstitute.hellbender.engine.filters.ReadFilter) Tuple2(scala.Tuple2) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) HaplotypeCaller(org.broadinstitute.hellbender.tools.walkers.haplotypecaller.HaplotypeCaller) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval) Collectors(java.util.stream.Collectors) HaplotypeCallerEngine(org.broadinstitute.hellbender.tools.walkers.haplotypecaller.HaplotypeCallerEngine) Serializable(java.io.Serializable) org.broadinstitute.hellbender.engine(org.broadinstitute.hellbender.engine) List(java.util.List) Stream(java.util.stream.Stream) UserException(org.broadinstitute.hellbender.exceptions.UserException) VariantContextWriter(htsjdk.variant.variantcontext.writer.VariantContextWriter) VariantContext(htsjdk.variant.variantcontext.VariantContext) Utils(org.broadinstitute.hellbender.utils.Utils) ReferenceSequence(htsjdk.samtools.reference.ReferenceSequence) VisibleForTesting(com.google.common.annotations.VisibleForTesting) Collections(java.util.Collections) HaplotypeCallerEngine(org.broadinstitute.hellbender.tools.walkers.haplotypecaller.HaplotypeCallerEngine) VariantContext(htsjdk.variant.variantcontext.VariantContext) VariantContextWriter(htsjdk.variant.variantcontext.writer.VariantContextWriter) SAMSequenceDictionary(htsjdk.samtools.SAMSequenceDictionary)

Example 28 with JavaRDD

use of org.apache.spark.api.java.JavaRDD in project gatk by broadinstitute.

the class ReadWalkerSpark method getReads.

/**
     * Loads reads and the corresponding reference and features into a {@link JavaRDD} for the intervals specified.
     *
     * If no intervals were specified, returns all the reads.
     *
     * @return all reads as a {@link JavaRDD}, bounded by intervals if specified.
     */
public JavaRDD<ReadWalkerContext> getReads(JavaSparkContext ctx) {
    SAMSequenceDictionary sequenceDictionary = getBestAvailableSequenceDictionary();
    List<SimpleInterval> intervals = hasIntervals() ? getIntervals() : IntervalUtils.getAllIntervalsForReference(sequenceDictionary);
    // use unpadded shards (padding is only needed for reference bases)
    final List<ShardBoundary> intervalShards = intervals.stream().flatMap(interval -> Shard.divideIntervalIntoShards(interval, readShardSize, 0, sequenceDictionary).stream()).collect(Collectors.toList());
    JavaRDD<Shard<GATKRead>> shardedReads = SparkSharder.shard(ctx, getReads(), GATKRead.class, sequenceDictionary, intervalShards, readShardSize, shuffle);
    Broadcast<ReferenceMultiSource> bReferenceSource = hasReference() ? ctx.broadcast(getReference()) : null;
    Broadcast<FeatureManager> bFeatureManager = features == null ? null : ctx.broadcast(features);
    return shardedReads.flatMap(getReadsFunction(bReferenceSource, bFeatureManager, sequenceDictionary, readShardPadding));
}
Also used : Broadcast(org.apache.spark.broadcast.Broadcast) ReferenceMultiSource(org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource) SAMSequenceDictionary(htsjdk.samtools.SAMSequenceDictionary) Argument(org.broadinstitute.barclay.argparser.Argument) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval) Collectors(java.util.stream.Collectors) org.broadinstitute.hellbender.engine(org.broadinstitute.hellbender.engine) List(java.util.List) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) StreamSupport(java.util.stream.StreamSupport) JavaRDD(org.apache.spark.api.java.JavaRDD) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) ReferenceMultiSource(org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval) SAMSequenceDictionary(htsjdk.samtools.SAMSequenceDictionary)

Example 29 with JavaRDD

use of org.apache.spark.api.java.JavaRDD in project incubator-systemml by apache.

the class WriteSPInstruction method processMatrixWriteInstruction.

protected void processMatrixWriteInstruction(SparkExecutionContext sec, String fname, OutputInfo oi) throws IOException {
    // get input rdd
    JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable(input1.getName());
    MatrixCharacteristics mc = sec.getMatrixCharacteristics(input1.getName());
    if (oi == OutputInfo.MatrixMarketOutputInfo || oi == OutputInfo.TextCellOutputInfo) {
        // piggyback nnz maintenance on write
        LongAccumulator aNnz = null;
        if (!mc.nnzKnown()) {
            aNnz = sec.getSparkContext().sc().longAccumulator("nnz");
            in1 = in1.mapValues(new ComputeBinaryBlockNnzFunction(aNnz));
        }
        JavaRDD<String> header = null;
        if (oi == OutputInfo.MatrixMarketOutputInfo) {
            ArrayList<String> headerContainer = new ArrayList<>(1);
            // First output MM header
            String headerStr = "%%MatrixMarket matrix coordinate real general\n" + // output number of rows, number of columns and number of nnz
            mc.getRows() + " " + mc.getCols() + " " + mc.getNonZeros();
            headerContainer.add(headerStr);
            header = sec.getSparkContext().parallelize(headerContainer);
        }
        JavaRDD<String> ijv = RDDConverterUtils.binaryBlockToTextCell(in1, mc);
        if (header != null)
            customSaveTextFile(header.union(ijv), fname, true);
        else
            customSaveTextFile(ijv, fname, false);
        if (!mc.nnzKnown())
            mc.setNonZeros(aNnz.value());
    } else if (oi == OutputInfo.CSVOutputInfo) {
        if (mc.getRows() == 0 || mc.getCols() == 0) {
            throw new IOException("Write of matrices with zero rows or columns" + " not supported (" + mc.getRows() + "x" + mc.getCols() + ").");
        }
        LongAccumulator aNnz = null;
        // piggyback nnz computation on actual write
        if (!mc.nnzKnown()) {
            aNnz = sec.getSparkContext().sc().longAccumulator("nnz");
            in1 = in1.mapValues(new ComputeBinaryBlockNnzFunction(aNnz));
        }
        JavaRDD<String> out = RDDConverterUtils.binaryBlockToCsv(in1, mc, (CSVFileFormatProperties) formatProperties, true);
        customSaveTextFile(out, fname, false);
        if (!mc.nnzKnown())
            mc.setNonZeros((long) aNnz.value().longValue());
    } else if (oi == OutputInfo.BinaryBlockOutputInfo) {
        // piggyback nnz computation on actual write
        LongAccumulator aNnz = null;
        if (!mc.nnzKnown()) {
            aNnz = sec.getSparkContext().sc().longAccumulator("nnz");
            in1 = in1.mapValues(new ComputeBinaryBlockNnzFunction(aNnz));
        }
        // save binary block rdd on hdfs
        in1.saveAsHadoopFile(fname, MatrixIndexes.class, MatrixBlock.class, SequenceFileOutputFormat.class);
        if (!mc.nnzKnown())
            mc.setNonZeros((long) aNnz.value().longValue());
    } else {
        // unsupported formats: binarycell (not externalized)
        throw new DMLRuntimeException("Unexpected data format: " + OutputInfo.outputInfoToString(oi));
    }
    // write meta data file
    MapReduceTool.writeMetaDataFile(fname + ".mtd", ValueType.DOUBLE, mc, oi, formatProperties);
}
Also used : MatrixBlock(org.apache.sysml.runtime.matrix.data.MatrixBlock) CSVFileFormatProperties(org.apache.sysml.runtime.matrix.data.CSVFileFormatProperties) MatrixIndexes(org.apache.sysml.runtime.matrix.data.MatrixIndexes) ArrayList(java.util.ArrayList) IOException(java.io.IOException) ComputeBinaryBlockNnzFunction(org.apache.sysml.runtime.instructions.spark.functions.ComputeBinaryBlockNnzFunction) MatrixCharacteristics(org.apache.sysml.runtime.matrix.MatrixCharacteristics) JavaRDD(org.apache.spark.api.java.JavaRDD) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException) LongAccumulator(org.apache.spark.util.LongAccumulator)

Example 30 with JavaRDD

use of org.apache.spark.api.java.JavaRDD in project incubator-systemml by apache.

the class MLContextTest method testOutputJavaRDDStringIJVFromMatrixDML.

@Test
public void testOutputJavaRDDStringIJVFromMatrixDML() {
    System.out.println("MLContextTest - output Java RDD String IJV from matrix DML");
    String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
    Script script = dml(s).out("M");
    MLResults results = ml.execute(script);
    JavaRDD<String> javaRDDStringIJV = results.getJavaRDDStringIJV("M");
    List<String> lines = javaRDDStringIJV.sortBy(row -> row, true, 1).collect();
    Assert.assertEquals("1 1 1.0", lines.get(0));
    Assert.assertEquals("1 2 2.0", lines.get(1));
    Assert.assertEquals("2 1 3.0", lines.get(2));
    Assert.assertEquals("2 2 4.0", lines.get(3));
}
Also used : MatrixFormat(org.apache.sysml.api.mlcontext.MatrixFormat) Arrays(java.util.Arrays) URL(java.net.URL) MatrixObject(org.apache.sysml.runtime.controlprogram.caching.MatrixObject) MatrixIndexes(org.apache.sysml.runtime.matrix.data.MatrixIndexes) MatrixBlock(org.apache.sysml.runtime.matrix.data.MatrixBlock) Map(java.util.Map) DoubleType(org.apache.spark.sql.types.DoubleType) ScriptFactory.dml(org.apache.sysml.api.mlcontext.ScriptFactory.dml) ScriptFactory.pydmlFromUrl(org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromUrl) DataTypes(org.apache.spark.sql.types.DataTypes) StructField(org.apache.spark.sql.types.StructField) StructType(org.apache.spark.sql.types.StructType) Vector(org.apache.spark.ml.linalg.Vector) DenseVector(org.apache.spark.ml.linalg.DenseVector) MLContextConversionUtil(org.apache.sysml.api.mlcontext.MLContextConversionUtil) Seq(scala.collection.Seq) Tuple4(scala.Tuple4) ScriptFactory.dmlFromLocalFile(org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromLocalFile) Tuple2(scala.Tuple2) Collectors(java.util.stream.Collectors) Tuple3(scala.Tuple3) Tuple1(scala.Tuple1) List(java.util.List) Stream(java.util.stream.Stream) ScriptFactory.pydmlFromFile(org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromFile) MLResults(org.apache.sysml.api.mlcontext.MLResults) Vectors(org.apache.spark.ml.linalg.Vectors) MatrixCharacteristics(org.apache.sysml.runtime.matrix.MatrixCharacteristics) Function(org.apache.spark.api.java.function.Function) ScriptFactory.dmlFromFile(org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromFile) RDD(org.apache.spark.rdd.RDD) ScriptFactory.dmlFromInputStream(org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromInputStream) Dataset(org.apache.spark.sql.Dataset) MLContextException(org.apache.sysml.api.mlcontext.MLContextException) MatrixMetadata(org.apache.sysml.api.mlcontext.MatrixMetadata) RDDConverterUtils(org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils) MLContextUtil(org.apache.sysml.api.mlcontext.MLContextUtil) HashMap(java.util.HashMap) DataConverter(org.apache.sysml.runtime.util.DataConverter) ArrayList(java.util.ArrayList) Script(org.apache.sysml.api.mlcontext.Script) Statistics(org.apache.sysml.utils.Statistics) VectorUDT(org.apache.spark.ml.linalg.VectorUDT) ScriptFactory.pydmlFromLocalFile(org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromLocalFile) JavaRDD(org.apache.spark.api.java.JavaRDD) JavaConversions(scala.collection.JavaConversions) MalformedURLException(java.net.MalformedURLException) RowFactory(org.apache.spark.sql.RowFactory) Iterator(scala.collection.Iterator) Matrix(org.apache.sysml.api.mlcontext.Matrix) Assert.assertTrue(org.junit.Assert.assertTrue) IOException(java.io.IOException) Test(org.junit.Test) FileInputStream(java.io.FileInputStream) Row(org.apache.spark.sql.Row) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) File(java.io.File) ScriptFactory.pydml(org.apache.sysml.api.mlcontext.ScriptFactory.pydml) ScriptFactory.pydmlFromInputStream(org.apache.sysml.api.mlcontext.ScriptFactory.pydmlFromInputStream) ScriptFactory.dmlFromUrl(org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromUrl) ScriptExecutor(org.apache.sysml.api.mlcontext.ScriptExecutor) Assert(org.junit.Assert) InputStream(java.io.InputStream) Script(org.apache.sysml.api.mlcontext.Script) MLResults(org.apache.sysml.api.mlcontext.MLResults) Test(org.junit.Test)

Aggregations

JavaRDD (org.apache.spark.api.java.JavaRDD)63 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)33 List (java.util.List)24 GATKRead (org.broadinstitute.hellbender.utils.read.GATKRead)24 Collectors (java.util.stream.Collectors)20 JavaPairRDD (org.apache.spark.api.java.JavaPairRDD)20 Tuple2 (scala.Tuple2)20 Argument (org.broadinstitute.barclay.argparser.Argument)17 Broadcast (org.apache.spark.broadcast.Broadcast)15 SimpleInterval (org.broadinstitute.hellbender.utils.SimpleInterval)15 SAMFileHeader (htsjdk.samtools.SAMFileHeader)14 SAMSequenceDictionary (htsjdk.samtools.SAMSequenceDictionary)14 IOException (java.io.IOException)14 UserException (org.broadinstitute.hellbender.exceptions.UserException)14 CommandLineProgramProperties (org.broadinstitute.barclay.argparser.CommandLineProgramProperties)13 GATKSparkTool (org.broadinstitute.hellbender.engine.spark.GATKSparkTool)13 Serializable (java.io.Serializable)12 IntervalUtils (org.broadinstitute.hellbender.utils.IntervalUtils)12 java.util (java.util)11 ArrayList (java.util.ArrayList)11