use of org.apache.flink.streaming.api.transformations.TwoInputTransformation in project flink by apache.
the class PythonOperatorChainingOptimizer method createChainedTransformation.
private static Transformation<?> createChainedTransformation(Transformation<?> upTransform, Transformation<?> downTransform) {
final AbstractDataStreamPythonFunctionOperator<?> upOperator = (AbstractDataStreamPythonFunctionOperator<?>) ((SimpleOperatorFactory<?>) getOperatorFactory(upTransform)).getOperator();
final PythonProcessOperator<?, ?> downOperator = (PythonProcessOperator<?, ?>) ((SimpleOperatorFactory<?>) getOperatorFactory(downTransform)).getOperator();
final DataStreamPythonFunctionInfo upPythonFunctionInfo = upOperator.getPythonFunctionInfo().copy();
final DataStreamPythonFunctionInfo downPythonFunctionInfo = downOperator.getPythonFunctionInfo().copy();
DataStreamPythonFunctionInfo headPythonFunctionInfoOfDownOperator = downPythonFunctionInfo;
while (headPythonFunctionInfoOfDownOperator.getInputs().length != 0) {
headPythonFunctionInfoOfDownOperator = (DataStreamPythonFunctionInfo) headPythonFunctionInfoOfDownOperator.getInputs()[0];
}
headPythonFunctionInfoOfDownOperator.setInputs(new DataStreamPythonFunctionInfo[] { upPythonFunctionInfo });
final AbstractDataStreamPythonFunctionOperator<?> chainedOperator = upOperator.copy(downPythonFunctionInfo, downOperator.getProducedType());
// set partition custom flag
chainedOperator.setContainsPartitionCustom(downOperator.containsPartitionCustom() || upOperator.containsPartitionCustom());
PhysicalTransformation<?> chainedTransformation;
if (upOperator instanceof AbstractOneInputPythonFunctionOperator) {
chainedTransformation = new OneInputTransformation(upTransform.getInputs().get(0), upTransform.getName() + ", " + downTransform.getName(), (OneInputStreamOperator<?, ?>) chainedOperator, downTransform.getOutputType(), upTransform.getParallelism());
((OneInputTransformation<?, ?>) chainedTransformation).setStateKeySelector(((OneInputTransformation) upTransform).getStateKeySelector());
((OneInputTransformation<?, ?>) chainedTransformation).setStateKeyType(((OneInputTransformation<?, ?>) upTransform).getStateKeyType());
} else {
chainedTransformation = new TwoInputTransformation(upTransform.getInputs().get(0), upTransform.getInputs().get(1), upTransform.getName() + ", " + downTransform.getName(), (TwoInputStreamOperator<?, ?, ?>) chainedOperator, downTransform.getOutputType(), upTransform.getParallelism());
((TwoInputTransformation<?, ?, ?>) chainedTransformation).setStateKeySelectors(((TwoInputTransformation) upTransform).getStateKeySelector1(), ((TwoInputTransformation) upTransform).getStateKeySelector2());
((TwoInputTransformation<?, ?, ?>) chainedTransformation).setStateKeyType(((TwoInputTransformation<?, ?, ?>) upTransform).getStateKeyType());
}
chainedTransformation.setUid(upTransform.getUid());
if (upTransform.getUserProvidedNodeHash() != null) {
chainedTransformation.setUidHash(upTransform.getUserProvidedNodeHash());
}
for (ManagedMemoryUseCase useCase : upTransform.getManagedMemorySlotScopeUseCases()) {
chainedTransformation.declareManagedMemoryUseCaseAtSlotScope(useCase);
}
for (ManagedMemoryUseCase useCase : downTransform.getManagedMemorySlotScopeUseCases()) {
chainedTransformation.declareManagedMemoryUseCaseAtSlotScope(useCase);
}
for (Map.Entry<ManagedMemoryUseCase, Integer> useCase : upTransform.getManagedMemoryOperatorScopeUseCaseWeights().entrySet()) {
chainedTransformation.declareManagedMemoryUseCaseAtOperatorScope(useCase.getKey(), useCase.getValue());
}
for (Map.Entry<ManagedMemoryUseCase, Integer> useCase : downTransform.getManagedMemoryOperatorScopeUseCaseWeights().entrySet()) {
chainedTransformation.declareManagedMemoryUseCaseAtOperatorScope(useCase.getKey(), useCase.getValue() + chainedTransformation.getManagedMemoryOperatorScopeUseCaseWeights().getOrDefault(useCase.getKey(), 0));
}
chainedTransformation.setBufferTimeout(Math.min(upTransform.getBufferTimeout(), downTransform.getBufferTimeout()));
if (upTransform.getMaxParallelism() > 0) {
chainedTransformation.setMaxParallelism(upTransform.getMaxParallelism());
}
chainedTransformation.setChainingStrategy(getOperatorFactory(upTransform).getChainingStrategy());
chainedTransformation.setCoLocationGroupKey(upTransform.getCoLocationGroupKey());
chainedTransformation.setResources(upTransform.getMinResources().merge(downTransform.getMinResources()), upTransform.getPreferredResources().merge(downTransform.getPreferredResources()));
if (upTransform.getSlotSharingGroup().isPresent()) {
chainedTransformation.setSlotSharingGroup(upTransform.getSlotSharingGroup().get());
}
if (upTransform.getDescription() != null && downTransform.getDescription() != null) {
chainedTransformation.setDescription(upTransform.getDescription() + ", " + downTransform.getDescription());
} else if (upTransform.getDescription() != null) {
chainedTransformation.setDescription(upTransform.getDescription());
} else if (downTransform.getDescription() != null) {
chainedTransformation.setDescription(downTransform.getDescription());
}
return chainedTransformation;
}
use of org.apache.flink.streaming.api.transformations.TwoInputTransformation in project flink by apache.
the class DataStreamBatchExecutionITCase method batchMixedKeyedAndNonKeyedTwoInputOperator.
@Test
public void batchMixedKeyedAndNonKeyedTwoInputOperator() throws Exception {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(1);
env.setRuntimeMode(RuntimeExecutionMode.BATCH);
DataStream<Tuple2<String, Integer>> bcInput = env.fromElements(Tuple2.of("bc3", 3), Tuple2.of("bc2", 2), Tuple2.of("bc1", 1)).assignTimestampsAndWatermarks(WatermarkStrategy.<Tuple2<String, Integer>>forMonotonousTimestamps().withTimestampAssigner((in, ts) -> in.f1)).broadcast();
DataStream<Tuple2<String, Integer>> regularInput = env.fromElements(Tuple2.of("regular1", 1), Tuple2.of("regular1", 2), Tuple2.of("regular1", 3), Tuple2.of("regular1", 4), Tuple2.of("regular2", 3), Tuple2.of("regular2", 5), Tuple2.of("regular1", 3)).assignTimestampsAndWatermarks(WatermarkStrategy.<Tuple2<String, Integer>>forMonotonousTimestamps().withTimestampAssigner((in, ts) -> in.f1)).keyBy(input -> input.f0);
TwoInputTransformation<Tuple2<String, Integer>, Tuple2<String, Integer>, String> twoInputTransformation = new TwoInputTransformation<>(regularInput.getTransformation(), bcInput.getTransformation(), "operator", new TestMixedTwoInputOperator(), BasicTypeInfo.STRING_TYPE_INFO, 1);
twoInputTransformation.setStateKeyType(BasicTypeInfo.STRING_TYPE_INFO);
twoInputTransformation.setStateKeySelectors(input -> input.f0, null);
DataStream<String> result = new DataStream<>(env, twoInputTransformation);
try (CloseableIterator<String> resultIterator = result.executeAndCollect()) {
List<String> results = CollectionUtil.iteratorToList(resultIterator);
assertThat(results, equalTo(Arrays.asList("(regular1,1): [bc3, bc2, bc1]", "(regular1,2): [bc3, bc2, bc1]", "(regular1,3): [bc3, bc2, bc1]", "(regular1,3): [bc3, bc2, bc1]", "(regular1,4): [bc3, bc2, bc1]", "(regular2,3): [bc3, bc2, bc1]", "(regular2,5): [bc3, bc2, bc1]")));
}
}
use of org.apache.flink.streaming.api.transformations.TwoInputTransformation in project beam by apache.
the class FlinkStreamingPortablePipelineTranslator method translateExecutableStage.
private <InputT, OutputT> void translateExecutableStage(String id, RunnerApi.Pipeline pipeline, StreamingTranslationContext context) {
// TODO: Fail on splittable DoFns.
// TODO: Special-case single outputs to avoid multiplexing PCollections.
RunnerApi.Components components = pipeline.getComponents();
RunnerApi.PTransform transform = components.getTransformsOrThrow(id);
Map<String, String> outputs = transform.getOutputsMap();
final RunnerApi.ExecutableStagePayload stagePayload;
try {
stagePayload = RunnerApi.ExecutableStagePayload.parseFrom(transform.getSpec().getPayload());
} catch (IOException e) {
throw new RuntimeException(e);
}
String inputPCollectionId = stagePayload.getInput();
final TransformedSideInputs transformedSideInputs;
if (stagePayload.getSideInputsCount() > 0) {
transformedSideInputs = transformSideInputs(stagePayload, components, context);
} else {
transformedSideInputs = new TransformedSideInputs(Collections.emptyMap(), null);
}
Map<TupleTag<?>, OutputTag<WindowedValue<?>>> tagsToOutputTags = Maps.newLinkedHashMap();
Map<TupleTag<?>, Coder<WindowedValue<?>>> tagsToCoders = Maps.newLinkedHashMap();
// TODO: does it matter which output we designate as "main"
final TupleTag<OutputT> mainOutputTag = outputs.isEmpty() ? null : new TupleTag(outputs.keySet().iterator().next());
// associate output tags with ids, output manager uses these Integer ids to serialize state
BiMap<String, Integer> outputIndexMap = createOutputMap(outputs.keySet());
Map<String, Coder<WindowedValue<?>>> outputCoders = Maps.newHashMap();
Map<TupleTag<?>, Integer> tagsToIds = Maps.newHashMap();
Map<String, TupleTag<?>> collectionIdToTupleTag = Maps.newHashMap();
// order output names for deterministic mapping
for (String localOutputName : new TreeMap<>(outputIndexMap).keySet()) {
String collectionId = outputs.get(localOutputName);
Coder<WindowedValue<?>> windowCoder = (Coder) instantiateCoder(collectionId, components);
outputCoders.put(localOutputName, windowCoder);
TupleTag<?> tupleTag = new TupleTag<>(localOutputName);
CoderTypeInformation<WindowedValue<?>> typeInformation = new CoderTypeInformation(windowCoder, context.getPipelineOptions());
tagsToOutputTags.put(tupleTag, new OutputTag<>(localOutputName, typeInformation));
tagsToCoders.put(tupleTag, windowCoder);
tagsToIds.put(tupleTag, outputIndexMap.get(localOutputName));
collectionIdToTupleTag.put(collectionId, tupleTag);
}
final SingleOutputStreamOperator<WindowedValue<OutputT>> outputStream;
DataStream<WindowedValue<InputT>> inputDataStream = context.getDataStreamOrThrow(inputPCollectionId);
CoderTypeInformation<WindowedValue<OutputT>> outputTypeInformation = !outputs.isEmpty() ? new CoderTypeInformation(outputCoders.get(mainOutputTag.getId()), context.getPipelineOptions()) : null;
ArrayList<TupleTag<?>> additionalOutputTags = Lists.newArrayList();
for (TupleTag<?> tupleTag : tagsToCoders.keySet()) {
if (!mainOutputTag.getId().equals(tupleTag.getId())) {
additionalOutputTags.add(tupleTag);
}
}
final Coder<WindowedValue<InputT>> windowedInputCoder = instantiateCoder(inputPCollectionId, components);
final boolean stateful = stagePayload.getUserStatesCount() > 0 || stagePayload.getTimersCount() > 0;
final boolean hasSdfProcessFn = stagePayload.getComponents().getTransformsMap().values().stream().anyMatch(pTransform -> pTransform.getSpec().getUrn().equals(PTransformTranslation.SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN));
Coder keyCoder = null;
KeySelector<WindowedValue<InputT>, ?> keySelector = null;
if (stateful || hasSdfProcessFn) {
// Stateful/SDF stages are only allowed of KV input.
Coder valueCoder = ((WindowedValue.FullWindowedValueCoder) windowedInputCoder).getValueCoder();
if (!(valueCoder instanceof KvCoder)) {
throw new IllegalStateException(String.format(Locale.ENGLISH, "The element coder for stateful DoFn '%s' must be KvCoder but is: %s", inputPCollectionId, valueCoder.getClass().getSimpleName()));
}
if (stateful) {
keyCoder = ((KvCoder) valueCoder).getKeyCoder();
keySelector = new KvToByteBufferKeySelector(keyCoder, new SerializablePipelineOptions(context.getPipelineOptions()));
} else {
// as the key.
if (!(((KvCoder) valueCoder).getKeyCoder() instanceof KvCoder)) {
throw new IllegalStateException(String.format(Locale.ENGLISH, "The element coder for splittable DoFn '%s' must be KVCoder(KvCoder, DoubleCoder) but is: %s", inputPCollectionId, valueCoder.getClass().getSimpleName()));
}
keyCoder = ((KvCoder) ((KvCoder) valueCoder).getKeyCoder()).getKeyCoder();
keySelector = new SdfByteBufferKeySelector(keyCoder, new SerializablePipelineOptions(context.getPipelineOptions()));
}
inputDataStream = inputDataStream.keyBy(keySelector);
}
DoFnOperator.MultiOutputOutputManagerFactory<OutputT> outputManagerFactory = new DoFnOperator.MultiOutputOutputManagerFactory<>(mainOutputTag, tagsToOutputTags, tagsToCoders, tagsToIds, new SerializablePipelineOptions(context.getPipelineOptions()));
DoFnOperator<InputT, OutputT> doFnOperator = new ExecutableStageDoFnOperator<>(transform.getUniqueName(), windowedInputCoder, Collections.emptyMap(), mainOutputTag, additionalOutputTags, outputManagerFactory, transformedSideInputs.unionTagToView, new ArrayList<>(transformedSideInputs.unionTagToView.values()), getSideInputIdToPCollectionViewMap(stagePayload, components), context.getPipelineOptions(), stagePayload, context.getJobInfo(), FlinkExecutableStageContextFactory.getInstance(), collectionIdToTupleTag, getWindowingStrategy(inputPCollectionId, components), keyCoder, keySelector);
final String operatorName = generateNameFromStagePayload(stagePayload);
if (transformedSideInputs.unionTagToView.isEmpty()) {
outputStream = inputDataStream.transform(operatorName, outputTypeInformation, doFnOperator);
} else {
DataStream<RawUnionValue> sideInputStream = transformedSideInputs.unionedSideInputs.broadcast();
if (stateful || hasSdfProcessFn) {
// We have to manually construct the two-input transform because we're not
// allowed to have only one input keyed, normally. Since Flink 1.5.0 it's
// possible to use the Broadcast State Pattern which provides a more elegant
// way to process keyed main input with broadcast state, but it's not feasible
// here because it breaks the DoFnOperator abstraction.
TwoInputTransformation<WindowedValue<KV<?, InputT>>, RawUnionValue, WindowedValue<OutputT>> rawFlinkTransform = new TwoInputTransformation(inputDataStream.getTransformation(), sideInputStream.getTransformation(), transform.getUniqueName(), doFnOperator, outputTypeInformation, inputDataStream.getParallelism());
rawFlinkTransform.setStateKeyType(((KeyedStream) inputDataStream).getKeyType());
rawFlinkTransform.setStateKeySelectors(((KeyedStream) inputDataStream).getKeySelector(), null);
outputStream = new SingleOutputStreamOperator(inputDataStream.getExecutionEnvironment(), // we have to cheat around the ctor being protected
rawFlinkTransform) {
};
} else {
outputStream = inputDataStream.connect(sideInputStream).transform(operatorName, outputTypeInformation, doFnOperator);
}
}
// Assign a unique but consistent id to re-map operator state
outputStream.uid(transform.getUniqueName());
if (mainOutputTag != null) {
context.addDataStream(outputs.get(mainOutputTag.getId()), outputStream);
}
for (TupleTag<?> tupleTag : additionalOutputTags) {
context.addDataStream(outputs.get(tupleTag.getId()), outputStream.getSideOutput(tagsToOutputTags.get(tupleTag)));
}
}
use of org.apache.flink.streaming.api.transformations.TwoInputTransformation in project flink by apache.
the class PythonOperatorChainingOptimizerTest method testChainingTwoInputOperators.
@Test
public void testChainingTwoInputOperators() {
PythonKeyedCoProcessOperator<?> keyedCoProcessOperator1 = createCoKeyedProcessOperator("f1", new RowTypeInfo(Types.INT(), Types.STRING()), new RowTypeInfo(Types.INT(), Types.INT()), Types.STRING());
PythonProcessOperator<?, ?> processOperator1 = createProcessOperator("f2", new RowTypeInfo(Types.INT(), Types.INT()), Types.STRING());
PythonProcessOperator<?, ?> processOperator2 = createProcessOperator("f3", new RowTypeInfo(Types.INT(), Types.INT()), Types.LONG());
PythonKeyedProcessOperator<?> keyedProcessOperator2 = createKeyedProcessOperator("f4", new RowTypeInfo(Types.INT(), Types.INT()), Types.STRING());
PythonProcessOperator<?, ?> processOperator3 = createProcessOperator("f5", new RowTypeInfo(Types.INT(), Types.INT()), Types.STRING());
Transformation<?> sourceTransformation1 = mock(SourceTransformation.class);
Transformation<?> sourceTransformation2 = mock(SourceTransformation.class);
TwoInputTransformation<?, ?, ?> keyedCoProcessTransformation = new TwoInputTransformation(sourceTransformation1, sourceTransformation2, "keyedCoProcess", keyedCoProcessOperator1, keyedCoProcessOperator1.getProducedType(), 2);
Transformation<?> processTransformation1 = new OneInputTransformation(keyedCoProcessTransformation, "process", processOperator1, processOperator1.getProducedType(), 2);
Transformation<?> processTransformation2 = new OneInputTransformation(processTransformation1, "process", processOperator2, processOperator2.getProducedType(), 2);
OneInputTransformation<?, ?> keyedProcessTransformation = new OneInputTransformation(processTransformation2, "keyedProcess", keyedProcessOperator2, keyedProcessOperator2.getProducedType(), 2);
Transformation<?> processTransformation3 = new OneInputTransformation(keyedProcessTransformation, "process", processOperator3, processOperator3.getProducedType(), 2);
List<Transformation<?>> transformations = new ArrayList<>();
transformations.add(sourceTransformation1);
transformations.add(sourceTransformation2);
transformations.add(keyedCoProcessTransformation);
transformations.add(processTransformation1);
transformations.add(processTransformation2);
transformations.add(keyedProcessTransformation);
transformations.add(processTransformation3);
List<Transformation<?>> optimized = PythonOperatorChainingOptimizer.optimize(transformations);
assertEquals(4, optimized.size());
TwoInputTransformation<?, ?, ?> chainedTransformation1 = (TwoInputTransformation<?, ?, ?>) optimized.get(2);
assertEquals(sourceTransformation1.getOutputType(), chainedTransformation1.getInputType1());
assertEquals(sourceTransformation2.getOutputType(), chainedTransformation1.getInputType2());
assertEquals(processOperator2.getProducedType(), chainedTransformation1.getOutputType());
OneInputTransformation<?, ?> chainedTransformation2 = (OneInputTransformation<?, ?>) optimized.get(3);
assertEquals(processOperator2.getProducedType(), chainedTransformation2.getInputType());
assertEquals(processOperator3.getProducedType(), chainedTransformation2.getOutputType());
TwoInputStreamOperator<?, ?, ?> chainedOperator1 = chainedTransformation1.getOperator();
assertTrue(chainedOperator1 instanceof PythonKeyedCoProcessOperator);
validateChainedPythonFunctions(((PythonKeyedCoProcessOperator<?>) chainedOperator1).getPythonFunctionInfo(), "f3", "f2", "f1");
OneInputStreamOperator<?, ?> chainedOperator2 = chainedTransformation2.getOperator();
assertTrue(chainedOperator2 instanceof PythonKeyedProcessOperator);
validateChainedPythonFunctions(((PythonKeyedProcessOperator<?>) chainedOperator2).getPythonFunctionInfo(), "f5", "f4");
}
use of org.apache.flink.streaming.api.transformations.TwoInputTransformation in project flink by apache.
the class PythonOperatorChainingOptimizer method replaceInput.
private static void replaceInput(Transformation<?> transformation, Transformation<?> oldInput, Transformation<?> newInput) {
try {
if (transformation instanceof OneInputTransformation || transformation instanceof FeedbackTransformation || transformation instanceof SideOutputTransformation || transformation instanceof ReduceTransformation || transformation instanceof SinkTransformation || transformation instanceof LegacySinkTransformation || transformation instanceof TimestampsAndWatermarksTransformation || transformation instanceof PartitionTransformation) {
final Field inputField = transformation.getClass().getDeclaredField("input");
inputField.setAccessible(true);
inputField.set(transformation, newInput);
} else if (transformation instanceof TwoInputTransformation) {
final Field inputField;
if (((TwoInputTransformation<?, ?, ?>) transformation).getInput1() == oldInput) {
inputField = transformation.getClass().getDeclaredField("input1");
} else {
inputField = transformation.getClass().getDeclaredField("input2");
}
inputField.setAccessible(true);
inputField.set(transformation, newInput);
} else if (transformation instanceof UnionTransformation || transformation instanceof AbstractMultipleInputTransformation) {
final Field inputsField = transformation.getClass().getDeclaredField("inputs");
inputsField.setAccessible(true);
List<Transformation<?>> newInputs = Lists.newArrayList();
newInputs.addAll(transformation.getInputs());
newInputs.remove(oldInput);
newInputs.add(newInput);
inputsField.set(transformation, newInputs);
} else if (transformation instanceof AbstractBroadcastStateTransformation) {
final Field inputField;
if (((AbstractBroadcastStateTransformation<?, ?, ?>) transformation).getRegularInput() == oldInput) {
inputField = transformation.getClass().getDeclaredField("regularInput");
} else {
inputField = transformation.getClass().getDeclaredField("broadcastInput");
}
inputField.setAccessible(true);
inputField.set(transformation, newInput);
} else {
throw new RuntimeException("Unsupported transformation: " + transformation);
}
} catch (NoSuchFieldException | IllegalAccessException e) {
// This should never happen
throw new RuntimeException(e);
}
}
Aggregations