Search in sources :

Example 1 with PairFlatMapFunction

use of org.apache.spark.api.java.function.PairFlatMapFunction in project cdap by caskdata.

the class SparkPageRankProgram method run.

@Override
public void run(JavaSparkExecutionContext sec) throws Exception {
    JavaSparkContext jsc = new JavaSparkContext();
    LOG.info("Processing backlinkURLs data");
    JavaPairRDD<Long, String> backlinkURLs = sec.fromStream("backlinkURLStream", String.class);
    int iterationCount = getIterationCount(sec);
    LOG.info("Grouping data by key");
    // Grouping backlinks by unique URL in key
    JavaPairRDD<String, Iterable<String>> links = backlinkURLs.values().mapToPair(new PairFunction<String, String, String>() {

        @Override
        public Tuple2<String, String> call(String s) {
            String[] parts = SPACES.split(s);
            return new Tuple2<>(parts[0], parts[1]);
        }
    }).distinct().groupByKey().cache();
    // Initialize default rank for each key URL
    JavaPairRDD<String, Double> ranks = links.mapValues(new Function<Iterable<String>, Double>() {

        @Override
        public Double call(Iterable<String> rs) {
            return 1.0;
        }
    });
    // Calculates and updates URL ranks continuously using PageRank algorithm.
    for (int current = 0; current < iterationCount; current++) {
        LOG.debug("Processing data with PageRank algorithm. Iteration {}/{}", current + 1, (iterationCount));
        // Calculates URL contributions to the rank of other URLs.
        JavaPairRDD<String, Double> contribs = links.join(ranks).values().flatMapToPair(new PairFlatMapFunction<Tuple2<Iterable<String>, Double>, String, Double>() {

            @Override
            public Iterable<Tuple2<String, Double>> call(Tuple2<Iterable<String>, Double> s) {
                LOG.debug("Processing {} with rank {}", s._1(), s._2());
                int urlCount = Iterables.size(s._1());
                List<Tuple2<String, Double>> results = new ArrayList<>();
                for (String n : s._1()) {
                    results.add(new Tuple2<>(n, s._2() / urlCount));
                }
                return results;
            }
        });
        // Re-calculates URL ranks based on backlink contributions.
        ranks = contribs.reduceByKey(new Sum()).mapValues(new Function<Double, Double>() {

            @Override
            public Double call(Double sum) {
                return 0.15 + sum * 0.85;
            }
        });
    }
    LOG.info("Writing ranks data");
    final ServiceDiscoverer discoveryServiceContext = sec.getServiceDiscoverer();
    final Metrics sparkMetrics = sec.getMetrics();
    JavaPairRDD<byte[], Integer> ranksRaw = ranks.mapToPair(new PairFunction<Tuple2<String, Double>, byte[], Integer>() {

        @Override
        public Tuple2<byte[], Integer> call(Tuple2<String, Double> tuple) throws Exception {
            LOG.debug("URL {} has rank {}", Arrays.toString(tuple._1().getBytes(Charsets.UTF_8)), tuple._2());
            URL serviceURL = discoveryServiceContext.getServiceURL(SparkPageRankApp.SERVICE_HANDLERS);
            if (serviceURL == null) {
                throw new RuntimeException("Failed to discover service: " + SparkPageRankApp.SERVICE_HANDLERS);
            }
            try {
                URLConnection connection = new URL(serviceURL, String.format("%s/%s", SparkPageRankApp.SparkPageRankServiceHandler.TRANSFORM_PATH, tuple._2().toString())).openConnection();
                try (BufferedReader reader = new BufferedReader(new InputStreamReader(connection.getInputStream(), Charsets.UTF_8))) {
                    String pr = reader.readLine();
                    if ((Integer.parseInt(pr)) == POPULAR_PAGE_THRESHOLD) {
                        sparkMetrics.count(POPULAR_PAGES, 1);
                    } else if (Integer.parseInt(pr) <= UNPOPULAR_PAGE_THRESHOLD) {
                        sparkMetrics.count(UNPOPULAR_PAGES, 1);
                    } else {
                        sparkMetrics.count(REGULAR_PAGES, 1);
                    }
                    return new Tuple2<>(tuple._1().getBytes(Charsets.UTF_8), Integer.parseInt(pr));
                }
            } catch (Exception e) {
                LOG.warn("Failed to read the Stream for service {}", SparkPageRankApp.SERVICE_HANDLERS, e);
                throw Throwables.propagate(e);
            }
        }
    });
    // Store calculated results in output Dataset.
    // All calculated results are stored in one row.
    // Each result, the calculated URL rank based on backlink contributions, is an entry of the row.
    // The value of the entry is the URL rank.
    sec.saveAsDataset(ranksRaw, "ranks");
    LOG.info("PageRanks successfuly computed and written to \"ranks\" dataset");
}
Also used : URL(java.net.URL) PairFlatMapFunction(org.apache.spark.api.java.function.PairFlatMapFunction) Function(org.apache.spark.api.java.function.Function) PairFunction(org.apache.spark.api.java.function.PairFunction) Metrics(co.cask.cdap.api.metrics.Metrics) ArrayList(java.util.ArrayList) List(java.util.List) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) PairFunction(org.apache.spark.api.java.function.PairFunction) ServiceDiscoverer(co.cask.cdap.api.ServiceDiscoverer) InputStreamReader(java.io.InputStreamReader) URLConnection(java.net.URLConnection) Tuple2(scala.Tuple2) BufferedReader(java.io.BufferedReader)

