Search in sources :

Example 1 with FlatMapFunction

use of org.apache.spark.api.java.function.FlatMapFunction in project deeplearning4j by deeplearning4j.

the class Word2Vec method train.

/**
     *  Training word2vec model on a given text corpus
     *
     * @param corpusRDD training corpus
     * @throws Exception
     */
public void train(JavaRDD<String> corpusRDD) throws Exception {
    log.info("Start training ...");
    if (workers > 0)
        corpusRDD.repartition(workers);
    // SparkContext
    final JavaSparkContext sc = new JavaSparkContext(corpusRDD.context());
    // Pre-defined variables
    Map<String, Object> tokenizerVarMap = getTokenizerVarMap();
    Map<String, Object> word2vecVarMap = getWord2vecVarMap();
    // Variables to fill in train
    final JavaRDD<AtomicLong> sentenceWordsCountRDD;
    final JavaRDD<List<VocabWord>> vocabWordListRDD;
    final JavaPairRDD<List<VocabWord>, Long> vocabWordListSentenceCumSumRDD;
    final VocabCache<VocabWord> vocabCache;
    final JavaRDD<Long> sentenceCumSumCountRDD;
    int maxRep = 1;
    // Start Training //
    //////////////////////////////////////
    log.info("Tokenization and building VocabCache ...");
    // Processing every sentence and make a VocabCache which gets fed into a LookupCache
    Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(tokenizerVarMap);
    TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
    pipeline.buildVocabCache();
    pipeline.buildVocabWordListRDD();
    // Get total word count and put into word2vec variable map
    word2vecVarMap.put("totalWordCount", pipeline.getTotalWordCount());
    // 2 RDDs: (vocab words list) and (sentence Count).Already cached
    sentenceWordsCountRDD = pipeline.getSentenceCountRDD();
    vocabWordListRDD = pipeline.getVocabWordListRDD();
    // Get vocabCache and broad-casted vocabCache
    Broadcast<VocabCache<VocabWord>> vocabCacheBroadcast = pipeline.getBroadCastVocabCache();
    vocabCache = vocabCacheBroadcast.getValue();
    log.info("Vocab size: {}", vocabCache.numWords());
    //////////////////////////////////////
    log.info("Building Huffman Tree ...");
    // Building Huffman Tree would update the code and point in each of the vocabWord in vocabCache
    /*
        We don't need to build tree here, since it was built earlier, at TextPipeline.buildVocabCache() call.
        
        Huffman huffman = new Huffman(vocabCache.vocabWords());
        huffman.build();
        huffman.applyIndexes(vocabCache);
        */
    //////////////////////////////////////
    log.info("Calculating cumulative sum of sentence counts ...");
    sentenceCumSumCountRDD = new CountCumSum(sentenceWordsCountRDD).buildCumSum();
    //////////////////////////////////////
    log.info("Mapping to RDD(vocabWordList, cumulative sentence count) ...");
    vocabWordListSentenceCumSumRDD = vocabWordListRDD.zip(sentenceCumSumCountRDD).setName("vocabWordListSentenceCumSumRDD");
    /////////////////////////////////////
    log.info("Broadcasting word2vec variables to workers ...");
    Broadcast<Map<String, Object>> word2vecVarMapBroadcast = sc.broadcast(word2vecVarMap);
    Broadcast<double[]> expTableBroadcast = sc.broadcast(expTable);
    /////////////////////////////////////
    log.info("Training word2vec sentences ...");
    FlatMapFunction firstIterFunc = new FirstIterationFunction(word2vecVarMapBroadcast, expTableBroadcast, vocabCacheBroadcast);
    @SuppressWarnings("unchecked") JavaRDD<Pair<VocabWord, INDArray>> indexSyn0UpdateEntryRDD = vocabWordListSentenceCumSumRDD.mapPartitions(firstIterFunc).map(new MapToPairFunction());
    // Get all the syn0 updates into a list in driver
    List<Pair<VocabWord, INDArray>> syn0UpdateEntries = indexSyn0UpdateEntryRDD.collect();
    // Instantiate syn0
    INDArray syn0 = Nd4j.zeros(vocabCache.numWords(), layerSize);
    // Updating syn0 first pass: just add vectors obtained from different nodes
    log.info("Averaging results...");
    Map<VocabWord, AtomicInteger> updates = new HashMap<>();
    Map<Long, Long> updaters = new HashMap<>();
    for (Pair<VocabWord, INDArray> syn0UpdateEntry : syn0UpdateEntries) {
        syn0.getRow(syn0UpdateEntry.getFirst().getIndex()).addi(syn0UpdateEntry.getSecond());
        // for proper averaging we need to divide resulting sums later, by the number of additions
        if (updates.containsKey(syn0UpdateEntry.getFirst())) {
            updates.get(syn0UpdateEntry.getFirst()).incrementAndGet();
        } else
            updates.put(syn0UpdateEntry.getFirst(), new AtomicInteger(1));
        if (!updaters.containsKey(syn0UpdateEntry.getFirst().getVocabId())) {
            updaters.put(syn0UpdateEntry.getFirst().getVocabId(), syn0UpdateEntry.getFirst().getAffinityId());
        }
    }
    // Updating syn0 second pass: average obtained vectors
    for (Map.Entry<VocabWord, AtomicInteger> entry : updates.entrySet()) {
        if (entry.getValue().get() > 1) {
            if (entry.getValue().get() > maxRep)
                maxRep = entry.getValue().get();
            syn0.getRow(entry.getKey().getIndex()).divi(entry.getValue().get());
        }
    }
    long totals = 0;
    log.info("Finished calculations...");
    vocab = vocabCache;
    InMemoryLookupTable<VocabWord> inMemoryLookupTable = new InMemoryLookupTable<VocabWord>();
    Environment env = EnvironmentUtils.buildEnvironment();
    env.setNumCores(maxRep);
    env.setAvailableMemory(totals);
    update(env, Event.SPARK);
    inMemoryLookupTable.setVocab(vocabCache);
    inMemoryLookupTable.setVectorLength(layerSize);
    inMemoryLookupTable.setSyn0(syn0);
    lookupTable = inMemoryLookupTable;
    modelUtils.init(lookupTable);
}
Also used : HashMap(java.util.HashMap) VocabWord(org.deeplearning4j.models.word2vec.VocabWord) InMemoryLookupTable(org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) ArrayList(java.util.ArrayList) List(java.util.List) CountCumSum(org.deeplearning4j.spark.text.functions.CountCumSum) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Pair(org.deeplearning4j.berkeley.Pair) TextPipeline(org.deeplearning4j.spark.text.functions.TextPipeline) AtomicLong(java.util.concurrent.atomic.AtomicLong) INDArray(org.nd4j.linalg.api.ndarray.INDArray) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) VocabCache(org.deeplearning4j.models.word2vec.wordstore.VocabCache) AtomicLong(java.util.concurrent.atomic.AtomicLong) Environment(org.nd4j.linalg.heartbeat.reports.Environment) HashMap(java.util.HashMap) Map(java.util.Map)

