Search in sources :

Example 61 with JavaPairRDD

use of org.apache.spark.api.java.JavaPairRDD in project gatk by broadinstitute.

the class AddContextDataToReadSpark method addUsingOverlapsPartitioning.

/**
     * Add context data ({@link ReadContextData}) to reads, using overlaps partitioning to avoid a shuffle.
     * @param ctx the Spark context
     * @param mappedReads the coordinate-sorted reads
     * @param referenceSource the reference source
     * @param variants the coordinate-sorted variants
     * @param sequenceDictionary the sequence dictionary for the reads
     * @param shardSize the maximum size of each shard, in bases
     * @param shardPadding amount of extra context around each shard, in bases
     * @return a RDD of read-context pairs, in coordinate-sorted order
     */
private static JavaPairRDD<GATKRead, ReadContextData> addUsingOverlapsPartitioning(final JavaSparkContext ctx, final JavaRDD<GATKRead> mappedReads, final ReferenceMultiSource referenceSource, final JavaRDD<GATKVariant> variants, final SAMSequenceDictionary sequenceDictionary, final int shardSize, final int shardPadding) {
    final List<SimpleInterval> intervals = IntervalUtils.getAllIntervalsForReference(sequenceDictionary);
    // use unpadded shards (padding is only needed for reference bases)
    final List<ShardBoundary> intervalShards = intervals.stream().flatMap(interval -> Shard.divideIntervalIntoShards(interval, shardSize, 0, sequenceDictionary).stream()).collect(Collectors.toList());
    final Broadcast<ReferenceMultiSource> bReferenceSource = ctx.broadcast(referenceSource);
    final IntervalsSkipList<GATKVariant> variantSkipList = new IntervalsSkipList<>(variants.collect());
    final Broadcast<IntervalsSkipList<GATKVariant>> variantsBroadcast = ctx.broadcast(variantSkipList);
    int maxLocatableSize = Math.min(shardSize, shardPadding);
    JavaRDD<Shard<GATKRead>> shardedReads = SparkSharder.shard(ctx, mappedReads, GATKRead.class, sequenceDictionary, intervalShards, maxLocatableSize);
    return shardedReads.flatMapToPair(new PairFlatMapFunction<Shard<GATKRead>, GATKRead, ReadContextData>() {

        private static final long serialVersionUID = 1L;

        @Override
        public Iterator<Tuple2<GATKRead, ReadContextData>> call(Shard<GATKRead> shard) throws Exception {
            // get reference bases for this shard (padded)
            SimpleInterval paddedInterval = shard.getInterval().expandWithinContig(shardPadding, sequenceDictionary);
            ReferenceBases referenceBases = bReferenceSource.getValue().getReferenceBases(null, paddedInterval);
            final IntervalsSkipList<GATKVariant> intervalsSkipList = variantsBroadcast.getValue();
            Iterator<Tuple2<GATKRead, ReadContextData>> transform = Iterators.transform(shard.iterator(), new Function<GATKRead, Tuple2<GATKRead, ReadContextData>>() {

                @Nullable
                @Override
                public Tuple2<GATKRead, ReadContextData> apply(@Nullable GATKRead r) {
                    List<GATKVariant> overlappingVariants;
                    if (SimpleInterval.isValid(r.getContig(), r.getStart(), r.getEnd())) {
                        overlappingVariants = intervalsSkipList.getOverlapping(new SimpleInterval(r));
                    } else {
                        //Sometimes we have reads that do not form valid intervals (reads that do not consume any ref bases, eg CIGAR 61S90I
                        //In those cases, we'll just say that nothing overlaps the read
                        overlappingVariants = Collections.emptyList();
                    }
                    return new Tuple2<>(r, new ReadContextData(referenceBases, overlappingVariants));
                }
            });
            // only include reads that start in the shard
            return Iterators.filter(transform, r -> r._1().getStart() >= shard.getStart() && r._1().getStart() <= shard.getEnd());
        }
    });
}
Also used : PairFlatMapFunction(org.apache.spark.api.java.function.PairFlatMapFunction) ReferenceMultiSource(org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) Iterators(com.google.common.collect.Iterators) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) ReferenceBases(org.broadinstitute.hellbender.utils.reference.ReferenceBases) ReadContextData(org.broadinstitute.hellbender.engine.ReadContextData) JavaRDD(org.apache.spark.api.java.JavaRDD) Nullable(javax.annotation.Nullable) Broadcast(org.apache.spark.broadcast.Broadcast) IntervalsSkipList(org.broadinstitute.hellbender.utils.collections.IntervalsSkipList) Function(com.google.common.base.Function) Iterator(java.util.Iterator) SAMSequenceDictionary(htsjdk.samtools.SAMSequenceDictionary) GATKVariant(org.broadinstitute.hellbender.utils.variant.GATKVariant) Tuple2(scala.Tuple2) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval) Collectors(java.util.stream.Collectors) Shard(org.broadinstitute.hellbender.engine.Shard) List(java.util.List) UserException(org.broadinstitute.hellbender.exceptions.UserException) ShardBoundary(org.broadinstitute.hellbender.engine.ShardBoundary) Collections(java.util.Collections) ReadFilterLibrary(org.broadinstitute.hellbender.engine.filters.ReadFilterLibrary) GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) ShardBoundary(org.broadinstitute.hellbender.engine.ShardBoundary) PairFlatMapFunction(org.apache.spark.api.java.function.PairFlatMapFunction) Function(com.google.common.base.Function) IntervalsSkipList(org.broadinstitute.hellbender.utils.collections.IntervalsSkipList) Iterator(java.util.Iterator) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval) GATKVariant(org.broadinstitute.hellbender.utils.variant.GATKVariant) ReferenceMultiSource(org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource) UserException(org.broadinstitute.hellbender.exceptions.UserException) ReadContextData(org.broadinstitute.hellbender.engine.ReadContextData) ReferenceBases(org.broadinstitute.hellbender.utils.reference.ReferenceBases) Tuple2(scala.Tuple2) Shard(org.broadinstitute.hellbender.engine.Shard) Nullable(javax.annotation.Nullable)

