Search in sources :

Example 1 with Seq

use of scala.collection.Seq in project beam by apache.

the class SparkGroupAlsoByWindowViaWindowSet method groupAlsoByWindow.

public static <K, InputT, W extends BoundedWindow> JavaDStream<WindowedValue<KV<K, Iterable<InputT>>>> groupAlsoByWindow(JavaDStream<WindowedValue<KV<K, Iterable<WindowedValue<InputT>>>>> inputDStream, final Coder<K> keyCoder, final Coder<WindowedValue<InputT>> wvCoder, final WindowingStrategy<?, W> windowingStrategy, final SparkRuntimeContext runtimeContext, final List<Integer> sourceIds) {
    final IterableCoder<WindowedValue<InputT>> itrWvCoder = IterableCoder.of(wvCoder);
    final Coder<InputT> iCoder = ((FullWindowedValueCoder<InputT>) wvCoder).getValueCoder();
    final Coder<? extends BoundedWindow> wCoder = ((FullWindowedValueCoder<InputT>) wvCoder).getWindowCoder();
    final Coder<WindowedValue<KV<K, Iterable<InputT>>>> wvKvIterCoder = FullWindowedValueCoder.of(KvCoder.of(keyCoder, IterableCoder.of(iCoder)), wCoder);
    final TimerInternals.TimerDataCoder timerDataCoder = TimerInternals.TimerDataCoder.of(windowingStrategy.getWindowFn().windowCoder());
    long checkpointDurationMillis = runtimeContext.getPipelineOptions().as(SparkPipelineOptions.class).getCheckpointDurationMillis();
    // we have to switch to Scala API to avoid Optional in the Java API, see: SPARK-4819.
    // we also have a broader API for Scala (access to the actual key and entire iterator).
    // we use coders to convert objects in the PCollection to byte arrays, so they
    // can be transferred over the network for the shuffle and be in serialized form
    // for checkpointing.
    // for readability, we add comments with actual type next to byte[].
    // to shorten line length, we use:
    //---- WV: WindowedValue
    //---- Iterable: Itr
    //---- AccumT: A
    //---- InputT: I
    DStream<Tuple2<ByteArray, byte[]>> /*Itr<WV<I>>*/
    pairDStream = inputDStream.transformToPair(new Function<JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<InputT>>>>>, JavaPairRDD<ByteArray, byte[]>>() {

        // we use mapPartitions with the RDD API because its the only available API
        // that allows to preserve partitioning.
        @Override
        public JavaPairRDD<ByteArray, byte[]> call(JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<InputT>>>>> rdd) throws Exception {
            return rdd.mapPartitions(TranslationUtils.functionToFlatMapFunction(WindowingHelpers.<KV<K, Iterable<WindowedValue<InputT>>>>unwindowFunction()), true).mapPartitionsToPair(TranslationUtils.<K, Iterable<WindowedValue<InputT>>>toPairFlatMapFunction(), true).mapPartitionsToPair(TranslationUtils.pairFunctionToPairFlatMapFunction(CoderHelpers.toByteFunction(keyCoder, itrWvCoder)), true);
        }
    }).dstream();
    PairDStreamFunctions<ByteArray, byte[]> pairDStreamFunctions = DStream.toPairDStreamFunctions(pairDStream, JavaSparkContext$.MODULE$.<ByteArray>fakeClassTag(), JavaSparkContext$.MODULE$.<byte[]>fakeClassTag(), null);
    int defaultNumPartitions = pairDStreamFunctions.defaultPartitioner$default$1();
    Partitioner partitioner = pairDStreamFunctions.defaultPartitioner(defaultNumPartitions);
    // use updateStateByKey to scan through the state and update elements and timers.
    DStream<Tuple2<ByteArray, Tuple2<StateAndTimers, List<byte[]>>>> /*WV<KV<K, Itr<I>>>*/
    firedStream = pairDStreamFunctions.updateStateByKey(new SerializableFunction1<scala.collection.Iterator<Tuple3</*K*/
    ByteArray, Seq<byte[]>, Option<Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
    List<byte[]>>>>>, scala.collection.Iterator<Tuple2</*K*/
    ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
    List<byte[]>>>>>() {

        @Override
        public scala.collection.Iterator<Tuple2</*K*/
        ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
        List<byte[]>>>> apply(final scala.collection.Iterator<Tuple3</*K*/
        ByteArray, Seq<byte[]>, Option<Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
        List<byte[]>>>>> iter) {
            //--- ACTUAL STATEFUL OPERATION:
            //
            // Input Iterator: the partition (~bundle) of a cogrouping of the input
            // and the previous state (if exists).
            //
            // Output Iterator: the output key, and the updated state.
            //
            // possible input scenarios for (K, Seq, Option<S>):
            // (1) Option<S>.isEmpty: new data with no previous state.
            // (2) Seq.isEmpty: no new data, but evaluating previous state (timer-like behaviour).
            // (3) Seq.nonEmpty && Option<S>.isDefined: new data with previous state.
            final SystemReduceFn<K, InputT, Iterable<InputT>, Iterable<InputT>, W> reduceFn = SystemReduceFn.buffering(((FullWindowedValueCoder<InputT>) wvCoder).getValueCoder());
            final OutputWindowedValueHolder<K, InputT> outputHolder = new OutputWindowedValueHolder<>();
            // use in memory Aggregators since Spark Accumulators are not resilient
            // in stateful operators, once done with this partition.
            final MetricsContainerImpl cellProvider = new MetricsContainerImpl("cellProvider");
            final CounterCell droppedDueToClosedWindow = cellProvider.getCounter(MetricName.named(SparkGroupAlsoByWindowViaWindowSet.class, GroupAlsoByWindowsAggregators.DROPPED_DUE_TO_CLOSED_WINDOW_COUNTER));
            final CounterCell droppedDueToLateness = cellProvider.getCounter(MetricName.named(SparkGroupAlsoByWindowViaWindowSet.class, GroupAlsoByWindowsAggregators.DROPPED_DUE_TO_LATENESS_COUNTER));
            AbstractIterator<Tuple2<ByteArray, Tuple2<StateAndTimers, List<byte[]>>>> /*WV<KV<K, Itr<I>>>*/
            outIter = new AbstractIterator<Tuple2</*K*/
            ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
            List<byte[]>>>>() {

                @Override
                protected Tuple2</*K*/
                ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
                List<byte[]>>> computeNext() {
                    // (possibly) previous-state and (possibly) new data.
                    while (iter.hasNext()) {
                        // for each element in the partition:
                        Tuple3<ByteArray, Seq<byte[]>, Option<Tuple2<StateAndTimers, List<byte[]>>>> next = iter.next();
                        ByteArray encodedKey = next._1();
                        K key = CoderHelpers.fromByteArray(encodedKey.getValue(), keyCoder);
                        Seq<byte[]> seq = next._2();
                        Option<Tuple2<StateAndTimers, List<byte[]>>> prevStateAndTimersOpt = next._3();
                        SparkStateInternals<K> stateInternals;
                        SparkTimerInternals timerInternals = SparkTimerInternals.forStreamFromSources(sourceIds, GlobalWatermarkHolder.get());
                        // get state(internals) per key.
                        if (prevStateAndTimersOpt.isEmpty()) {
                            // no previous state.
                            stateInternals = SparkStateInternals.forKey(key);
                        } else {
                            // with pre-existing state.
                            StateAndTimers prevStateAndTimers = prevStateAndTimersOpt.get()._1();
                            stateInternals = SparkStateInternals.forKeyAndState(key, prevStateAndTimers.getState());
                            Collection<byte[]> serTimers = prevStateAndTimers.getTimers();
                            timerInternals.addTimers(SparkTimerInternals.deserializeTimers(serTimers, timerDataCoder));
                        }
                        ReduceFnRunner<K, InputT, Iterable<InputT>, W> reduceFnRunner = new ReduceFnRunner<>(key, windowingStrategy, ExecutableTriggerStateMachine.create(TriggerStateMachines.stateMachineForTrigger(TriggerTranslation.toProto(windowingStrategy.getTrigger()))), stateInternals, timerInternals, outputHolder, new UnsupportedSideInputReader("GroupAlsoByWindow"), reduceFn, runtimeContext.getPipelineOptions());
                        // clear before potential use.
                        outputHolder.clear();
                        if (!seq.isEmpty()) {
                            // new input for key.
                            try {
                                Iterable<WindowedValue<InputT>> elementsIterable = CoderHelpers.fromByteArray(seq.head(), itrWvCoder);
                                Iterable<WindowedValue<InputT>> validElements = LateDataUtils.dropExpiredWindows(key, elementsIterable, timerInternals, windowingStrategy, droppedDueToLateness);
                                reduceFnRunner.processElements(validElements);
                            } catch (Exception e) {
                                throw new RuntimeException("Failed to process element with ReduceFnRunner", e);
                            }
                        } else if (stateInternals.getState().isEmpty()) {
                            // no input and no state -> GC evict now.
                            continue;
                        }
                        try {
                            // advance the watermark to HWM to fire by timers.
                            timerInternals.advanceWatermark();
                            // call on timers that are ready.
                            reduceFnRunner.onTimers(timerInternals.getTimersReadyToProcess());
                        } catch (Exception e) {
                            throw new RuntimeException("Failed to process ReduceFnRunner onTimer.", e);
                        }
                        // this is mostly symbolic since actual persist is done by emitting output.
                        reduceFnRunner.persist();
                        // obtain output, if fired.
                        List<WindowedValue<KV<K, Iterable<InputT>>>> outputs = outputHolder.get();
                        if (!outputs.isEmpty() || !stateInternals.getState().isEmpty()) {
                            StateAndTimers updated = new StateAndTimers(stateInternals.getState(), SparkTimerInternals.serializeTimers(timerInternals.getTimers(), timerDataCoder));
                            // persist Spark's state by outputting.
                            List<byte[]> serOutput = CoderHelpers.toByteArrays(outputs, wvKvIterCoder);
                            return new Tuple2<>(encodedKey, new Tuple2<>(updated, serOutput));
                        }
                    // an empty state with no output, can be evicted completely - do nothing.
                    }
                    return endOfData();
                }
            };
            // log if there's something to log.
            long lateDropped = droppedDueToLateness.getCumulative();
            if (lateDropped > 0) {
                LOG.info(String.format("Dropped %d elements due to lateness.", lateDropped));
                droppedDueToLateness.inc(-droppedDueToLateness.getCumulative());
            }
            long closedWindowDropped = droppedDueToClosedWindow.getCumulative();
            if (closedWindowDropped > 0) {
                LOG.info(String.format("Dropped %d elements due to closed window.", closedWindowDropped));
                droppedDueToClosedWindow.inc(-droppedDueToClosedWindow.getCumulative());
            }
            return scala.collection.JavaConversions.asScalaIterator(outIter);
        }
    }, partitioner, true, JavaSparkContext$.MODULE$.<Tuple2<StateAndTimers, List<byte[]>>>fakeClassTag());
    if (checkpointDurationMillis > 0) {
        firedStream.checkpoint(new Duration(checkpointDurationMillis));
    }
    // go back to Java now.
    JavaPairDStream<ByteArray, Tuple2<StateAndTimers, List<byte[]>>> /*WV<KV<K, Itr<I>>>*/
    javaFiredStream = JavaPairDStream.fromPairDStream(firedStream, JavaSparkContext$.MODULE$.<ByteArray>fakeClassTag(), JavaSparkContext$.MODULE$.<Tuple2<StateAndTimers, List<byte[]>>>fakeClassTag());
    // filter state-only output (nothing to fire) and remove the state from the output.
    return javaFiredStream.filter(new Function<Tuple2</*K*/
    ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
    List<byte[]>>>, Boolean>() {

        @Override
        public Boolean call(Tuple2</*K*/
        ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
        List<byte[]>>> t2) throws Exception {
            // filter output if defined.
            return !t2._2()._2().isEmpty();
        }
    }).flatMap(new FlatMapFunction<Tuple2</*K*/
    ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
    List<byte[]>>>, WindowedValue<KV<K, Iterable<InputT>>>>() {

        @Override
        public Iterable<WindowedValue<KV<K, Iterable<InputT>>>> call(Tuple2</*K*/
        ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
        List<byte[]>>> t2) throws Exception {
            // return in serialized form.
            return CoderHelpers.fromByteArrays(t2._2()._2(), wvKvIterCoder);
        }
    });
}
Also used : MetricsContainerImpl(org.apache.beam.runners.core.metrics.MetricsContainerImpl) CounterCell(org.apache.beam.runners.core.metrics.CounterCell) WindowedValue(org.apache.beam.sdk.util.WindowedValue) OutputWindowedValue(org.apache.beam.runners.core.OutputWindowedValue) ByteArray(org.apache.beam.runners.spark.util.ByteArray) List(java.util.List) ArrayList(java.util.ArrayList) ReduceFnRunner(org.apache.beam.runners.core.ReduceFnRunner) SystemReduceFn(org.apache.beam.runners.core.SystemReduceFn) Duration(org.apache.spark.streaming.Duration) TimerInternals(org.apache.beam.runners.core.TimerInternals) Collection(java.util.Collection) Option(scala.Option) Seq(scala.collection.Seq) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) Function(org.apache.spark.api.java.function.Function) UnsupportedSideInputReader(org.apache.beam.runners.core.UnsupportedSideInputReader) AbstractIterator(com.google.common.collect.AbstractIterator) AbstractIterator(com.google.common.collect.AbstractIterator) Partitioner(org.apache.spark.Partitioner) FullWindowedValueCoder(org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder) KV(org.apache.beam.sdk.values.KV) SparkPipelineOptions(org.apache.beam.runners.spark.SparkPipelineOptions) JavaRDD(org.apache.spark.api.java.JavaRDD) Tuple2(scala.Tuple2) Tuple3(scala.Tuple3)