Example 2 with FlatMapFunction

use of org.apache.spark.api.java.function.FlatMapFunction in project beam by apache.

the class SparkGroupAlsoByWindowViaWindowSet method groupAlsoByWindow.

public static <K, InputT, W extends BoundedWindow> JavaDStream<WindowedValue<KV<K, Iterable<InputT>>>> groupAlsoByWindow(JavaDStream<WindowedValue<KV<K, Iterable<WindowedValue<InputT>>>>> inputDStream, final Coder<K> keyCoder, final Coder<WindowedValue<InputT>> wvCoder, final WindowingStrategy<?, W> windowingStrategy, final SparkRuntimeContext runtimeContext, final List<Integer> sourceIds) {
    final IterableCoder<WindowedValue<InputT>> itrWvCoder = IterableCoder.of(wvCoder);
    final Coder<InputT> iCoder = ((FullWindowedValueCoder<InputT>) wvCoder).getValueCoder();
    final Coder<? extends BoundedWindow> wCoder = ((FullWindowedValueCoder<InputT>) wvCoder).getWindowCoder();
    final Coder<WindowedValue<KV<K, Iterable<InputT>>>> wvKvIterCoder = FullWindowedValueCoder.of(KvCoder.of(keyCoder, IterableCoder.of(iCoder)), wCoder);
    final TimerInternals.TimerDataCoder timerDataCoder = TimerInternals.TimerDataCoder.of(windowingStrategy.getWindowFn().windowCoder());
    long checkpointDurationMillis = runtimeContext.getPipelineOptions().as(SparkPipelineOptions.class).getCheckpointDurationMillis();
    // we have to switch to Scala API to avoid Optional in the Java API, see: SPARK-4819.
    // we also have a broader API for Scala (access to the actual key and entire iterator).
    // we use coders to convert objects in the PCollection to byte arrays, so they
    // can be transferred over the network for the shuffle and be in serialized form
    // for checkpointing.
    // for readability, we add comments with actual type next to byte[].
    // to shorten line length, we use:
    //---- WV: WindowedValue
    //---- Iterable: Itr
    //---- AccumT: A
    //---- InputT: I
    DStream<Tuple2<ByteArray, byte[]>> /*Itr<WV<I>>*/
    pairDStream = inputDStream.transformToPair(new Function<JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<InputT>>>>>, JavaPairRDD<ByteArray, byte[]>>() {

        // we use mapPartitions with the RDD API because its the only available API
        // that allows to preserve partitioning.
        @Override
        public JavaPairRDD<ByteArray, byte[]> call(JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<InputT>>>>> rdd) throws Exception {
            return rdd.mapPartitions(TranslationUtils.functionToFlatMapFunction(WindowingHelpers.<KV<K, Iterable<WindowedValue<InputT>>>>unwindowFunction()), true).mapPartitionsToPair(TranslationUtils.<K, Iterable<WindowedValue<InputT>>>toPairFlatMapFunction(), true).mapPartitionsToPair(TranslationUtils.pairFunctionToPairFlatMapFunction(CoderHelpers.toByteFunction(keyCoder, itrWvCoder)), true);
        }
    }).dstream();
    PairDStreamFunctions<ByteArray, byte[]> pairDStreamFunctions = DStream.toPairDStreamFunctions(pairDStream, JavaSparkContext$.MODULE$.<ByteArray>fakeClassTag(), JavaSparkContext$.MODULE$.<byte[]>fakeClassTag(), null);
    int defaultNumPartitions = pairDStreamFunctions.defaultPartitioner$default$1();
    Partitioner partitioner = pairDStreamFunctions.defaultPartitioner(defaultNumPartitions);
    // use updateStateByKey to scan through the state and update elements and timers.
    DStream<Tuple2<ByteArray, Tuple2<StateAndTimers, List<byte[]>>>> /*WV<KV<K, Itr<I>>>*/
    firedStream = pairDStreamFunctions.updateStateByKey(new SerializableFunction1<scala.collection.Iterator<Tuple3</*K*/
    ByteArray, Seq<byte[]>, Option<Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
    List<byte[]>>>>>, scala.collection.Iterator<Tuple2</*K*/
    ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
    List<byte[]>>>>>() {

        @Override
        public scala.collection.Iterator<Tuple2</*K*/
        ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
        List<byte[]>>>> apply(final scala.collection.Iterator<Tuple3</*K*/
        ByteArray, Seq<byte[]>, Option<Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
        List<byte[]>>>>> iter) {
            //--- ACTUAL STATEFUL OPERATION:
            //
            // Input Iterator: the partition (~bundle) of a cogrouping of the input
            // and the previous state (if exists).
            //
            // Output Iterator: the output key, and the updated state.
            //
            // possible input scenarios for (K, Seq, Option<S>):
            // (1) Option<S>.isEmpty: new data with no previous state.
            // (2) Seq.isEmpty: no new data, but evaluating previous state (timer-like behaviour).
            // (3) Seq.nonEmpty && Option<S>.isDefined: new data with previous state.
            final SystemReduceFn<K, InputT, Iterable<InputT>, Iterable<InputT>, W> reduceFn = SystemReduceFn.buffering(((FullWindowedValueCoder<InputT>) wvCoder).getValueCoder());
            final OutputWindowedValueHolder<K, InputT> outputHolder = new OutputWindowedValueHolder<>();
            // use in memory Aggregators since Spark Accumulators are not resilient
            // in stateful operators, once done with this partition.
            final MetricsContainerImpl cellProvider = new MetricsContainerImpl("cellProvider");
            final CounterCell droppedDueToClosedWindow = cellProvider.getCounter(MetricName.named(SparkGroupAlsoByWindowViaWindowSet.class, GroupAlsoByWindowsAggregators.DROPPED_DUE_TO_CLOSED_WINDOW_COUNTER));
            final CounterCell droppedDueToLateness = cellProvider.getCounter(MetricName.named(SparkGroupAlsoByWindowViaWindowSet.class, GroupAlsoByWindowsAggregators.DROPPED_DUE_TO_LATENESS_COUNTER));
            AbstractIterator<Tuple2<ByteArray, Tuple2<StateAndTimers, List<byte[]>>>> /*WV<KV<K, Itr<I>>>*/
            outIter = new AbstractIterator<Tuple2</*K*/
            ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
            List<byte[]>>>>() {

                @Override
                protected Tuple2</*K*/
                ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
                List<byte[]>>> computeNext() {
                    // (possibly) previous-state and (possibly) new data.
                    while (iter.hasNext()) {
                        // for each element in the partition:
                        Tuple3<ByteArray, Seq<byte[]>, Option<Tuple2<StateAndTimers, List<byte[]>>>> next = iter.next();
                        ByteArray encodedKey = next._1();
                        K key = CoderHelpers.fromByteArray(encodedKey.getValue(), keyCoder);
                        Seq<byte[]> seq = next._2();
                        Option<Tuple2<StateAndTimers, List<byte[]>>> prevStateAndTimersOpt = next._3();
                        SparkStateInternals<K> stateInternals;
                        SparkTimerInternals timerInternals = SparkTimerInternals.forStreamFromSources(sourceIds, GlobalWatermarkHolder.get());
                        // get state(internals) per key.
                        if (prevStateAndTimersOpt.isEmpty()) {
                            // no previous state.
                            stateInternals = SparkStateInternals.forKey(key);
                        } else {
                            // with pre-existing state.
                            StateAndTimers prevStateAndTimers = prevStateAndTimersOpt.get()._1();
                            stateInternals = SparkStateInternals.forKeyAndState(key, prevStateAndTimers.getState());
                            Collection<byte[]> serTimers = prevStateAndTimers.getTimers();
                            timerInternals.addTimers(SparkTimerInternals.deserializeTimers(serTimers, timerDataCoder));
                        }
                        ReduceFnRunner<K, InputT, Iterable<InputT>, W> reduceFnRunner = new ReduceFnRunner<>(key, windowingStrategy, ExecutableTriggerStateMachine.create(TriggerStateMachines.stateMachineForTrigger(TriggerTranslation.toProto(windowingStrategy.getTrigger()))), stateInternals, timerInternals, outputHolder, new UnsupportedSideInputReader("GroupAlsoByWindow"), reduceFn, runtimeContext.getPipelineOptions());
                        // clear before potential use.
                        outputHolder.clear();
                        if (!seq.isEmpty()) {
                            // new input for key.
                            try {
                                Iterable<WindowedValue<InputT>> elementsIterable = CoderHelpers.fromByteArray(seq.head(), itrWvCoder);
                                Iterable<WindowedValue<InputT>> validElements = LateDataUtils.dropExpiredWindows(key, elementsIterable, timerInternals, windowingStrategy, droppedDueToLateness);
                                reduceFnRunner.processElements(validElements);
                            } catch (Exception e) {
                                throw new RuntimeException("Failed to process element with ReduceFnRunner", e);
                            }
                        } else if (stateInternals.getState().isEmpty()) {
                            // no input and no state -> GC evict now.
                            continue;
                        }
                        try {
                            // advance the watermark to HWM to fire by timers.
                            timerInternals.advanceWatermark();
                            // call on timers that are ready.
                            reduceFnRunner.onTimers(timerInternals.getTimersReadyToProcess());
                        } catch (Exception e) {
                            throw new RuntimeException("Failed to process ReduceFnRunner onTimer.", e);
                        }
                        // this is mostly symbolic since actual persist is done by emitting output.
                        reduceFnRunner.persist();
                        // obtain output, if fired.
                        List<WindowedValue<KV<K, Iterable<InputT>>>> outputs = outputHolder.get();
                        if (!outputs.isEmpty() || !stateInternals.getState().isEmpty()) {
                            StateAndTimers updated = new StateAndTimers(stateInternals.getState(), SparkTimerInternals.serializeTimers(timerInternals.getTimers(), timerDataCoder));
                            // persist Spark's state by outputting.
                            List<byte[]> serOutput = CoderHelpers.toByteArrays(outputs, wvKvIterCoder);
                            return new Tuple2<>(encodedKey, new Tuple2<>(updated, serOutput));
                        }
                    // an empty state with no output, can be evicted completely - do nothing.
                    }
                    return endOfData();
                }
            };
            // log if there's something to log.
            long lateDropped = droppedDueToLateness.getCumulative();
            if (lateDropped > 0) {
                LOG.info(String.format("Dropped %d elements due to lateness.", lateDropped));
                droppedDueToLateness.inc(-droppedDueToLateness.getCumulative());
            }
            long closedWindowDropped = droppedDueToClosedWindow.getCumulative();
            if (closedWindowDropped > 0) {
                LOG.info(String.format("Dropped %d elements due to closed window.", closedWindowDropped));
                droppedDueToClosedWindow.inc(-droppedDueToClosedWindow.getCumulative());
            }
            return scala.collection.JavaConversions.asScalaIterator(outIter);
        }
    }, partitioner, true, JavaSparkContext$.MODULE$.<Tuple2<StateAndTimers, List<byte[]>>>fakeClassTag());
    if (checkpointDurationMillis > 0) {
        firedStream.checkpoint(new Duration(checkpointDurationMillis));
    }
    // go back to Java now.
    JavaPairDStream<ByteArray, Tuple2<StateAndTimers, List<byte[]>>> /*WV<KV<K, Itr<I>>>*/
    javaFiredStream = JavaPairDStream.fromPairDStream(firedStream, JavaSparkContext$.MODULE$.<ByteArray>fakeClassTag(), JavaSparkContext$.MODULE$.<Tuple2<StateAndTimers, List<byte[]>>>fakeClassTag());
    // filter state-only output (nothing to fire) and remove the state from the output.
    return javaFiredStream.filter(new Function<Tuple2</*K*/
    ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
    List<byte[]>>>, Boolean>() {

        @Override
        public Boolean call(Tuple2</*K*/
        ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
        List<byte[]>>> t2) throws Exception {
            // filter output if defined.
            return !t2._2()._2().isEmpty();
        }
    }).flatMap(new FlatMapFunction<Tuple2</*K*/
    ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
    List<byte[]>>>, WindowedValue<KV<K, Iterable<InputT>>>>() {

        @Override
        public Iterable<WindowedValue<KV<K, Iterable<InputT>>>> call(Tuple2</*K*/
        ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
        List<byte[]>>> t2) throws Exception {
            // return in serialized form.
            return CoderHelpers.fromByteArrays(t2._2()._2(), wvKvIterCoder);
        }
    });
}
Also used : MetricsContainerImpl(org.apache.beam.runners.core.metrics.MetricsContainerImpl) CounterCell(org.apache.beam.runners.core.metrics.CounterCell) WindowedValue(org.apache.beam.sdk.util.WindowedValue) OutputWindowedValue(org.apache.beam.runners.core.OutputWindowedValue) ByteArray(org.apache.beam.runners.spark.util.ByteArray) List(java.util.List) ArrayList(java.util.ArrayList) ReduceFnRunner(org.apache.beam.runners.core.ReduceFnRunner) SystemReduceFn(org.apache.beam.runners.core.SystemReduceFn) Duration(org.apache.spark.streaming.Duration) TimerInternals(org.apache.beam.runners.core.TimerInternals) Collection(java.util.Collection) Option(scala.Option) Seq(scala.collection.Seq) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) Function(org.apache.spark.api.java.function.Function) UnsupportedSideInputReader(org.apache.beam.runners.core.UnsupportedSideInputReader) AbstractIterator(com.google.common.collect.AbstractIterator) AbstractIterator(com.google.common.collect.AbstractIterator) Partitioner(org.apache.spark.Partitioner) FullWindowedValueCoder(org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder) KV(org.apache.beam.sdk.values.KV) SparkPipelineOptions(org.apache.beam.runners.spark.SparkPipelineOptions) JavaRDD(org.apache.spark.api.java.JavaRDD) Tuple2(scala.Tuple2) Tuple3(scala.Tuple3)