Example 62 with JavaPairRDD

use of org.apache.spark.api.java.JavaPairRDD in project gatk by broadinstitute.

the class CompareDuplicatesSpark method runTool.

@Override
protected void runTool(final JavaSparkContext ctx) {
    JavaRDD<GATKRead> firstReads = filteredReads(getReads(), readArguments.getReadFilesNames().get(0));
    ReadsSparkSource readsSource2 = new ReadsSparkSource(ctx, readArguments.getReadValidationStringency());
    JavaRDD<GATKRead> secondReads = filteredReads(readsSource2.getParallelReads(input2, null, getIntervals(), bamPartitionSplitSize), input2);
    // Start by verifying that we have same number of reads and duplicates in each BAM.
    long firstBamSize = firstReads.count();
    long secondBamSize = secondReads.count();
    if (firstBamSize != secondBamSize) {
        throw new UserException("input bams have different numbers of mapped reads: " + firstBamSize + "," + secondBamSize);
    }
    System.out.println("processing bams with " + firstBamSize + " mapped reads");
    long firstDupesCount = firstReads.filter(GATKRead::isDuplicate).count();
    long secondDupesCount = secondReads.filter(GATKRead::isDuplicate).count();
    if (firstDupesCount != secondDupesCount) {
        System.out.println("BAMs have different number of total duplicates: " + firstDupesCount + "," + secondDupesCount);
    }
    System.out.println("first and second: " + firstDupesCount + "," + secondDupesCount);
    Broadcast<SAMFileHeader> bHeader = ctx.broadcast(getHeaderForReads());
    // Group the reads of each BAM by MarkDuplicates key, then pair up the the reads for each BAM.
    JavaPairRDD<String, GATKRead> firstKeyed = firstReads.mapToPair(read -> new Tuple2<>(ReadsKey.keyForFragment(bHeader.getValue(), read), read));
    JavaPairRDD<String, GATKRead> secondKeyed = secondReads.mapToPair(read -> new Tuple2<>(ReadsKey.keyForFragment(bHeader.getValue(), read), read));
    JavaPairRDD<String, Tuple2<Iterable<GATKRead>, Iterable<GATKRead>>> cogroup = firstKeyed.cogroup(secondKeyed, getRecommendedNumReducers());
    // Produces an RDD of MatchTypes, e.g., EQUAL, DIFFERENT_REPRESENTATIVE_READ, etc. per MarkDuplicates key,
    // which is approximately start position x strand.
    JavaRDD<MatchType> tagged = cogroup.map(v1 -> {
        SAMFileHeader header = bHeader.getValue();
        Iterable<GATKRead> iFirstReads = v1._2()._1();
        Iterable<GATKRead> iSecondReads = v1._2()._2();
        return getDupes(iFirstReads, iSecondReads, header);
    });
    // TODO: We should also produce examples of reads that don't match to make debugging easier (#1263).
    Map<MatchType, Integer> tagCountMap = tagged.mapToPair(v1 -> new Tuple2<>(v1, 1)).reduceByKey((v1, v2) -> v1 + v2).collectAsMap();
    if (tagCountMap.get(MatchType.SIZE_UNEQUAL) != null) {
        throw new UserException("The number of reads by the MarkDuplicates key were unequal, indicating that the BAMs are not the same");
    }
    if (tagCountMap.get(MatchType.READ_MISMATCH) != null) {
        throw new UserException("The reads grouped by the MarkDuplicates key were not the same, indicating that the BAMs are not the same");
    }
    if (printSummary) {
        MatchType[] values = MatchType.values();
        Set<MatchType> matchTypes = Sets.newLinkedHashSet(Sets.newHashSet(values));
        System.out.println("##############################");
        matchTypes.forEach(s -> System.out.println(s + ": " + tagCountMap.getOrDefault(s, 0)));
    }
    if (throwOnDiff) {
        for (MatchType s : MatchType.values()) {
            if (s != MatchType.EQUAL) {
                if (tagCountMap.get(s) != null)
                    throw new UserException("found difference between the two BAMs: " + s + " with count " + tagCountMap.get(s));
            }
        }
    }
}
Also used : GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) Broadcast(org.apache.spark.broadcast.Broadcast) CommandLineProgramProperties(org.broadinstitute.barclay.argparser.CommandLineProgramProperties) TestSparkProgramGroup(org.broadinstitute.hellbender.cmdline.programgroups.TestSparkProgramGroup) Argument(org.broadinstitute.barclay.argparser.Argument) ReadsKey(org.broadinstitute.hellbender.utils.read.markduplicates.ReadsKey) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) GATKSparkTool(org.broadinstitute.hellbender.engine.spark.GATKSparkTool) ReadCoordinateComparator(org.broadinstitute.hellbender.utils.read.ReadCoordinateComparator) Set(java.util.Set) GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) Tuple2(scala.Tuple2) SAMFileHeader(htsjdk.samtools.SAMFileHeader) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) GATKException(org.broadinstitute.hellbender.exceptions.GATKException) Sets(com.google.common.collect.Sets) ReadUtils(org.broadinstitute.hellbender.utils.read.ReadUtils) List(java.util.List) Lists(com.google.common.collect.Lists) ReadsSparkSource(org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSource) UserException(org.broadinstitute.hellbender.exceptions.UserException) Map(java.util.Map) Function(org.apache.spark.api.java.function.Function) JavaRDD(org.apache.spark.api.java.JavaRDD) ReadsSparkSource(org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSource) Tuple2(scala.Tuple2) UserException(org.broadinstitute.hellbender.exceptions.UserException) SAMFileHeader(htsjdk.samtools.SAMFileHeader)

Example 63 with JavaPairRDD

use of org.apache.spark.api.java.JavaPairRDD in project grakn by graknlabs.

the class GraknSparkComputer method submitWithExecutor.