Example 2 with PairFlatMapFunction

use of org.apache.spark.api.java.function.PairFlatMapFunction in project cdap by caskdata.

the class SparkPageRankProgram method run.

@Override
public void run(JavaSparkExecutionContext sec) throws Exception {
    JavaSparkContext jsc = new JavaSparkContext();
    LOG.info("Processing backlinkURLs data");
    JavaPairRDD<Long, String> backlinkURLs = sec.fromStream("backlinkURLStream", String.class);
    int iterationCount = getIterationCount(sec);
    LOG.info("Grouping data by key");
    // Grouping backlinks by unique URL in key
    JavaPairRDD<String, Iterable<String>> links = backlinkURLs.values().mapToPair(new PairFunction<String, String, String>() {

        @Override
        public Tuple2<String, String> call(String s) {
            String[] parts = SPACES.split(s);
            return new Tuple2<>(parts[0], parts[1]);
        }
    }).distinct().groupByKey().cache();
    // Initialize default rank for each key URL
    JavaPairRDD<String, Double> ranks = links.mapValues(new Function<Iterable<String>, Double>() {

        @Override
        public Double call(Iterable<String> rs) {
            return 1.0;
        }
    });
    // Calculates and updates URL ranks continuously using PageRank algorithm.
    for (int current = 0; current < iterationCount; current++) {
        LOG.debug("Processing data with PageRank algorithm. Iteration {}/{}", current + 1, (iterationCount));
        // Calculates URL contributions to the rank of other URLs.
        JavaPairRDD<String, Double> contribs = links.join(ranks).values().flatMapToPair(new PairFlatMapFunction<Tuple2<Iterable<String>, Double>, String, Double>() {

            @Override
            public Iterable<Tuple2<String, Double>> call(Tuple2<Iterable<String>, Double> s) {
                LOG.debug("Processing {} with rank {}", s._1(), s._2());
                int urlCount = Iterables.size(s._1());
                List<Tuple2<String, Double>> results = new ArrayList<>();
                for (String n : s._1()) {
                    results.add(new Tuple2<>(n, s._2() / urlCount));
                }
                return results;
            }
        });
        // Re-calculates URL ranks based on backlink contributions.
        ranks = contribs.reduceByKey(new Sum()).mapValues(new Function<Double, Double>() {

            @Override
            public Double call(Double sum) {
                return 0.15 + sum * 0.85;
            }
        });
    }
    LOG.info("Writing ranks data");
    final ServiceDiscoverer discoveryServiceContext = sec.getServiceDiscoverer();
    final Metrics sparkMetrics = sec.getMetrics();
    JavaPairRDD<byte[], Integer> ranksRaw = ranks.mapToPair(new PairFunction<Tuple2<String, Double>, byte[], Integer>() {

        @Override
        public Tuple2<byte[], Integer> call(Tuple2<String, Double> tuple) throws Exception {
            LOG.debug("URL {} has rank {}", Arrays.toString(tuple._1().getBytes(Charsets.UTF_8)), tuple._2());
            URL serviceURL = discoveryServiceContext.getServiceURL(SparkPageRankApp.SERVICE_HANDLERS);
            if (serviceURL == null) {
                throw new RuntimeException("Failed to discover service: " + SparkPageRankApp.SERVICE_HANDLERS);
            }
            try {
                URLConnection connection = new URL(serviceURL, String.format("%s/%s", SparkPageRankApp.SparkPageRankServiceHandler.TRANSFORM_PATH, tuple._2().toString())).openConnection();
                try (BufferedReader reader = new BufferedReader(new InputStreamReader(connection.getInputStream(), Charsets.UTF_8))) {
                    String pr = reader.readLine();
                    if ((Integer.parseInt(pr)) == POPULAR_PAGE_THRESHOLD) {
                        sparkMetrics.count(POPULAR_PAGES, 1);
                    } else if (Integer.parseInt(pr) <= UNPOPULAR_PAGE_THRESHOLD) {
                        sparkMetrics.count(UNPOPULAR_PAGES, 1);
                    } else {
                        sparkMetrics.count(REGULAR_PAGES, 1);
                    }
                    return new Tuple2<>(tuple._1().getBytes(Charsets.UTF_8), Integer.parseInt(pr));
                }
            } catch (Exception e) {
                LOG.warn("Failed to read the Stream for service {}", SparkPageRankApp.SERVICE_HANDLERS, e);
                throw Throwables.propagate(e);
            }
        }
    });
    // Store calculated results in output Dataset.
    // All calculated results are stored in one row.
    // Each result, the calculated URL rank based on backlink contributions, is an entry of the row.
    // The value of the entry is the URL rank.
    sec.saveAsDataset(ranksRaw, "ranks");
    LOG.info("PageRanks successfuly computed and written to \"ranks\" dataset");
}
Also used : URL(java.net.URL) PairFlatMapFunction(org.apache.spark.api.java.function.PairFlatMapFunction) Function(org.apache.spark.api.java.function.Function) PairFunction(org.apache.spark.api.java.function.PairFunction) Metrics(co.cask.cdap.api.metrics.Metrics) ArrayList(java.util.ArrayList) List(java.util.List) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) PairFunction(org.apache.spark.api.java.function.PairFunction) ServiceDiscoverer(co.cask.cdap.api.ServiceDiscoverer) InputStreamReader(java.io.InputStreamReader) URLConnection(java.net.URLConnection) Tuple2(scala.Tuple2) BufferedReader(java.io.BufferedReader)