Example 3 with FlatMapFunction

use of org.apache.spark.api.java.function.FlatMapFunction in project gatk by broadinstitute.

the class VariantWalkerSpark method getVariantsFunction.

private static FlatMapFunction<Shard<VariantContext>, VariantWalkerContext> getVariantsFunction(final Broadcast<ReferenceMultiSource> bReferenceSource, final Broadcast<FeatureManager> bFeatureManager, final SAMSequenceDictionary sequenceDictionary, final int variantShardPadding) {
    return (FlatMapFunction<Shard<VariantContext>, VariantWalkerContext>) shard -> {
        SimpleInterval paddedInterval = shard.getInterval().expandWithinContig(variantShardPadding, sequenceDictionary);
        ReferenceDataSource reference = bReferenceSource == null ? null : new ReferenceMemorySource(bReferenceSource.getValue().getReferenceBases(null, paddedInterval), sequenceDictionary);
        FeatureManager features = bFeatureManager == null ? null : bFeatureManager.getValue();
        return StreamSupport.stream(shard.spliterator(), false).filter(v -> v.getStart() >= shard.getStart() && v.getStart() <= shard.getEnd()).map(v -> {
            final SimpleInterval variantInterval = new SimpleInterval(v);
            return new VariantWalkerContext(v, new ReadsContext(), new ReferenceContext(reference, variantInterval), new FeatureContext(features, variantInterval));
        }).iterator();
    };
}
Also used : Broadcast(org.apache.spark.broadcast.Broadcast) VCFHeader(htsjdk.variant.vcf.VCFHeader) ReferenceMultiSource(org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource) SAMSequenceDictionary(htsjdk.samtools.SAMSequenceDictionary) Argument(org.broadinstitute.barclay.argparser.Argument) IndexUtils(org.broadinstitute.hellbender.utils.IndexUtils) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) VariantFilterLibrary(org.broadinstitute.hellbender.engine.filters.VariantFilterLibrary) StandardArgumentDefinitions(org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval) Collectors(java.util.stream.Collectors) VariantFilter(org.broadinstitute.hellbender.engine.filters.VariantFilter) org.broadinstitute.hellbender.engine(org.broadinstitute.hellbender.engine) List(java.util.List) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) VariantContext(htsjdk.variant.variantcontext.VariantContext) VariantsSparkSource(org.broadinstitute.hellbender.engine.spark.datasources.VariantsSparkSource) StreamSupport(java.util.stream.StreamSupport) JavaRDD(org.apache.spark.api.java.JavaRDD) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) VariantContext(htsjdk.variant.variantcontext.VariantContext) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval)