@SuppressWarnings("PMD.UnusedFormalParameter")
private Future<ComputerResult> submitWithExecutor() {
    jobGroupId = Integer.toString(ThreadLocalRandom.current().nextInt(Integer.MAX_VALUE));
    String jobDescription = this.vertexProgram == null ? this.mapReducers.toString() : this.vertexProgram + "+" + this.mapReducers;
    // Use different output locations
    this.sparkConfiguration.setProperty(Constants.GREMLIN_HADOOP_OUTPUT_LOCATION, this.sparkConfiguration.getString(Constants.GREMLIN_HADOOP_OUTPUT_LOCATION) + "/" + jobGroupId);
    updateConfigKeys(sparkConfiguration);
    final Future<ComputerResult> result = computerService.submit(() -> {
        final long startTime = System.currentTimeMillis();
        // apache and hadoop configurations that are used throughout the graph computer computation
        final org.apache.commons.configuration.Configuration graphComputerConfiguration = new HadoopConfiguration(this.sparkConfiguration);
        if (!graphComputerConfiguration.containsKey(Constants.SPARK_SERIALIZER)) {
            graphComputerConfiguration.setProperty(Constants.SPARK_SERIALIZER, GryoSerializer.class.getCanonicalName());
        }
        graphComputerConfiguration.setProperty(Constants.GREMLIN_HADOOP_GRAPH_WRITER_HAS_EDGES, this.persist.equals(GraphComputer.Persist.EDGES));
        final Configuration hadoopConfiguration = ConfUtil.makeHadoopConfiguration(graphComputerConfiguration);
        final Storage fileSystemStorage = FileSystemStorage.open(hadoopConfiguration);
        final boolean inputFromHDFS = FileInputFormat.class.isAssignableFrom(hadoopConfiguration.getClass(Constants.GREMLIN_HADOOP_GRAPH_READER, Object.class));
        final boolean inputFromSpark = PersistedInputRDD.class.isAssignableFrom(hadoopConfiguration.getClass(Constants.GREMLIN_HADOOP_GRAPH_READER, Object.class));
        final boolean outputToHDFS = FileOutputFormat.class.isAssignableFrom(hadoopConfiguration.getClass(Constants.GREMLIN_HADOOP_GRAPH_WRITER, Object.class));
        final boolean outputToSpark = PersistedOutputRDD.class.isAssignableFrom(hadoopConfiguration.getClass(Constants.GREMLIN_HADOOP_GRAPH_WRITER, Object.class));
        final boolean skipPartitioner = graphComputerConfiguration.getBoolean(Constants.GREMLIN_SPARK_SKIP_PARTITIONER, false);
        final boolean skipPersist = graphComputerConfiguration.getBoolean(Constants.GREMLIN_SPARK_SKIP_GRAPH_CACHE, false);
        if (inputFromHDFS) {
            String inputLocation = Constants.getSearchGraphLocation(hadoopConfiguration.get(Constants.GREMLIN_HADOOP_INPUT_LOCATION), fileSystemStorage).orElse(null);
            if (null != inputLocation) {
                try {
                    graphComputerConfiguration.setProperty(Constants.MAPREDUCE_INPUT_FILEINPUTFORMAT_INPUTDIR, FileSystem.get(hadoopConfiguration).getFileStatus(new Path(inputLocation)).getPath().toString());
                    hadoopConfiguration.set(Constants.MAPREDUCE_INPUT_FILEINPUTFORMAT_INPUTDIR, FileSystem.get(hadoopConfiguration).getFileStatus(new Path(inputLocation)).getPath().toString());
                } catch (final IOException e) {
                    throw new IllegalStateException(e.getMessage(), e);
                }
            }
        }
        final InputRDD inputRDD;
        final OutputRDD outputRDD;
        final boolean filtered;
        try {
            inputRDD = InputRDD.class.isAssignableFrom(hadoopConfiguration.getClass(Constants.GREMLIN_HADOOP_GRAPH_READER, Object.class)) ? hadoopConfiguration.getClass(Constants.GREMLIN_HADOOP_GRAPH_READER, InputRDD.class, InputRDD.class).newInstance() : InputFormatRDD.class.newInstance();
            outputRDD = OutputRDD.class.isAssignableFrom(hadoopConfiguration.getClass(Constants.GREMLIN_HADOOP_GRAPH_WRITER, Object.class)) ? hadoopConfiguration.getClass(Constants.GREMLIN_HADOOP_GRAPH_WRITER, OutputRDD.class, OutputRDD.class).newInstance() : OutputFormatRDD.class.newInstance();
            // if the input class can filter on load, then set the filters
            if (inputRDD instanceof InputFormatRDD && GraphFilterAware.class.isAssignableFrom(hadoopConfiguration.getClass(Constants.GREMLIN_HADOOP_GRAPH_READER, InputFormat.class, InputFormat.class))) {
                GraphFilterAware.storeGraphFilter(graphComputerConfiguration, hadoopConfiguration, this.graphFilter);
                filtered = false;
            } else if (inputRDD instanceof GraphFilterAware) {
                ((GraphFilterAware) inputRDD).setGraphFilter(this.graphFilter);
                filtered = false;
            } else
                filtered = this.graphFilter.hasFilter();
        } catch (final InstantiationException | IllegalAccessException e) {
            throw new IllegalStateException(e.getMessage(), e);
        }
        // create the spark context from the graph computer configuration
        final JavaSparkContext sparkContext = new JavaSparkContext(Spark.create(hadoopConfiguration));
        final Storage sparkContextStorage = SparkContextStorage.open();
        sparkContext.setJobGroup(jobGroupId, jobDescription);
        GraknSparkMemory memory = null;
        // delete output location
        final String outputLocation = hadoopConfiguration.get(Constants.GREMLIN_HADOOP_OUTPUT_LOCATION, null);
        if (null != outputLocation) {
            if (outputToHDFS && fileSystemStorage.exists(outputLocation)) {
                fileSystemStorage.rm(outputLocation);
            }
            if (outputToSpark && sparkContextStorage.exists(outputLocation)) {
                sparkContextStorage.rm(outputLocation);
            }
        }
        // the Spark application name will always be set by SparkContextStorage,
        // thus, INFO the name to make it easier to debug
        logger.debug(Constants.GREMLIN_HADOOP_SPARK_JOB_PREFIX + (null == this.vertexProgram ? "No VertexProgram" : this.vertexProgram) + "[" + this.mapReducers + "]");
        // add the project jars to the cluster
        this.loadJars(hadoopConfiguration, sparkContext);
        updateLocalConfiguration(sparkContext, hadoopConfiguration);
        // create a message-passing friendly rdd from the input rdd
        boolean partitioned = false;
        JavaPairRDD<Object, VertexWritable> loadedGraphRDD = inputRDD.readGraphRDD(graphComputerConfiguration, sparkContext);
        // if there are vertex or edge filters, filter the loaded graph rdd prior to partitioning and persisting
        if (filtered) {
            this.logger.debug("Filtering the loaded graphRDD: " + this.graphFilter);
            loadedGraphRDD = GraknSparkExecutor.applyGraphFilter(loadedGraphRDD, this.graphFilter);
        }
        // else partition it with HashPartitioner
        if (loadedGraphRDD.partitioner().isPresent()) {
            this.logger.debug("Using the existing partitioner associated with the loaded graphRDD: " + loadedGraphRDD.partitioner().get());
        } else {
            if (!skipPartitioner) {
                final Partitioner partitioner = new HashPartitioner(this.workersSet ? this.workers : loadedGraphRDD.partitions().size());
                this.logger.debug("Partitioning the loaded graphRDD: " + partitioner);
                loadedGraphRDD = loadedGraphRDD.partitionBy(partitioner);
                partitioned = true;
                assert loadedGraphRDD.partitioner().isPresent();
            } else {
                // no easy way to test this with a test case
                assert skipPartitioner == !loadedGraphRDD.partitioner().isPresent();
                this.logger.debug("Partitioning has been skipped for the loaded graphRDD via " + Constants.GREMLIN_SPARK_SKIP_PARTITIONER);
            }
        }
        // then this coalesce/repartition will not take place
        if (this.workersSet) {
            // ensures that the loaded graphRDD does not have more partitions than workers
            if (loadedGraphRDD.partitions().size() > this.workers) {
                loadedGraphRDD = loadedGraphRDD.coalesce(this.workers);
            } else {
                // ensures that the loaded graphRDD does not have less partitions than workers
                if (loadedGraphRDD.partitions().size() < this.workers) {
                    loadedGraphRDD = loadedGraphRDD.repartition(this.workers);
                }
            }
        }
        // or else use default cache() which is MEMORY_ONLY
        if (!skipPersist && (!inputFromSpark || partitioned || filtered)) {
            loadedGraphRDD = loadedGraphRDD.persist(StorageLevel.fromString(hadoopConfiguration.get(Constants.GREMLIN_SPARK_GRAPH_STORAGE_LEVEL, "MEMORY_ONLY")));
        }
        // final graph with view
        // (for persisting and/or mapReducing -- may be null and thus, possible to save space/time)
        JavaPairRDD<Object, VertexWritable> computedGraphRDD = null;
        try {
            // //////////////////////////////
            if (null != this.vertexProgram) {
                memory = new GraknSparkMemory(this.vertexProgram, this.mapReducers, sparkContext);
                // if there is a registered VertexProgramInterceptor, use it to bypass the GraphComputer semantics
                if (graphComputerConfiguration.containsKey(Constants.GREMLIN_HADOOP_VERTEX_PROGRAM_INTERCEPTOR)) {
                    try {
                        final GraknSparkVertexProgramInterceptor<VertexProgram> interceptor = (GraknSparkVertexProgramInterceptor) Class.forName(graphComputerConfiguration.getString(Constants.GREMLIN_HADOOP_VERTEX_PROGRAM_INTERCEPTOR)).newInstance();
                        computedGraphRDD = interceptor.apply(this.vertexProgram, loadedGraphRDD, memory);
                    } catch (final ClassNotFoundException | IllegalAccessException | InstantiationException e) {
                        throw new IllegalStateException(e.getMessage());
                    }
                } else {
                    // standard GraphComputer semantics
                    // get a configuration that will be propagated to all workers
                    final HadoopConfiguration vertexProgramConfiguration = new HadoopConfiguration();
                    this.vertexProgram.storeState(vertexProgramConfiguration);
                    // set up the vertex program and wire up configurations
                    this.vertexProgram.setup(memory);
                    JavaPairRDD<Object, ViewIncomingPayload<Object>> viewIncomingRDD = null;
                    memory.broadcastMemory(sparkContext);
                    // execute the vertex program
                    while (true) {
                        if (Thread.interrupted()) {
                            sparkContext.cancelAllJobs();
                            throw new TraversalInterruptedException();
                        }
                        memory.setInExecute(true);
                        viewIncomingRDD = GraknSparkExecutor.executeVertexProgramIteration(loadedGraphRDD, viewIncomingRDD, memory, graphComputerConfiguration, vertexProgramConfiguration);
                        memory.setInExecute(false);
                        if (this.vertexProgram.terminate(memory)) {
                            break;
                        } else {
                            memory.incrIteration();
                            memory.broadcastMemory(sparkContext);
                        }
                    }
                    // then generate a view+graph
                    if ((null != outputRDD && !this.persist.equals(Persist.NOTHING)) || !this.mapReducers.isEmpty()) {
                        computedGraphRDD = GraknSparkExecutor.prepareFinalGraphRDD(loadedGraphRDD, viewIncomingRDD, this.vertexProgram.getVertexComputeKeys());
                        assert null != computedGraphRDD && computedGraphRDD != loadedGraphRDD;
                    } else {
                        // ensure that the computedGraphRDD was not created
                        assert null == computedGraphRDD;
                    }
                }
                // ///////////////
                // drop all transient memory keys
                memory.complete();
                // write the computed graph to the respective output (rdd or output format)
                if (null != outputRDD && !this.persist.equals(Persist.NOTHING)) {
                    // the logic holds that a computeGraphRDD must be created at this point
                    assert null != computedGraphRDD;
                    outputRDD.writeGraphRDD(graphComputerConfiguration, computedGraphRDD);
                }
            }
            final boolean computedGraphCreated = computedGraphRDD != null && computedGraphRDD != loadedGraphRDD;
            if (!computedGraphCreated) {
                computedGraphRDD = loadedGraphRDD;
            }
            final Memory.Admin finalMemory = null == memory ? new MapMemory() : new MapMemory(memory);
            // ////////////////////////////
            if (!this.mapReducers.isEmpty()) {
                // create a mapReduceRDD for executing the map reduce jobs on
                JavaPairRDD<Object, VertexWritable> mapReduceRDD = computedGraphRDD;
                if (computedGraphCreated && !outputToSpark) {
                    // drop all the edges of the graph as they are not used in mapReduce processing
                    mapReduceRDD = computedGraphRDD.mapValues(vertexWritable -> {
                        vertexWritable.get().dropEdges(Direction.BOTH);
                        return vertexWritable;
                    });
                    // if there is only one MapReduce to execute, don't bother wasting the clock cycles.
                    if (this.mapReducers.size() > 1) {
                        mapReduceRDD = mapReduceRDD.persist(StorageLevel.fromString(hadoopConfiguration.get(Constants.GREMLIN_SPARK_GRAPH_STORAGE_LEVEL, "MEMORY_ONLY")));
                    }
                }
                for (final MapReduce mapReduce : this.mapReducers) {
                    // execute the map reduce job
                    final HadoopConfiguration newApacheConfiguration = new HadoopConfiguration(graphComputerConfiguration);
                    mapReduce.storeState(newApacheConfiguration);
                    // map
                    final JavaPairRDD mapRDD = GraknSparkExecutor.executeMap(mapReduceRDD, mapReduce, newApacheConfiguration);
                    // combine
                    final JavaPairRDD combineRDD = mapReduce.doStage(MapReduce.Stage.COMBINE) ? GraknSparkExecutor.executeCombine(mapRDD, newApacheConfiguration) : mapRDD;
                    // reduce
                    final JavaPairRDD reduceRDD = mapReduce.doStage(MapReduce.Stage.REDUCE) ? GraknSparkExecutor.executeReduce(combineRDD, mapReduce, newApacheConfiguration) : combineRDD;
                    // write the map reduce output back to disk and computer result memory
                    if (null != outputRDD) {
                        mapReduce.addResultToMemory(finalMemory, outputRDD.writeMemoryRDD(graphComputerConfiguration, mapReduce.getMemoryKey(), reduceRDD));
                    }
                }
                // if the mapReduceRDD is not simply the computed graph, unpersist the mapReduceRDD
                if (computedGraphCreated && !outputToSpark) {
                    assert loadedGraphRDD != computedGraphRDD;
                    assert mapReduceRDD != computedGraphRDD;
                    mapReduceRDD.unpersist();
                } else {
                    assert mapReduceRDD == computedGraphRDD;
                }
            }
            // if the graphRDD was loaded from Spark, but then partitioned or filtered, its a different RDD
            if (!inputFromSpark || partitioned || filtered) {
                loadedGraphRDD.unpersist();
            }
            // then don't unpersist the computedGraphRDD/loadedGraphRDD
            if ((!outputToSpark || this.persist.equals(GraphComputer.Persist.NOTHING)) && computedGraphCreated) {
                computedGraphRDD.unpersist();
            }
            // delete any file system or rdd data if persist nothing
            if (null != outputLocation && this.persist.equals(GraphComputer.Persist.NOTHING)) {
                if (outputToHDFS) {
                    fileSystemStorage.rm(outputLocation);
                }
                if (outputToSpark) {
                    sparkContextStorage.rm(outputLocation);
                }
            }
            // update runtime and return the newly computed graph
            finalMemory.setRuntime(System.currentTimeMillis() - startTime);
            // clear properties that should not be propagated in an OLAP chain
            graphComputerConfiguration.clearProperty(Constants.GREMLIN_HADOOP_GRAPH_FILTER);
            graphComputerConfiguration.clearProperty(Constants.GREMLIN_HADOOP_VERTEX_PROGRAM_INTERCEPTOR);
            graphComputerConfiguration.clearProperty(Constants.GREMLIN_SPARK_SKIP_GRAPH_CACHE);
            graphComputerConfiguration.clearProperty(Constants.GREMLIN_SPARK_SKIP_PARTITIONER);
            return new DefaultComputerResult(InputOutputHelper.getOutputGraph(graphComputerConfiguration, this.resultGraph, this.persist), finalMemory.asImmutable());
        } catch (Exception e) {
            // So it throws the same exception as tinker does
            throw new RuntimeException(e);
        }
    });
    computerService.shutdown();
    return result;
}
Also used : InputRDD(org.apache.tinkerpop.gremlin.spark.structure.io.InputRDD) PersistedInputRDD(org.apache.tinkerpop.gremlin.spark.structure.io.PersistedInputRDD) TraversalInterruptedException(org.apache.tinkerpop.gremlin.process.traversal.util.TraversalInterruptedException) GryoSerializer(org.apache.tinkerpop.gremlin.spark.structure.io.gryo.GryoSerializer) FileSystem(org.apache.hadoop.fs.FileSystem) GraphFilterAware(org.apache.tinkerpop.gremlin.hadoop.structure.io.GraphFilterAware) GraphComputer(org.apache.tinkerpop.gremlin.process.computer.GraphComputer) LoggerFactory(org.slf4j.LoggerFactory) SparkContextStorage(org.apache.tinkerpop.gremlin.spark.structure.io.SparkContextStorage) Future(java.util.concurrent.Future) Partitioner(org.apache.spark.Partitioner) StorageLevel(org.apache.spark.storage.StorageLevel) Constants(org.apache.tinkerpop.gremlin.hadoop.Constants) Configuration(org.apache.hadoop.conf.Configuration) Path(org.apache.hadoop.fs.Path) ThreadFactory(java.util.concurrent.ThreadFactory) DefaultComputerResult(org.apache.tinkerpop.gremlin.process.computer.util.DefaultComputerResult) InputRDD(org.apache.tinkerpop.gremlin.spark.structure.io.InputRDD) HadoopConfiguration(org.apache.tinkerpop.gremlin.hadoop.structure.HadoopConfiguration) OutputRDD(org.apache.tinkerpop.gremlin.spark.structure.io.OutputRDD) HashPartitioner(org.apache.spark.HashPartitioner) Set(java.util.Set) BasicThreadFactory(org.apache.commons.lang3.concurrent.BasicThreadFactory) Executors(java.util.concurrent.Executors) SparkSingleIterationStrategy(org.apache.tinkerpop.gremlin.spark.process.computer.traversal.strategy.optimization.SparkSingleIterationStrategy) Memory(org.apache.tinkerpop.gremlin.process.computer.Memory) OutputFormatRDD(org.apache.tinkerpop.gremlin.spark.structure.io.OutputFormatRDD) InputFormatRDD(org.apache.tinkerpop.gremlin.spark.structure.io.InputFormatRDD) MapMemory(org.apache.tinkerpop.gremlin.process.computer.util.MapMemory) FileConfiguration(org.apache.commons.configuration.FileConfiguration) TraversalStrategies(org.apache.tinkerpop.gremlin.process.traversal.TraversalStrategies) TraversalInterruptedException(org.apache.tinkerpop.gremlin.process.traversal.util.TraversalInterruptedException) ConfigurationUtils(org.apache.commons.configuration.ConfigurationUtils) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) ComputerSubmissionHelper(org.apache.tinkerpop.gremlin.hadoop.process.computer.util.ComputerSubmissionHelper) VertexProgram(org.apache.tinkerpop.gremlin.process.computer.VertexProgram) HashSet(java.util.HashSet) VertexWritable(org.apache.tinkerpop.gremlin.hadoop.structure.io.VertexWritable) ComputerResult(org.apache.tinkerpop.gremlin.process.computer.ComputerResult) ThreadLocalRandom(java.util.concurrent.ThreadLocalRandom) PropertiesConfiguration(org.apache.commons.configuration.PropertiesConfiguration) AbstractHadoopGraphComputer(org.apache.tinkerpop.gremlin.hadoop.process.computer.AbstractHadoopGraphComputer) FileInputFormat(org.apache.hadoop.mapreduce.lib.input.FileInputFormat) ExecutorService(java.util.concurrent.ExecutorService) FileSystemStorage(org.apache.tinkerpop.gremlin.hadoop.structure.io.FileSystemStorage) ConfUtil(org.apache.tinkerpop.gremlin.hadoop.structure.util.ConfUtil) ViewIncomingPayload(org.apache.tinkerpop.gremlin.spark.process.computer.payload.ViewIncomingPayload) Logger(org.slf4j.Logger) SparkLauncher(org.apache.spark.launcher.SparkLauncher) InputFormat(org.apache.hadoop.mapreduce.InputFormat) InputOutputHelper(org.apache.tinkerpop.gremlin.spark.structure.io.InputOutputHelper) Spark(org.apache.tinkerpop.gremlin.spark.structure.Spark) IOException(java.io.IOException) SparkInterceptorStrategy(org.apache.tinkerpop.gremlin.spark.process.computer.traversal.strategy.optimization.SparkInterceptorStrategy) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) File(java.io.File) PersistedOutputRDD(org.apache.tinkerpop.gremlin.spark.structure.io.PersistedOutputRDD) FileOutputFormat(org.apache.hadoop.mapreduce.lib.output.FileOutputFormat) Direction(org.apache.tinkerpop.gremlin.structure.Direction) HadoopGraph(org.apache.tinkerpop.gremlin.hadoop.structure.HadoopGraph) PersistedInputRDD(org.apache.tinkerpop.gremlin.spark.structure.io.PersistedInputRDD) Storage(org.apache.tinkerpop.gremlin.structure.io.Storage) MapReduce(org.apache.tinkerpop.gremlin.process.computer.MapReduce) VertexWritable(org.apache.tinkerpop.gremlin.hadoop.structure.io.VertexWritable) Configuration(org.apache.hadoop.conf.Configuration) HadoopConfiguration(org.apache.tinkerpop.gremlin.hadoop.structure.HadoopConfiguration) FileConfiguration(org.apache.commons.configuration.FileConfiguration) PropertiesConfiguration(org.apache.commons.configuration.PropertiesConfiguration) Memory(org.apache.tinkerpop.gremlin.process.computer.Memory) MapMemory(org.apache.tinkerpop.gremlin.process.computer.util.MapMemory) ViewIncomingPayload(org.apache.tinkerpop.gremlin.spark.process.computer.payload.ViewIncomingPayload) InputFormatRDD(org.apache.tinkerpop.gremlin.spark.structure.io.InputFormatRDD) MapReduce(org.apache.tinkerpop.gremlin.process.computer.MapReduce) GraphFilterAware(org.apache.tinkerpop.gremlin.hadoop.structure.io.GraphFilterAware) MapMemory(org.apache.tinkerpop.gremlin.process.computer.util.MapMemory) DefaultComputerResult(org.apache.tinkerpop.gremlin.process.computer.util.DefaultComputerResult) ComputerResult(org.apache.tinkerpop.gremlin.process.computer.ComputerResult) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Partitioner(org.apache.spark.Partitioner) HashPartitioner(org.apache.spark.HashPartitioner) Path(org.apache.hadoop.fs.Path) IOException(java.io.IOException) VertexProgram(org.apache.tinkerpop.gremlin.process.computer.VertexProgram) TraversalInterruptedException(org.apache.tinkerpop.gremlin.process.traversal.util.TraversalInterruptedException) IOException(java.io.IOException) OutputRDD(org.apache.tinkerpop.gremlin.spark.structure.io.OutputRDD) PersistedOutputRDD(org.apache.tinkerpop.gremlin.spark.structure.io.PersistedOutputRDD) SparkContextStorage(org.apache.tinkerpop.gremlin.spark.structure.io.SparkContextStorage) FileSystemStorage(org.apache.tinkerpop.gremlin.hadoop.structure.io.FileSystemStorage) Storage(org.apache.tinkerpop.gremlin.structure.io.Storage) DefaultComputerResult(org.apache.tinkerpop.gremlin.process.computer.util.DefaultComputerResult) HashPartitioner(org.apache.spark.HashPartitioner) GryoSerializer(org.apache.tinkerpop.gremlin.spark.structure.io.gryo.GryoSerializer) HadoopConfiguration(org.apache.tinkerpop.gremlin.hadoop.structure.HadoopConfiguration)