Example 2 with Seq

use of scala.collection.Seq in project incubator-systemml by apache.

the class MLContextTest method testInputTupleSeqWithMetadataDML.

@SuppressWarnings({ "rawtypes", "unchecked" })
@Test
public void testInputTupleSeqWithMetadataDML() {
    System.out.println("MLContextTest - Tuple sequence with metadata DML");
    List<String> list1 = new ArrayList<String>();
    list1.add("1,2");
    list1.add("3,4");
    JavaRDD<String> javaRDD1 = sc.parallelize(list1);
    RDD<String> rdd1 = JavaRDD.toRDD(javaRDD1);
    List<String> list2 = new ArrayList<String>();
    list2.add("5,6");
    list2.add("7,8");
    JavaRDD<String> javaRDD2 = sc.parallelize(list2);
    RDD<String> rdd2 = JavaRDD.toRDD(javaRDD2);
    MatrixMetadata mm1 = new MatrixMetadata(2, 2);
    MatrixMetadata mm2 = new MatrixMetadata(2, 2);
    Tuple3 tuple1 = new Tuple3("m1", rdd1, mm1);
    Tuple3 tuple2 = new Tuple3("m2", rdd2, mm2);
    List tupleList = new ArrayList();
    tupleList.add(tuple1);
    tupleList.add(tuple2);
    Seq seq = JavaConversions.asScalaBuffer(tupleList).toSeq();
    Script script = dml("print('sums: ' + sum(m1) + ' ' + sum(m2));").in(seq);
    setExpectedStdOut("sums: 10.0 26.0");
    ml.execute(script);
}
Also used : Script(org.apache.sysml.api.mlcontext.Script) Tuple3(scala.Tuple3) ArrayList(java.util.ArrayList) List(java.util.List) ArrayList(java.util.ArrayList) MatrixMetadata(org.apache.sysml.api.mlcontext.MatrixMetadata) Seq(scala.collection.Seq) Test(org.junit.Test)

