Search in sources :

Example 1 with CoGroupByKey

use of org.apache.beam.sdk.transforms.join.CoGroupByKey in project beam by apache.

the class SparkCoGroupByKeyStreamingTest method testInStreamingMode.

@Category(StreamingTest.class)
@Test
public void testInStreamingMode() throws Exception {
    Instant instant = new Instant(0);
    CreateStream<KV<Integer, Integer>> source1 = CreateStream.of(KvCoder.of(VarIntCoder.of(), VarIntCoder.of()), batchDuration()).emptyBatch().advanceWatermarkForNextBatch(instant).nextBatch(TimestampedValue.of(KV.of(1, 1), instant), TimestampedValue.of(KV.of(1, 2), instant), TimestampedValue.of(KV.of(1, 3), instant)).advanceWatermarkForNextBatch(instant.plus(Duration.standardSeconds(1L))).nextBatch(TimestampedValue.of(KV.of(2, 4), instant.plus(Duration.standardSeconds(1L))), TimestampedValue.of(KV.of(2, 5), instant.plus(Duration.standardSeconds(1L))), TimestampedValue.of(KV.of(2, 6), instant.plus(Duration.standardSeconds(1L)))).advanceNextBatchWatermarkToInfinity();
    CreateStream<KV<Integer, Integer>> source2 = CreateStream.of(KvCoder.of(VarIntCoder.of(), VarIntCoder.of()), batchDuration()).emptyBatch().advanceWatermarkForNextBatch(instant).nextBatch(TimestampedValue.of(KV.of(1, 11), instant), TimestampedValue.of(KV.of(1, 12), instant), TimestampedValue.of(KV.of(1, 13), instant)).advanceWatermarkForNextBatch(instant.plus(Duration.standardSeconds(1L))).nextBatch(TimestampedValue.of(KV.of(2, 14), instant.plus(Duration.standardSeconds(1L))), TimestampedValue.of(KV.of(2, 15), instant.plus(Duration.standardSeconds(1L))), TimestampedValue.of(KV.of(2, 16), instant.plus(Duration.standardSeconds(1L)))).advanceNextBatchWatermarkToInfinity();
    PCollection<KV<Integer, Integer>> input1 = pipeline.apply("create source1", source1).apply("window input1", Window.<KV<Integer, Integer>>into(FixedWindows.of(Duration.standardSeconds(3L))).withAllowedLateness(Duration.ZERO));
    PCollection<KV<Integer, Integer>> input2 = pipeline.apply("create source2", source2).apply("window input2", Window.<KV<Integer, Integer>>into(FixedWindows.of(Duration.standardSeconds(3L))).withAllowedLateness(Duration.ZERO));
    PCollection<KV<Integer, CoGbkResult>> output = KeyedPCollectionTuple.of(INPUT1_TAG, input1).and(INPUT2_TAG, input2).apply(CoGroupByKey.create());
    PAssert.that("Wrong output of the join using CoGroupByKey in streaming mode", output).satisfies((SerializableFunction<Iterable<KV<Integer, CoGbkResult>>, Void>) input -> {
        assertEquals("Wrong size of the output PCollection", 2, Iterables.size(input));
        for (KV<Integer, CoGbkResult> element : input) {
            if (element.getKey() == 1) {
                Iterable<Integer> input1Elements = element.getValue().getAll(INPUT1_TAG);
                assertEquals("Wrong number of values for output elements for tag input1 and key 1", 3, Iterables.size(input1Elements));
                assertThat("Elements of PCollection input1 for key \"1\" are not present in the output PCollection", input1Elements, containsInAnyOrder(1, 2, 3));
                Iterable<Integer> input2Elements = element.getValue().getAll(INPUT2_TAG);
                assertEquals("Wrong number of values for output elements for tag input2 and key 1", 3, Iterables.size(input2Elements));
                assertThat("Elements of PCollection input2 for key \"1\" are not present in the output PCollection", input2Elements, containsInAnyOrder(11, 12, 13));
            } else if (element.getKey() == 2) {
                Iterable<Integer> input1Elements = element.getValue().getAll(INPUT1_TAG);
                assertEquals("Wrong number of values for output elements for tag input1 and key 2", 3, Iterables.size(input1Elements));
                assertThat("Elements of PCollection input1 for key \"2\" are not present in the output PCollection", input1Elements, containsInAnyOrder(4, 5, 6));
                Iterable<Integer> input2Elements = element.getValue().getAll(INPUT2_TAG);
                assertEquals("Wrong number of values for output elements for tag input2 and key 2", 3, Iterables.size(input2Elements));
                assertThat("Elements of PCollection input2 for key \"2\" are not present in the output PCollection", input2Elements, containsInAnyOrder(14, 15, 16));
            } else {
                fail("Unknown key in the output PCollection");
            }
        }
        return null;
    });
    pipeline.run();
}
Also used : KV(org.apache.beam.sdk.values.KV) StreamingTest(org.apache.beam.runners.spark.StreamingTest) Duration(org.joda.time.Duration) SerializableFunction(org.apache.beam.sdk.transforms.SerializableFunction) ReuseSparkContextRule(org.apache.beam.runners.spark.ReuseSparkContextRule) CoGbkResult(org.apache.beam.sdk.transforms.join.CoGbkResult) TupleTag(org.apache.beam.sdk.values.TupleTag) CreateStream(org.apache.beam.runners.spark.io.CreateStream) TestPipeline(org.apache.beam.sdk.testing.TestPipeline) Iterables(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables) Window(org.apache.beam.sdk.transforms.windowing.Window) Assert.fail(org.junit.Assert.fail) MatcherAssert.assertThat(org.hamcrest.MatcherAssert.assertThat) KeyedPCollectionTuple(org.apache.beam.sdk.transforms.join.KeyedPCollectionTuple) TimestampedValue(org.apache.beam.sdk.values.TimestampedValue) KvCoder(org.apache.beam.sdk.coders.KvCoder) PAssert(org.apache.beam.sdk.testing.PAssert) FixedWindows(org.apache.beam.sdk.transforms.windowing.FixedWindows) Test(org.junit.Test) PCollection(org.apache.beam.sdk.values.PCollection) Category(org.junit.experimental.categories.Category) CoGroupByKey(org.apache.beam.sdk.transforms.join.CoGroupByKey) Rule(org.junit.Rule) Matchers.containsInAnyOrder(org.hamcrest.Matchers.containsInAnyOrder) Instant(org.joda.time.Instant) VarIntCoder(org.apache.beam.sdk.coders.VarIntCoder) SparkPipelineOptions(org.apache.beam.runners.spark.SparkPipelineOptions) Assert.assertEquals(org.junit.Assert.assertEquals) Instant(org.joda.time.Instant) KV(org.apache.beam.sdk.values.KV) CoGbkResult(org.apache.beam.sdk.transforms.join.CoGbkResult) Category(org.junit.experimental.categories.Category) StreamingTest(org.apache.beam.runners.spark.StreamingTest) Test(org.junit.Test)