Example 4 with FlatMapFunction

use of org.apache.spark.api.java.function.FlatMapFunction in project gatk by broadinstitute.

the class ReadsSparkSourceUnitTest method testPutPairsInSamePartition.

@Test(dataProvider = "readPairsAndPartitions")
public void testPutPairsInSamePartition(int numPairs, int numPartitions, int[] expectedReadsPerPartition) throws IOException {
    JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();
    SAMFileHeader header = ArtificialReadUtils.createArtificialSamHeader();
    header.setSortOrder(SAMFileHeader.SortOrder.queryname);
    JavaRDD<GATKRead> reads = createPairedReads(ctx, header, numPairs, numPartitions);
    ReadsSparkSource readsSparkSource = new ReadsSparkSource(ctx);
    JavaRDD<GATKRead> pairedReads = readsSparkSource.putPairsInSamePartition(header, reads);
    List<List<GATKRead>> partitions = pairedReads.mapPartitions((FlatMapFunction<Iterator<GATKRead>, List<GATKRead>>) it -> Iterators.singletonIterator(Lists.newArrayList(it))).collect();
    assertEquals(partitions.size(), numPartitions);
    for (int i = 0; i < numPartitions; i++) {
        assertEquals(partitions.get(i).size(), expectedReadsPerPartition[i]);
    }
    assertEquals(Arrays.stream(expectedReadsPerPartition).sum(), numPairs * 2);
}
Also used : GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) ArrayList(java.util.ArrayList) ImmutableList(com.google.common.collect.ImmutableList) List(java.util.List) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) BaseTest(org.broadinstitute.hellbender.utils.test.BaseTest) Test(org.testng.annotations.Test)