Example 3 with Seq

use of scala.collection.Seq in project incubator-systemml by apache.

the class MLContextTest method testInputTupleSeqNoMetadataPYDML.

@SuppressWarnings({ "rawtypes", "unchecked" })
@Test
public void testInputTupleSeqNoMetadataPYDML() {
    System.out.println("MLContextTest - Tuple sequence no metadata PYDML");
    List<String> list1 = new ArrayList<String>();
    list1.add("1,2");
    list1.add("3,4");
    JavaRDD<String> javaRDD1 = sc.parallelize(list1);
    RDD<String> rdd1 = JavaRDD.toRDD(javaRDD1);
    List<String> list2 = new ArrayList<String>();
    list2.add("5,6");
    list2.add("7,8");
    JavaRDD<String> javaRDD2 = sc.parallelize(list2);
    RDD<String> rdd2 = JavaRDD.toRDD(javaRDD2);
    Tuple2 tuple1 = new Tuple2("m1", rdd1);
    Tuple2 tuple2 = new Tuple2("m2", rdd2);
    List tupleList = new ArrayList();
    tupleList.add(tuple1);
    tupleList.add(tuple2);
    Seq seq = JavaConversions.asScalaBuffer(tupleList).toSeq();
    Script script = pydml("print('sums: ' + sum(m1) + ' ' + sum(m2))").in(seq);
    setExpectedStdOut("sums: 10.0 26.0");
    ml.execute(script);
}
Also used : Script(org.apache.sysml.api.mlcontext.Script) Tuple2(scala.Tuple2) ArrayList(java.util.ArrayList) List(java.util.List) ArrayList(java.util.ArrayList) Seq(scala.collection.Seq) Test(org.junit.Test)