Example 2 with CoGroupByKey

use of org.apache.beam.sdk.transforms.join.CoGroupByKey in project component-runtime by Talend.

the class BeamExecutor method run.

@Override
public void run() {
    try {
        final Map<String, Mapper> mappers = delegate.getLevels().values().stream().flatMap(Collection::stream).filter(Job.Component::isSource).collect(toMap(Job.Component::getId, e -> delegate.getManager().findMapper(e.getNode().getFamily(), e.getNode().getComponent(), e.getNode().getVersion(), e.getNode().getConfiguration()).orElseThrow(() -> new IllegalStateException("No mapper found for: " + e.getNode()))));
        final Map<String, Processor> processors = delegate.getLevels().values().stream().flatMap(Collection::stream).filter(component -> !component.isSource()).collect(toMap(Job.Component::getId, e -> delegate.getManager().findProcessor(e.getNode().getFamily(), e.getNode().getComponent(), e.getNode().getVersion(), e.getNode().getConfiguration()).orElseThrow(() -> new IllegalStateException("No processor found for:" + e.getNode()))));
        final Pipeline pipeline = Pipeline.create(createPipelineOptions());
        final Map<String, PCollection<JsonObject>> pCollections = new HashMap<>();
        delegate.getLevels().values().stream().flatMap(Collection::stream).forEach(component -> {
            if (component.isSource()) {
                final Mapper mapper = mappers.get(component.getId());
                pCollections.put(component.getId(), pipeline.apply(toName("TalendIO", component), TalendIO.read(mapper)).apply(toName("RecordNormalizer", component), RecordNormalizer.of(mapper.plugin())));
            } else {
                final Processor processor = processors.get(component.getId());
                final List<Job.Edge> joins = getEdges(delegate.getEdges(), component, e -> e.getTo().getNode());
                final Map<String, PCollection<KV<String, JsonObject>>> inputs = joins.stream().collect(toMap(e -> e.getTo().getBranch(), e -> {
                    final PCollection<JsonObject> pc = pCollections.get(e.getFrom().getNode().getId());
                    final PCollection<JsonObject> filteredInput = pc.apply(toName("RecordBranchFilter", component, e), RecordBranchFilter.of(processor.plugin(), e.getFrom().getBranch()));
                    final PCollection<JsonObject> mappedInput;
                    if (e.getFrom().getBranch().equals(e.getTo().getBranch())) {
                        mappedInput = filteredInput;
                    } else {
                        mappedInput = filteredInput.apply(toName("RecordBranchMapper", component, e), RecordBranchMapper.of(processor.plugin(), e.getFrom().getBranch(), e.getTo().getBranch()));
                    }
                    return mappedInput.apply(toName("RecordBranchUnwrapper", component, e), RecordBranchUnwrapper.of(processor.plugin(), e.getTo().getBranch())).apply(toName("AutoKVWrapper", component, e), AutoKVWrapper.of(processor.plugin(), delegate.getKeyProvider(component.getId()), component.getId(), e.getFrom().getBranch()));
                }));
                KeyedPCollectionTuple<String> join = null;
                for (final Map.Entry<String, PCollection<KV<String, JsonObject>>> entry : inputs.entrySet()) {
                    final TupleTag<JsonObject> branch = new TupleTag<>(entry.getKey());
                    join = join == null ? KeyedPCollectionTuple.of(branch, entry.getValue()) : join.and(branch, entry.getValue());
                }
                final PCollection<JsonObject> preparedInput = join.apply(toName("CoGroupByKey", component), CoGroupByKey.create()).apply(toName("CoGroupByKeyResultMappingTransform", component), new CoGroupByKeyResultMappingTransform<>(processor.plugin(), true));
                if (getEdges(delegate.getEdges(), component, e -> e.getFrom().getNode()).isEmpty()) {
                    final PTransform<PCollection<JsonObject>, PDone> write = TalendIO.write(processor);
                    preparedInput.apply(toName("Output", component), write);
                } else {
                    final PTransform<PCollection<JsonObject>, PCollection<JsonObject>> process = TalendFn.asFn(processor);
                    pCollections.put(component.getId(), preparedInput.apply(toName("Processor", component), process));
                }
            }
        });
        final PipelineResult result = pipeline.run();
        // the wait until finish don't wait for the job to complete on the direct runner
        result.waitUntilFinish();
        while (PipelineResult.State.RUNNING.equals(result.getState())) {
            try {
                Thread.sleep(100L);
            } catch (final InterruptedException e) {
                throw new IllegalStateException("the job was aborted", e);
            }
        }
    } finally {
        delegate.getLevels().values().stream().flatMap(Collection::stream).map(Job.Component::getId).forEach(JobImpl.LocalSequenceHolder::clean);
    }
}
Also used : TalendIO(org.talend.sdk.component.runtime.beam.TalendIO) KV(org.apache.beam.sdk.values.KV) PipelineResult(org.apache.beam.sdk.PipelineResult) RecordBranchFilter(org.talend.sdk.component.runtime.beam.transform.RecordBranchFilter) HashMap(java.util.HashMap) PipelineOptionsFactory(org.apache.beam.sdk.options.PipelineOptionsFactory) Function(java.util.function.Function) PTransform(org.apache.beam.sdk.transforms.PTransform) RecordBranchMapper(org.talend.sdk.component.runtime.beam.transform.RecordBranchMapper) Collectors.toMap(java.util.stream.Collectors.toMap) TupleTag(org.apache.beam.sdk.values.TupleTag) Map(java.util.Map) RecordNormalizer(org.talend.sdk.component.runtime.beam.transform.RecordNormalizer) Pipeline(org.apache.beam.sdk.Pipeline) KeyedPCollectionTuple(org.apache.beam.sdk.transforms.join.KeyedPCollectionTuple) PipelineOptions(org.apache.beam.sdk.options.PipelineOptions) JsonObject(javax.json.JsonObject) PDone(org.apache.beam.sdk.values.PDone) Collection(java.util.Collection) PCollection(org.apache.beam.sdk.values.PCollection) Processor(org.talend.sdk.component.runtime.output.Processor) RecordBranchUnwrapper(org.talend.sdk.component.runtime.beam.transform.RecordBranchUnwrapper) CoGroupByKey(org.apache.beam.sdk.transforms.join.CoGroupByKey) AutoKVWrapper(org.talend.sdk.component.runtime.beam.transform.AutoKVWrapper) Collectors.toList(java.util.stream.Collectors.toList) List(java.util.List) Mapper(org.talend.sdk.component.runtime.input.Mapper) CoGroupByKeyResultMappingTransform(org.talend.sdk.component.runtime.beam.transform.CoGroupByKeyResultMappingTransform) Job(org.talend.sdk.component.runtime.manager.chain.Job) JobImpl(org.talend.sdk.component.runtime.manager.chain.internal.JobImpl) AllArgsConstructor(lombok.AllArgsConstructor) TalendFn(org.talend.sdk.component.runtime.beam.TalendFn) Processor(org.talend.sdk.component.runtime.output.Processor) HashMap(java.util.HashMap) JsonObject(javax.json.JsonObject) TupleTag(org.apache.beam.sdk.values.TupleTag) RecordBranchMapper(org.talend.sdk.component.runtime.beam.transform.RecordBranchMapper) Mapper(org.talend.sdk.component.runtime.input.Mapper) Job(org.talend.sdk.component.runtime.manager.chain.Job) PipelineResult(org.apache.beam.sdk.PipelineResult) Pipeline(org.apache.beam.sdk.Pipeline) PCollection(org.apache.beam.sdk.values.PCollection) PDone(org.apache.beam.sdk.values.PDone) Collection(java.util.Collection) PCollection(org.apache.beam.sdk.values.PCollection) HashMap(java.util.HashMap) Collectors.toMap(java.util.stream.Collectors.toMap) Map(java.util.Map)