Example 64 with JavaPairRDD

use of org.apache.spark.api.java.JavaPairRDD in project incubator-systemml by apache.

the class MLContextConversionUtil method frameObjectToDataFrame.

/**
 * Convert a {@code FrameObject} to a {@code DataFrame}.
 *
 * @param frameObject
 *            the {@code FrameObject}
 * @param sparkExecutionContext
 *            the Spark execution context
 * @return the {@code FrameObject} converted to a {@code DataFrame}
 */
public static Dataset<Row> frameObjectToDataFrame(FrameObject frameObject, SparkExecutionContext sparkExecutionContext) {
    try {
        @SuppressWarnings("unchecked") JavaPairRDD<Long, FrameBlock> binaryBlockFrame = (JavaPairRDD<Long, FrameBlock>) sparkExecutionContext.getRDDHandleForFrameObject(frameObject, InputInfo.BinaryBlockInputInfo);
        MatrixCharacteristics mc = frameObject.getMatrixCharacteristics();
        return FrameRDDConverterUtils.binaryBlockToDataFrame(spark(), binaryBlockFrame, mc, frameObject.getSchema());
    } catch (DMLRuntimeException e) {
        throw new MLContextException("DMLRuntimeException while converting frame object to DataFrame", e);
    }
}
Also used : FrameBlock(org.apache.sysml.runtime.matrix.data.FrameBlock) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) MatrixCharacteristics(org.apache.sysml.runtime.matrix.MatrixCharacteristics) DMLRuntimeException(org.apache.sysml.runtime.DMLRuntimeException)

