use of org.apache.flink.streaming.api.operators.python.AbstractDataStreamPythonFunctionOperator in project flink by apache.
the class PythonOperatorChainingOptimizer method createChainedTransformation.
private static Transformation<?> createChainedTransformation(Transformation<?> upTransform, Transformation<?> downTransform) {
final AbstractDataStreamPythonFunctionOperator<?> upOperator = (AbstractDataStreamPythonFunctionOperator<?>) ((SimpleOperatorFactory<?>) getOperatorFactory(upTransform)).getOperator();
final PythonProcessOperator<?, ?> downOperator = (PythonProcessOperator<?, ?>) ((SimpleOperatorFactory<?>) getOperatorFactory(downTransform)).getOperator();
final DataStreamPythonFunctionInfo upPythonFunctionInfo = upOperator.getPythonFunctionInfo().copy();
final DataStreamPythonFunctionInfo downPythonFunctionInfo = downOperator.getPythonFunctionInfo().copy();
DataStreamPythonFunctionInfo headPythonFunctionInfoOfDownOperator = downPythonFunctionInfo;
while (headPythonFunctionInfoOfDownOperator.getInputs().length != 0) {
headPythonFunctionInfoOfDownOperator = (DataStreamPythonFunctionInfo) headPythonFunctionInfoOfDownOperator.getInputs()[0];
}
headPythonFunctionInfoOfDownOperator.setInputs(new DataStreamPythonFunctionInfo[] { upPythonFunctionInfo });
final AbstractDataStreamPythonFunctionOperator<?> chainedOperator = upOperator.copy(downPythonFunctionInfo, downOperator.getProducedType());
// set partition custom flag
chainedOperator.setContainsPartitionCustom(downOperator.containsPartitionCustom() || upOperator.containsPartitionCustom());
PhysicalTransformation<?> chainedTransformation;
if (upOperator instanceof AbstractOneInputPythonFunctionOperator) {
chainedTransformation = new OneInputTransformation(upTransform.getInputs().get(0), upTransform.getName() + ", " + downTransform.getName(), (OneInputStreamOperator<?, ?>) chainedOperator, downTransform.getOutputType(), upTransform.getParallelism());
((OneInputTransformation<?, ?>) chainedTransformation).setStateKeySelector(((OneInputTransformation) upTransform).getStateKeySelector());
((OneInputTransformation<?, ?>) chainedTransformation).setStateKeyType(((OneInputTransformation<?, ?>) upTransform).getStateKeyType());
} else {
chainedTransformation = new TwoInputTransformation(upTransform.getInputs().get(0), upTransform.getInputs().get(1), upTransform.getName() + ", " + downTransform.getName(), (TwoInputStreamOperator<?, ?, ?>) chainedOperator, downTransform.getOutputType(), upTransform.getParallelism());
((TwoInputTransformation<?, ?, ?>) chainedTransformation).setStateKeySelectors(((TwoInputTransformation) upTransform).getStateKeySelector1(), ((TwoInputTransformation) upTransform).getStateKeySelector2());
((TwoInputTransformation<?, ?, ?>) chainedTransformation).setStateKeyType(((TwoInputTransformation<?, ?, ?>) upTransform).getStateKeyType());
}
chainedTransformation.setUid(upTransform.getUid());
if (upTransform.getUserProvidedNodeHash() != null) {
chainedTransformation.setUidHash(upTransform.getUserProvidedNodeHash());
}
for (ManagedMemoryUseCase useCase : upTransform.getManagedMemorySlotScopeUseCases()) {
chainedTransformation.declareManagedMemoryUseCaseAtSlotScope(useCase);
}
for (ManagedMemoryUseCase useCase : downTransform.getManagedMemorySlotScopeUseCases()) {
chainedTransformation.declareManagedMemoryUseCaseAtSlotScope(useCase);
}
for (Map.Entry<ManagedMemoryUseCase, Integer> useCase : upTransform.getManagedMemoryOperatorScopeUseCaseWeights().entrySet()) {
chainedTransformation.declareManagedMemoryUseCaseAtOperatorScope(useCase.getKey(), useCase.getValue());
}
for (Map.Entry<ManagedMemoryUseCase, Integer> useCase : downTransform.getManagedMemoryOperatorScopeUseCaseWeights().entrySet()) {
chainedTransformation.declareManagedMemoryUseCaseAtOperatorScope(useCase.getKey(), useCase.getValue() + chainedTransformation.getManagedMemoryOperatorScopeUseCaseWeights().getOrDefault(useCase.getKey(), 0));
}
chainedTransformation.setBufferTimeout(Math.min(upTransform.getBufferTimeout(), downTransform.getBufferTimeout()));
if (upTransform.getMaxParallelism() > 0) {
chainedTransformation.setMaxParallelism(upTransform.getMaxParallelism());
}
chainedTransformation.setChainingStrategy(getOperatorFactory(upTransform).getChainingStrategy());
chainedTransformation.setCoLocationGroupKey(upTransform.getCoLocationGroupKey());
chainedTransformation.setResources(upTransform.getMinResources().merge(downTransform.getMinResources()), upTransform.getPreferredResources().merge(downTransform.getPreferredResources()));
if (upTransform.getSlotSharingGroup().isPresent()) {
chainedTransformation.setSlotSharingGroup(upTransform.getSlotSharingGroup().get());
}
if (upTransform.getDescription() != null && downTransform.getDescription() != null) {
chainedTransformation.setDescription(upTransform.getDescription() + ", " + downTransform.getDescription());
} else if (upTransform.getDescription() != null) {
chainedTransformation.setDescription(upTransform.getDescription());
} else if (downTransform.getDescription() != null) {
chainedTransformation.setDescription(downTransform.getDescription());
}
return chainedTransformation;
}
Aggregations