Search in sources :

Example 1 with AbstractDataStreamPythonFunctionOperator

use of org.apache.flink.streaming.api.operators.python.AbstractDataStreamPythonFunctionOperator in project flink by apache.

the class PythonOperatorChainingOptimizer method createChainedTransformation.

private static Transformation<?> createChainedTransformation(Transformation<?> upTransform, Transformation<?> downTransform) {
    final AbstractDataStreamPythonFunctionOperator<?> upOperator = (AbstractDataStreamPythonFunctionOperator<?>) ((SimpleOperatorFactory<?>) getOperatorFactory(upTransform)).getOperator();
    final PythonProcessOperator<?, ?> downOperator = (PythonProcessOperator<?, ?>) ((SimpleOperatorFactory<?>) getOperatorFactory(downTransform)).getOperator();
    final DataStreamPythonFunctionInfo upPythonFunctionInfo = upOperator.getPythonFunctionInfo().copy();
    final DataStreamPythonFunctionInfo downPythonFunctionInfo = downOperator.getPythonFunctionInfo().copy();
    DataStreamPythonFunctionInfo headPythonFunctionInfoOfDownOperator = downPythonFunctionInfo;
    while (headPythonFunctionInfoOfDownOperator.getInputs().length != 0) {
        headPythonFunctionInfoOfDownOperator = (DataStreamPythonFunctionInfo) headPythonFunctionInfoOfDownOperator.getInputs()[0];
    }
    headPythonFunctionInfoOfDownOperator.setInputs(new DataStreamPythonFunctionInfo[] { upPythonFunctionInfo });
    final AbstractDataStreamPythonFunctionOperator<?> chainedOperator = upOperator.copy(downPythonFunctionInfo, downOperator.getProducedType());
    // set partition custom flag
    chainedOperator.setContainsPartitionCustom(downOperator.containsPartitionCustom() || upOperator.containsPartitionCustom());
    PhysicalTransformation<?> chainedTransformation;
    if (upOperator instanceof AbstractOneInputPythonFunctionOperator) {
        chainedTransformation = new OneInputTransformation(upTransform.getInputs().get(0), upTransform.getName() + ", " + downTransform.getName(), (OneInputStreamOperator<?, ?>) chainedOperator, downTransform.getOutputType(), upTransform.getParallelism());
        ((OneInputTransformation<?, ?>) chainedTransformation).setStateKeySelector(((OneInputTransformation) upTransform).getStateKeySelector());
        ((OneInputTransformation<?, ?>) chainedTransformation).setStateKeyType(((OneInputTransformation<?, ?>) upTransform).getStateKeyType());
    } else {
        chainedTransformation = new TwoInputTransformation(upTransform.getInputs().get(0), upTransform.getInputs().get(1), upTransform.getName() + ", " + downTransform.getName(), (TwoInputStreamOperator<?, ?, ?>) chainedOperator, downTransform.getOutputType(), upTransform.getParallelism());
        ((TwoInputTransformation<?, ?, ?>) chainedTransformation).setStateKeySelectors(((TwoInputTransformation) upTransform).getStateKeySelector1(), ((TwoInputTransformation) upTransform).getStateKeySelector2());
        ((TwoInputTransformation<?, ?, ?>) chainedTransformation).setStateKeyType(((TwoInputTransformation<?, ?, ?>) upTransform).getStateKeyType());
    }
    chainedTransformation.setUid(upTransform.getUid());
    if (upTransform.getUserProvidedNodeHash() != null) {
        chainedTransformation.setUidHash(upTransform.getUserProvidedNodeHash());
    }
    for (ManagedMemoryUseCase useCase : upTransform.getManagedMemorySlotScopeUseCases()) {
        chainedTransformation.declareManagedMemoryUseCaseAtSlotScope(useCase);
    }
    for (ManagedMemoryUseCase useCase : downTransform.getManagedMemorySlotScopeUseCases()) {
        chainedTransformation.declareManagedMemoryUseCaseAtSlotScope(useCase);
    }
    for (Map.Entry<ManagedMemoryUseCase, Integer> useCase : upTransform.getManagedMemoryOperatorScopeUseCaseWeights().entrySet()) {
        chainedTransformation.declareManagedMemoryUseCaseAtOperatorScope(useCase.getKey(), useCase.getValue());
    }
    for (Map.Entry<ManagedMemoryUseCase, Integer> useCase : downTransform.getManagedMemoryOperatorScopeUseCaseWeights().entrySet()) {
        chainedTransformation.declareManagedMemoryUseCaseAtOperatorScope(useCase.getKey(), useCase.getValue() + chainedTransformation.getManagedMemoryOperatorScopeUseCaseWeights().getOrDefault(useCase.getKey(), 0));
    }
    chainedTransformation.setBufferTimeout(Math.min(upTransform.getBufferTimeout(), downTransform.getBufferTimeout()));
    if (upTransform.getMaxParallelism() > 0) {
        chainedTransformation.setMaxParallelism(upTransform.getMaxParallelism());
    }
    chainedTransformation.setChainingStrategy(getOperatorFactory(upTransform).getChainingStrategy());
    chainedTransformation.setCoLocationGroupKey(upTransform.getCoLocationGroupKey());
    chainedTransformation.setResources(upTransform.getMinResources().merge(downTransform.getMinResources()), upTransform.getPreferredResources().merge(downTransform.getPreferredResources()));
    if (upTransform.getSlotSharingGroup().isPresent()) {
        chainedTransformation.setSlotSharingGroup(upTransform.getSlotSharingGroup().get());
    }
    if (upTransform.getDescription() != null && downTransform.getDescription() != null) {
        chainedTransformation.setDescription(upTransform.getDescription() + ", " + downTransform.getDescription());
    } else if (upTransform.getDescription() != null) {
        chainedTransformation.setDescription(upTransform.getDescription());
    } else if (downTransform.getDescription() != null) {
        chainedTransformation.setDescription(downTransform.getDescription());
    }
    return chainedTransformation;
}
Also used : AbstractOneInputPythonFunctionOperator(org.apache.flink.streaming.api.operators.python.AbstractOneInputPythonFunctionOperator) DataStreamPythonFunctionInfo(org.apache.flink.streaming.api.functions.python.DataStreamPythonFunctionInfo) TwoInputStreamOperator(org.apache.flink.streaming.api.operators.TwoInputStreamOperator) OneInputStreamOperator(org.apache.flink.streaming.api.operators.OneInputStreamOperator) TwoInputTransformation(org.apache.flink.streaming.api.transformations.TwoInputTransformation) PythonProcessOperator(org.apache.flink.streaming.api.operators.python.PythonProcessOperator) OneInputTransformation(org.apache.flink.streaming.api.transformations.OneInputTransformation) HashMap(java.util.HashMap) Map(java.util.Map) ManagedMemoryUseCase(org.apache.flink.core.memory.ManagedMemoryUseCase) AbstractDataStreamPythonFunctionOperator(org.apache.flink.streaming.api.operators.python.AbstractDataStreamPythonFunctionOperator)

Aggregations

HashMap (java.util.HashMap)1 Map (java.util.Map)1 ManagedMemoryUseCase (org.apache.flink.core.memory.ManagedMemoryUseCase)1 DataStreamPythonFunctionInfo (org.apache.flink.streaming.api.functions.python.DataStreamPythonFunctionInfo)1 OneInputStreamOperator (org.apache.flink.streaming.api.operators.OneInputStreamOperator)1 TwoInputStreamOperator (org.apache.flink.streaming.api.operators.TwoInputStreamOperator)1 AbstractDataStreamPythonFunctionOperator (org.apache.flink.streaming.api.operators.python.AbstractDataStreamPythonFunctionOperator)1 AbstractOneInputPythonFunctionOperator (org.apache.flink.streaming.api.operators.python.AbstractOneInputPythonFunctionOperator)1 PythonProcessOperator (org.apache.flink.streaming.api.operators.python.PythonProcessOperator)1 OneInputTransformation (org.apache.flink.streaming.api.transformations.OneInputTransformation)1 TwoInputTransformation (org.apache.flink.streaming.api.transformations.TwoInputTransformation)1