Example 65 with JavaPairRDD

use of org.apache.spark.api.java.JavaPairRDD in project incubator-systemml by apache.

the class MLContextUtil method convertInputType.

/**
 * Convert input types to internal SystemML representations
 *
 * @param parameterName
 *            The name of the input parameter
 * @param parameterValue
 *            The value of the input parameter
 * @param metadata
 *            matrix/frame metadata
 * @return input in SystemML data representation
 */
public static Data convertInputType(String parameterName, Object parameterValue, Metadata metadata) {
    String name = parameterName;
    Object value = parameterValue;
    boolean hasMetadata = (metadata != null) ? true : false;
    boolean hasMatrixMetadata = hasMetadata && (metadata instanceof MatrixMetadata) ? true : false;
    boolean hasFrameMetadata = hasMetadata && (metadata instanceof FrameMetadata) ? true : false;
    if (name == null) {
        throw new MLContextException("Input parameter name is null");
    } else if (value == null) {
        throw new MLContextException("Input parameter value is null for: " + parameterName);
    } else if (value instanceof JavaRDD<?>) {
        @SuppressWarnings("unchecked") JavaRDD<String> javaRDD = (JavaRDD<String>) value;
        if (hasMatrixMetadata) {
            MatrixMetadata matrixMetadata = (MatrixMetadata) metadata;
            if (matrixMetadata.getMatrixFormat() == MatrixFormat.IJV) {
                return MLContextConversionUtil.javaRDDStringIJVToMatrixObject(javaRDD, matrixMetadata);
            } else {
                return MLContextConversionUtil.javaRDDStringCSVToMatrixObject(javaRDD, matrixMetadata);
            }
        } else if (hasFrameMetadata) {
            FrameMetadata frameMetadata = (FrameMetadata) metadata;
            if (frameMetadata.getFrameFormat() == FrameFormat.IJV) {
                return MLContextConversionUtil.javaRDDStringIJVToFrameObject(javaRDD, frameMetadata);
            } else {
                return MLContextConversionUtil.javaRDDStringCSVToFrameObject(javaRDD, frameMetadata);
            }
        } else if (!hasMetadata) {
            String firstLine = javaRDD.first();
            boolean isAllNumbers = isCSVLineAllNumbers(firstLine);
            if (isAllNumbers) {
                return MLContextConversionUtil.javaRDDStringCSVToMatrixObject(javaRDD);
            } else {
                return MLContextConversionUtil.javaRDDStringCSVToFrameObject(javaRDD);
            }
        }
    } else if (value instanceof RDD<?>) {
        @SuppressWarnings("unchecked") RDD<String> rdd = (RDD<String>) value;
        if (hasMatrixMetadata) {
            MatrixMetadata matrixMetadata = (MatrixMetadata) metadata;
            if (matrixMetadata.getMatrixFormat() == MatrixFormat.IJV) {
                return MLContextConversionUtil.rddStringIJVToMatrixObject(rdd, matrixMetadata);
            } else {
                return MLContextConversionUtil.rddStringCSVToMatrixObject(rdd, matrixMetadata);
            }
        } else if (hasFrameMetadata) {
            FrameMetadata frameMetadata = (FrameMetadata) metadata;
            if (frameMetadata.getFrameFormat() == FrameFormat.IJV) {
                return MLContextConversionUtil.rddStringIJVToFrameObject(rdd, frameMetadata);
            } else {
                return MLContextConversionUtil.rddStringCSVToFrameObject(rdd, frameMetadata);
            }
        } else if (!hasMetadata) {
            String firstLine = rdd.first();
            boolean isAllNumbers = isCSVLineAllNumbers(firstLine);
            if (isAllNumbers) {
                return MLContextConversionUtil.rddStringCSVToMatrixObject(rdd);
            } else {
                return MLContextConversionUtil.rddStringCSVToFrameObject(rdd);
            }
        }
    } else if (value instanceof MatrixBlock) {
        MatrixBlock matrixBlock = (MatrixBlock) value;
        return MLContextConversionUtil.matrixBlockToMatrixObject(name, matrixBlock, (MatrixMetadata) metadata);
    } else if (value instanceof FrameBlock) {
        FrameBlock frameBlock = (FrameBlock) value;
        return MLContextConversionUtil.frameBlockToFrameObject(name, frameBlock, (FrameMetadata) metadata);
    } else if (value instanceof Dataset<?>) {
        @SuppressWarnings("unchecked") Dataset<Row> dataFrame = (Dataset<Row>) value;
        dataFrame = MLUtils.convertVectorColumnsToML(dataFrame);
        if (hasMatrixMetadata) {
            return MLContextConversionUtil.dataFrameToMatrixObject(dataFrame, (MatrixMetadata) metadata);
        } else if (hasFrameMetadata) {
            return MLContextConversionUtil.dataFrameToFrameObject(dataFrame, (FrameMetadata) metadata);
        } else if (!hasMetadata) {
            boolean looksLikeMatrix = doesDataFrameLookLikeMatrix(dataFrame);
            if (looksLikeMatrix) {
                return MLContextConversionUtil.dataFrameToMatrixObject(dataFrame);
            } else {
                return MLContextConversionUtil.dataFrameToFrameObject(dataFrame);
            }
        }
    } else if (value instanceof Matrix) {
        Matrix matrix = (Matrix) value;
        if ((matrix.hasBinaryBlocks()) && (!matrix.hasMatrixObject())) {
            if (metadata == null) {
                metadata = matrix.getMatrixMetadata();
            }
            JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlocks = matrix.toBinaryBlocks();
            return MLContextConversionUtil.binaryBlocksToMatrixObject(binaryBlocks, (MatrixMetadata) metadata);
        } else {
            return matrix.toMatrixObject();
        }
    } else if (value instanceof Frame) {
        Frame frame = (Frame) value;
        if ((frame.hasBinaryBlocks()) && (!frame.hasFrameObject())) {
            if (metadata == null) {
                metadata = frame.getFrameMetadata();
            }
            JavaPairRDD<Long, FrameBlock> binaryBlocks = frame.toBinaryBlocks();
            return MLContextConversionUtil.binaryBlocksToFrameObject(binaryBlocks, (FrameMetadata) metadata);
        } else {
            return frame.toFrameObject();
        }
    } else if (value instanceof double[][]) {
        double[][] doubleMatrix = (double[][]) value;
        return MLContextConversionUtil.doubleMatrixToMatrixObject(name, doubleMatrix, (MatrixMetadata) metadata);
    } else if (value instanceof URL) {
        URL url = (URL) value;
        return MLContextConversionUtil.urlToMatrixObject(url, (MatrixMetadata) metadata);
    } else if (value instanceof Integer) {
        return new IntObject((Integer) value);
    } else if (value instanceof Double) {
        return new DoubleObject((Double) value);
    } else if (value instanceof String) {
        return new StringObject((String) value);
    } else if (value instanceof Boolean) {
        return new BooleanObject((Boolean) value);
    }
    return null;
}
Also used : MatrixBlock(org.apache.sysml.runtime.matrix.data.MatrixBlock) DoubleObject(org.apache.sysml.runtime.instructions.cp.DoubleObject) URL(java.net.URL) RDD(org.apache.spark.rdd.RDD) JavaRDD(org.apache.spark.api.java.JavaRDD) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) IntObject(org.apache.sysml.runtime.instructions.cp.IntObject) FrameBlock(org.apache.sysml.runtime.matrix.data.FrameBlock) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) StringObject(org.apache.sysml.runtime.instructions.cp.StringObject) Dataset(org.apache.spark.sql.Dataset) JavaRDD(org.apache.spark.api.java.JavaRDD) MatrixObject(org.apache.sysml.runtime.controlprogram.caching.MatrixObject) DoubleObject(org.apache.sysml.runtime.instructions.cp.DoubleObject) FrameObject(org.apache.sysml.runtime.controlprogram.caching.FrameObject) BooleanObject(org.apache.sysml.runtime.instructions.cp.BooleanObject) IntObject(org.apache.sysml.runtime.instructions.cp.IntObject) StringObject(org.apache.sysml.runtime.instructions.cp.StringObject) Row(org.apache.spark.sql.Row) BooleanObject(org.apache.sysml.runtime.instructions.cp.BooleanObject)