Example 3 with PairFlatMapFunction

use of org.apache.spark.api.java.function.PairFlatMapFunction in project gatk by broadinstitute.

the class AddContextDataToReadSpark method addUsingOverlapsPartitioning.

/**
     * Add context data ({@link ReadContextData}) to reads, using overlaps partitioning to avoid a shuffle.
     * @param ctx the Spark context
     * @param mappedReads the coordinate-sorted reads
     * @param referenceSource the reference source
     * @param variants the coordinate-sorted variants
     * @param sequenceDictionary the sequence dictionary for the reads
     * @param shardSize the maximum size of each shard, in bases
     * @param shardPadding amount of extra context around each shard, in bases
     * @return a RDD of read-context pairs, in coordinate-sorted order
     */
private static JavaPairRDD<GATKRead, ReadContextData> addUsingOverlapsPartitioning(final JavaSparkContext ctx, final JavaRDD<GATKRead> mappedReads, final ReferenceMultiSource referenceSource, final JavaRDD<GATKVariant> variants, final SAMSequenceDictionary sequenceDictionary, final int shardSize, final int shardPadding) {
    final List<SimpleInterval> intervals = IntervalUtils.getAllIntervalsForReference(sequenceDictionary);
    // use unpadded shards (padding is only needed for reference bases)
    final List<ShardBoundary> intervalShards = intervals.stream().flatMap(interval -> Shard.divideIntervalIntoShards(interval, shardSize, 0, sequenceDictionary).stream()).collect(Collectors.toList());
    final Broadcast<ReferenceMultiSource> bReferenceSource = ctx.broadcast(referenceSource);
    final IntervalsSkipList<GATKVariant> variantSkipList = new IntervalsSkipList<>(variants.collect());
    final Broadcast<IntervalsSkipList<GATKVariant>> variantsBroadcast = ctx.broadcast(variantSkipList);
    int maxLocatableSize = Math.min(shardSize, shardPadding);
    JavaRDD<Shard<GATKRead>> shardedReads = SparkSharder.shard(ctx, mappedReads, GATKRead.class, sequenceDictionary, intervalShards, maxLocatableSize);
    return shardedReads.flatMapToPair(new PairFlatMapFunction<Shard<GATKRead>, GATKRead, ReadContextData>() {

        private static final long serialVersionUID = 1L;

        @Override
        public Iterator<Tuple2<GATKRead, ReadContextData>> call(Shard<GATKRead> shard) throws Exception {
            // get reference bases for this shard (padded)
            SimpleInterval paddedInterval = shard.getInterval().expandWithinContig(shardPadding, sequenceDictionary);
            ReferenceBases referenceBases = bReferenceSource.getValue().getReferenceBases(null, paddedInterval);
            final IntervalsSkipList<GATKVariant> intervalsSkipList = variantsBroadcast.getValue();
            Iterator<Tuple2<GATKRead, ReadContextData>> transform = Iterators.transform(shard.iterator(), new Function<GATKRead, Tuple2<GATKRead, ReadContextData>>() {

                @Nullable
                @Override
                public Tuple2<GATKRead, ReadContextData> apply(@Nullable GATKRead r) {
                    List<GATKVariant> overlappingVariants;
                    if (SimpleInterval.isValid(r.getContig(), r.getStart(), r.getEnd())) {
                        overlappingVariants = intervalsSkipList.getOverlapping(new SimpleInterval(r));
                    } else {
                        //Sometimes we have reads that do not form valid intervals (reads that do not consume any ref bases, eg CIGAR 61S90I
                        //In those cases, we'll just say that nothing overlaps the read
                        overlappingVariants = Collections.emptyList();
                    }
                    return new Tuple2<>(r, new ReadContextData(referenceBases, overlappingVariants));
                }
            });
            // only include reads that start in the shard
            return Iterators.filter(transform, r -> r._1().getStart() >= shard.getStart() && r._1().getStart() <= shard.getEnd());
        }
    });
}
Also used : PairFlatMapFunction(org.apache.spark.api.java.function.PairFlatMapFunction) ReferenceMultiSource(org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) Iterators(com.google.common.collect.Iterators) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) ReferenceBases(org.broadinstitute.hellbender.utils.reference.ReferenceBases) ReadContextData(org.broadinstitute.hellbender.engine.ReadContextData) JavaRDD(org.apache.spark.api.java.JavaRDD) Nullable(javax.annotation.Nullable) Broadcast(org.apache.spark.broadcast.Broadcast) IntervalsSkipList(org.broadinstitute.hellbender.utils.collections.IntervalsSkipList) Function(com.google.common.base.Function) Iterator(java.util.Iterator) SAMSequenceDictionary(htsjdk.samtools.SAMSequenceDictionary) GATKVariant(org.broadinstitute.hellbender.utils.variant.GATKVariant) Tuple2(scala.Tuple2) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval) Collectors(java.util.stream.Collectors) Shard(org.broadinstitute.hellbender.engine.Shard) List(java.util.List) UserException(org.broadinstitute.hellbender.exceptions.UserException) ShardBoundary(org.broadinstitute.hellbender.engine.ShardBoundary) Collections(java.util.Collections) ReadFilterLibrary(org.broadinstitute.hellbender.engine.filters.ReadFilterLibrary) GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) ShardBoundary(org.broadinstitute.hellbender.engine.ShardBoundary) PairFlatMapFunction(org.apache.spark.api.java.function.PairFlatMapFunction) Function(com.google.common.base.Function) IntervalsSkipList(org.broadinstitute.hellbender.utils.collections.IntervalsSkipList) Iterator(java.util.Iterator) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval) GATKVariant(org.broadinstitute.hellbender.utils.variant.GATKVariant) ReferenceMultiSource(org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource) UserException(org.broadinstitute.hellbender.exceptions.UserException) ReadContextData(org.broadinstitute.hellbender.engine.ReadContextData) ReferenceBases(org.broadinstitute.hellbender.utils.reference.ReferenceBases) Tuple2(scala.Tuple2) Shard(org.broadinstitute.hellbender.engine.Shard) Nullable(javax.annotation.Nullable)

Aggregations

List (java.util.List)3 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)3 PairFlatMapFunction (org.apache.spark.api.java.function.PairFlatMapFunction)3 Tuple2 (scala.Tuple2)3 ServiceDiscoverer (co.cask.cdap.api.ServiceDiscoverer)2 Metrics (co.cask.cdap.api.metrics.Metrics)2 BufferedReader (java.io.BufferedReader)2 InputStreamReader (java.io.InputStreamReader)2 URL (java.net.URL)2 URLConnection (java.net.URLConnection)2 ArrayList (java.util.ArrayList)2 Function (org.apache.spark.api.java.function.Function)2 PairFunction (org.apache.spark.api.java.function.PairFunction)2 Function (com.google.common.base.Function)1 Iterators (com.google.common.collect.Iterators)1 SAMSequenceDictionary (htsjdk.samtools.SAMSequenceDictionary)1 Collections (java.util.Collections)1 Iterator (java.util.Iterator)1 Collectors (java.util.stream.Collectors)1 Nullable (javax.annotation.Nullable)1