Example 4 with Seq

use of scala.collection.Seq in project incubator-systemml by apache.

the class MLContextTest method testOutputScalaSeqPYDML.

@SuppressWarnings({ "unchecked", "rawtypes" })
@Test
public void testOutputScalaSeqPYDML() {
    System.out.println("MLContextTest - output specified as Scala Seq PYDML");
    List outputs = Arrays.asList("x", "y");
    Seq seq = JavaConversions.asScalaBuffer(outputs).toSeq();
    Script script = pydml("a=1\nx=a+1\ny=x+1").out(seq);
    MLResults results = ml.execute(script);
    Assert.assertEquals(2, results.getLong("x"));
    Assert.assertEquals(3, results.getLong("y"));
}
Also used : Script(org.apache.sysml.api.mlcontext.Script) MLResults(org.apache.sysml.api.mlcontext.MLResults) List(java.util.List) ArrayList(java.util.ArrayList) Seq(scala.collection.Seq) Test(org.junit.Test)

Example 5 with Seq

use of scala.collection.Seq in project systemml by apache.

the class MLContextTest method testOutputScalaSeqDML.

@SuppressWarnings({ "unchecked", "rawtypes" })
@Test
public void testOutputScalaSeqDML() {
    System.out.println("MLContextTest - output specified as Scala Seq DML");
    List outputs = Arrays.asList("x", "y");
    Seq seq = JavaConversions.asScalaBuffer(outputs).toSeq();
    Script script = dml("a=1;x=a+1;y=x+1").out(seq);
    MLResults results = ml.execute(script);
    Assert.assertEquals(2, results.getLong("x"));
    Assert.assertEquals(3, results.getLong("y"));
}
Also used : Script(org.apache.sysml.api.mlcontext.Script) MLResults(org.apache.sysml.api.mlcontext.MLResults) List(java.util.List) ArrayList(java.util.ArrayList) Seq(scala.collection.Seq) Test(org.junit.Test)

Aggregations

Seq (scala.collection.Seq)20 ArrayList (java.util.ArrayList)18 List (java.util.List)14 Script (org.apache.sysml.api.mlcontext.Script)12 Test (org.junit.Test)12 Tuple2 (scala.Tuple2)6 Tuple3 (scala.Tuple3)5 MLResults (org.apache.sysml.api.mlcontext.MLResults)4 MatrixMetadata (org.apache.sysml.api.mlcontext.MatrixMetadata)4 AggregateCall (org.apache.calcite.rel.core.AggregateCall)3 HashMap (java.util.HashMap)2 RelTraitSet (org.apache.calcite.plan.RelTraitSet)2 RelCollation (org.apache.calcite.rel.RelCollation)2 RelNode (org.apache.calcite.rel.RelNode)2 RexInputRef (org.apache.calcite.rex.RexInputRef)2 RexNode (org.apache.calcite.rex.RexNode)2 UserDefinedFunction (org.apache.flink.table.functions.UserDefinedFunction)2 FlinkRelDistribution (org.apache.flink.table.planner.plan.trait.FlinkRelDistribution)2 DataType (org.apache.flink.table.types.DataType)2 Option (scala.Option)2