Aggregations

JavaPairRDD (org.apache.spark.api.java.JavaPairRDD)99 MatrixBlock (org.apache.sysml.runtime.matrix.data.MatrixBlock)44 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)42 MatrixIndexes (org.apache.sysml.runtime.matrix.data.MatrixIndexes)42 MatrixCharacteristics (org.apache.sysml.runtime.matrix.MatrixCharacteristics)41 Tuple2 (scala.Tuple2)35 DMLRuntimeException (org.apache.sysml.runtime.DMLRuntimeException)33 JavaRDD (org.apache.spark.api.java.JavaRDD)28 List (java.util.List)27 SparkExecutionContext (org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext)24 FrameBlock (org.apache.sysml.runtime.matrix.data.FrameBlock)23 Collectors (java.util.stream.Collectors)22 IOException (java.io.IOException)17 RDDObject (org.apache.sysml.runtime.instructions.spark.data.RDDObject)16 LongWritable (org.apache.hadoop.io.LongWritable)15 Broadcast (org.apache.spark.broadcast.Broadcast)15 Text (org.apache.hadoop.io.Text)12 UserException (org.broadinstitute.hellbender.exceptions.UserException)12 Function (org.apache.spark.api.java.function.Function)11 MatrixObject (org.apache.sysml.runtime.controlprogram.caching.MatrixObject)11