Aggregations

CoGroupByKey (org.apache.beam.sdk.transforms.join.CoGroupByKey)2 KeyedPCollectionTuple (org.apache.beam.sdk.transforms.join.KeyedPCollectionTuple)2 KV (org.apache.beam.sdk.values.KV)2 PCollection (org.apache.beam.sdk.values.PCollection)2 TupleTag (org.apache.beam.sdk.values.TupleTag)2 Collection (java.util.Collection)1 HashMap (java.util.HashMap)1 List (java.util.List)1 Map (java.util.Map)1 Function (java.util.function.Function)1 Collectors.toList (java.util.stream.Collectors.toList)1 Collectors.toMap (java.util.stream.Collectors.toMap)1 JsonObject (javax.json.JsonObject)1 AllArgsConstructor (lombok.AllArgsConstructor)1 ReuseSparkContextRule (org.apache.beam.runners.spark.ReuseSparkContextRule)1 SparkPipelineOptions (org.apache.beam.runners.spark.SparkPipelineOptions)1 StreamingTest (org.apache.beam.runners.spark.StreamingTest)1 CreateStream (org.apache.beam.runners.spark.io.CreateStream)1 Pipeline (org.apache.beam.sdk.Pipeline)1 PipelineResult (org.apache.beam.sdk.PipelineResult)1