Example 5 with FlatMapFunction

use of org.apache.spark.api.java.function.FlatMapFunction in project beijingThirdPeriod by weidongcao.

the class SparkOperateBcp method run.

public static void run(TaskBean task) {
    logger.info("开始处理 {} 的BCP数据", task.getContentType());
    SparkConf conf = new SparkConf().setAppName(task.getContentType());
    JavaSparkContext sc = new JavaSparkContext(conf);
    JavaRDD<String> originalRDD = sc.textFile(task.getBcpPath());
    // 对BCP文件数据进行基本的处理,并生成ID(HBase的RowKey,Solr的Sid)
    JavaRDD<String[]> valueArrrayRDD = originalRDD.mapPartitions((FlatMapFunction<Iterator<String>, String[]>) iter -> {
        List<String[]> list = new ArrayList<>();
        while (iter.hasNext()) {
            String str = iter.next();
            String[] fields = str.split("\t");
            list.add(fields);
        }
        return list.iterator();
    });
    /*
         * 对数据进行过滤
         * 字段名数组里没有id字段(HBase的RowKey,Solr的Side)
         * BCP文件可能升级,添加了新的字段
         * FTP、IM_CHAT表新加了三个字段:"service_code_out", "terminal_longitude", "terminal_latitude"
         * HTTP表新了了7个字段其中三个字段与上面相同:"service_code_out", "terminal_longitude", "terminal_latitude"
         *      另外4个字段是:"manufacturer_code", "zipname", "bcpname", "rownumber", "
         * 故过滤的时候要把以上情况考虑进去
         */
    JavaRDD<String[]> filterValuesRDD;
    filterValuesRDD = valueArrrayRDD.filter((Function<String[], Boolean>) (String[] strings) -> // BCP文件 没有新加字段,
    (task.getColumns().length + 1 == strings.length) || // BCP文件添加了新的字段,且只添加了三个字段
    ((task.getColumns().length + 1) == (strings.length + 3)) || // HTTP的BCP文件添加了新的字段,且添加了7个字段
    (BigDataConstants.CONTENT_TYPE_HTTP.equalsIgnoreCase(task.getContentType()) && ((task.getColumns().length + 1) == (strings.length + 3 + 4))));
    // BCP文件数据写入HBase
    bcpWriteIntoHBase(filterValuesRDD, task);
    sc.close();
}
Also used : PairFlatMapFunction(org.apache.spark.api.java.function.PairFlatMapFunction) Date(java.util.Date) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) LoggerFactory(org.slf4j.LoggerFactory) ArrayUtils(org.apache.commons.lang3.ArrayUtils) VoidFunction(org.apache.spark.api.java.function.VoidFunction) DateFormatUtils(com.rainsoft.utils.DateFormatUtils) ArrayList(java.util.ArrayList) TaskBean(com.rainsoft.domain.TaskBean) ClassPathXmlApplicationContext(org.springframework.context.support.ClassPathXmlApplicationContext) BigDataConstants(com.rainsoft.BigDataConstants) JavaRDD(org.apache.spark.api.java.JavaRDD) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) Logger(org.slf4j.Logger) Iterator(java.util.Iterator) SolrUtil(com.rainsoft.utils.SolrUtil) SparkConf(org.apache.spark.SparkConf) RowkeyColumnSecondarySort(com.rainsoft.hbase.RowkeyColumnSecondarySort) Tuple2(scala.Tuple2) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) SolrClient(org.apache.solr.client.solrj.SolrClient) Serializable(java.io.Serializable) HBaseUtils(com.rainsoft.utils.HBaseUtils) List(java.util.List) AbstractApplicationContext(org.springframework.context.support.AbstractApplicationContext) FieldConstants(com.rainsoft.FieldConstants) Function(org.apache.spark.api.java.function.Function) SolrInputDocument(org.apache.solr.common.SolrInputDocument) PairFlatMapFunction(org.apache.spark.api.java.function.PairFlatMapFunction) VoidFunction(org.apache.spark.api.java.function.VoidFunction) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) Function(org.apache.spark.api.java.function.Function) Iterator(java.util.Iterator) ArrayList(java.util.ArrayList) List(java.util.List) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) SparkConf(org.apache.spark.SparkConf)

Aggregations

FlatMapFunction (org.apache.spark.api.java.function.FlatMapFunction)15 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)12 List (java.util.List)10 JavaRDD (org.apache.spark.api.java.JavaRDD)9 ArrayList (java.util.ArrayList)8 Collectors (java.util.stream.Collectors)5 Function (org.apache.spark.api.java.function.Function)5 Tuple2 (scala.Tuple2)5 SAMSequenceDictionary (htsjdk.samtools.SAMSequenceDictionary)4 JavaPairRDD (org.apache.spark.api.java.JavaPairRDD)4 Iterator (java.util.Iterator)3 StreamSupport (java.util.stream.StreamSupport)3 SparkConf (org.apache.spark.SparkConf)3 Function2 (org.apache.spark.api.java.function.Function2)3 Broadcast (org.apache.spark.broadcast.Broadcast)3 Argument (org.broadinstitute.barclay.argparser.Argument)3 org.broadinstitute.hellbender.engine (org.broadinstitute.hellbender.engine)3 ReferenceMultiSource (org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource)3 IntervalUtils (org.broadinstitute.hellbender.utils.IntervalUtils)3 SimpleInterval (org.broadinstitute.hellbender.utils.SimpleInterval)3