use of org.apache.flink.streaming.api.transformations.AbstractMultipleInputTransformation in project flink by apache.
the class PythonOperatorChainingOptimizer method replaceInput.
private static void replaceInput(Transformation<?> transformation, Transformation<?> oldInput, Transformation<?> newInput) {
try {
if (transformation instanceof OneInputTransformation || transformation instanceof FeedbackTransformation || transformation instanceof SideOutputTransformation || transformation instanceof ReduceTransformation || transformation instanceof SinkTransformation || transformation instanceof LegacySinkTransformation || transformation instanceof TimestampsAndWatermarksTransformation || transformation instanceof PartitionTransformation) {
final Field inputField = transformation.getClass().getDeclaredField("input");
inputField.setAccessible(true);
inputField.set(transformation, newInput);
} else if (transformation instanceof TwoInputTransformation) {
final Field inputField;
if (((TwoInputTransformation<?, ?, ?>) transformation).getInput1() == oldInput) {
inputField = transformation.getClass().getDeclaredField("input1");
} else {
inputField = transformation.getClass().getDeclaredField("input2");
}
inputField.setAccessible(true);
inputField.set(transformation, newInput);
} else if (transformation instanceof UnionTransformation || transformation instanceof AbstractMultipleInputTransformation) {
final Field inputsField = transformation.getClass().getDeclaredField("inputs");
inputsField.setAccessible(true);
List<Transformation<?>> newInputs = Lists.newArrayList();
newInputs.addAll(transformation.getInputs());
newInputs.remove(oldInput);
newInputs.add(newInput);
inputsField.set(transformation, newInputs);
} else if (transformation instanceof AbstractBroadcastStateTransformation) {
final Field inputField;
if (((AbstractBroadcastStateTransformation<?, ?, ?>) transformation).getRegularInput() == oldInput) {
inputField = transformation.getClass().getDeclaredField("regularInput");
} else {
inputField = transformation.getClass().getDeclaredField("broadcastInput");
}
inputField.setAccessible(true);
inputField.set(transformation, newInput);
} else {
throw new RuntimeException("Unsupported transformation: " + transformation);
}
} catch (NoSuchFieldException | IllegalAccessException e) {
// This should never happen
throw new RuntimeException(e);
}
}
use of org.apache.flink.streaming.api.transformations.AbstractMultipleInputTransformation in project flink by apache.
the class MultiInputTransformationTranslator method translateInternal.
private Collection<Integer> translateInternal(final AbstractMultipleInputTransformation<OUT> transformation, final Context context) {
checkNotNull(transformation);
checkNotNull(context);
final List<Transformation<?>> inputTransformations = transformation.getInputs();
checkArgument(!inputTransformations.isEmpty(), "Empty inputs for MultipleInputTransformation. Did you forget to add inputs?");
MultipleInputSelectionHandler.checkSupportedInputCount(inputTransformations.size());
final StreamGraph streamGraph = context.getStreamGraph();
final String slotSharingGroup = context.getSlotSharingGroup();
final int transformationId = transformation.getId();
final ExecutionConfig executionConfig = streamGraph.getExecutionConfig();
streamGraph.addMultipleInputOperator(transformationId, slotSharingGroup, transformation.getCoLocationGroupKey(), transformation.getOperatorFactory(), transformation.getInputTypes(), transformation.getOutputType(), transformation.getName());
final int parallelism = transformation.getParallelism() != ExecutionConfig.PARALLELISM_DEFAULT ? transformation.getParallelism() : executionConfig.getParallelism();
streamGraph.setParallelism(transformationId, parallelism);
streamGraph.setMaxParallelism(transformationId, transformation.getMaxParallelism());
if (transformation instanceof KeyedMultipleInputTransformation) {
KeyedMultipleInputTransformation<OUT> keyedTransform = (KeyedMultipleInputTransformation<OUT>) transformation;
TypeSerializer<?> keySerializer = keyedTransform.getStateKeyType().createSerializer(executionConfig);
streamGraph.setMultipleInputStateKey(transformationId, keyedTransform.getStateKeySelectors(), keySerializer);
}
for (int i = 0; i < inputTransformations.size(); i++) {
final Transformation<?> inputTransformation = inputTransformations.get(i);
final Collection<Integer> inputIds = context.getStreamNodeIds(inputTransformation);
for (Integer inputId : inputIds) {
streamGraph.addEdge(inputId, transformationId, i + 1);
}
}
return Collections.singleton(transformationId);
}
Aggregations