Search in sources :

Example 1 with PythonProcessOperator

use of org.apache.flink.streaming.api.operators.python.PythonProcessOperator in project flink by apache.

the class PythonOperatorChainingOptimizer method createChainedTransformation.

private static Transformation<?> createChainedTransformation(Transformation<?> upTransform, Transformation<?> downTransform) {
    final AbstractDataStreamPythonFunctionOperator<?> upOperator = (AbstractDataStreamPythonFunctionOperator<?>) ((SimpleOperatorFactory<?>) getOperatorFactory(upTransform)).getOperator();
    final PythonProcessOperator<?, ?> downOperator = (PythonProcessOperator<?, ?>) ((SimpleOperatorFactory<?>) getOperatorFactory(downTransform)).getOperator();
    final DataStreamPythonFunctionInfo upPythonFunctionInfo = upOperator.getPythonFunctionInfo().copy();
    final DataStreamPythonFunctionInfo downPythonFunctionInfo = downOperator.getPythonFunctionInfo().copy();
    DataStreamPythonFunctionInfo headPythonFunctionInfoOfDownOperator = downPythonFunctionInfo;
    while (headPythonFunctionInfoOfDownOperator.getInputs().length != 0) {
        headPythonFunctionInfoOfDownOperator = (DataStreamPythonFunctionInfo) headPythonFunctionInfoOfDownOperator.getInputs()[0];
    }
    headPythonFunctionInfoOfDownOperator.setInputs(new DataStreamPythonFunctionInfo[] { upPythonFunctionInfo });
    final AbstractDataStreamPythonFunctionOperator<?> chainedOperator = upOperator.copy(downPythonFunctionInfo, downOperator.getProducedType());
    // set partition custom flag
    chainedOperator.setContainsPartitionCustom(downOperator.containsPartitionCustom() || upOperator.containsPartitionCustom());
    PhysicalTransformation<?> chainedTransformation;
    if (upOperator instanceof AbstractOneInputPythonFunctionOperator) {
        chainedTransformation = new OneInputTransformation(upTransform.getInputs().get(0), upTransform.getName() + ", " + downTransform.getName(), (OneInputStreamOperator<?, ?>) chainedOperator, downTransform.getOutputType(), upTransform.getParallelism());
        ((OneInputTransformation<?, ?>) chainedTransformation).setStateKeySelector(((OneInputTransformation) upTransform).getStateKeySelector());
        ((OneInputTransformation<?, ?>) chainedTransformation).setStateKeyType(((OneInputTransformation<?, ?>) upTransform).getStateKeyType());
    } else {
        chainedTransformation = new TwoInputTransformation(upTransform.getInputs().get(0), upTransform.getInputs().get(1), upTransform.getName() + ", " + downTransform.getName(), (TwoInputStreamOperator<?, ?, ?>) chainedOperator, downTransform.getOutputType(), upTransform.getParallelism());
        ((TwoInputTransformation<?, ?, ?>) chainedTransformation).setStateKeySelectors(((TwoInputTransformation) upTransform).getStateKeySelector1(), ((TwoInputTransformation) upTransform).getStateKeySelector2());
        ((TwoInputTransformation<?, ?, ?>) chainedTransformation).setStateKeyType(((TwoInputTransformation<?, ?, ?>) upTransform).getStateKeyType());
    }
    chainedTransformation.setUid(upTransform.getUid());
    if (upTransform.getUserProvidedNodeHash() != null) {
        chainedTransformation.setUidHash(upTransform.getUserProvidedNodeHash());
    }
    for (ManagedMemoryUseCase useCase : upTransform.getManagedMemorySlotScopeUseCases()) {
        chainedTransformation.declareManagedMemoryUseCaseAtSlotScope(useCase);
    }
    for (ManagedMemoryUseCase useCase : downTransform.getManagedMemorySlotScopeUseCases()) {
        chainedTransformation.declareManagedMemoryUseCaseAtSlotScope(useCase);
    }
    for (Map.Entry<ManagedMemoryUseCase, Integer> useCase : upTransform.getManagedMemoryOperatorScopeUseCaseWeights().entrySet()) {
        chainedTransformation.declareManagedMemoryUseCaseAtOperatorScope(useCase.getKey(), useCase.getValue());
    }
    for (Map.Entry<ManagedMemoryUseCase, Integer> useCase : downTransform.getManagedMemoryOperatorScopeUseCaseWeights().entrySet()) {
        chainedTransformation.declareManagedMemoryUseCaseAtOperatorScope(useCase.getKey(), useCase.getValue() + chainedTransformation.getManagedMemoryOperatorScopeUseCaseWeights().getOrDefault(useCase.getKey(), 0));
    }
    chainedTransformation.setBufferTimeout(Math.min(upTransform.getBufferTimeout(), downTransform.getBufferTimeout()));
    if (upTransform.getMaxParallelism() > 0) {
        chainedTransformation.setMaxParallelism(upTransform.getMaxParallelism());
    }
    chainedTransformation.setChainingStrategy(getOperatorFactory(upTransform).getChainingStrategy());
    chainedTransformation.setCoLocationGroupKey(upTransform.getCoLocationGroupKey());
    chainedTransformation.setResources(upTransform.getMinResources().merge(downTransform.getMinResources()), upTransform.getPreferredResources().merge(downTransform.getPreferredResources()));
    if (upTransform.getSlotSharingGroup().isPresent()) {
        chainedTransformation.setSlotSharingGroup(upTransform.getSlotSharingGroup().get());
    }
    if (upTransform.getDescription() != null && downTransform.getDescription() != null) {
        chainedTransformation.setDescription(upTransform.getDescription() + ", " + downTransform.getDescription());
    } else if (upTransform.getDescription() != null) {
        chainedTransformation.setDescription(upTransform.getDescription());
    } else if (downTransform.getDescription() != null) {
        chainedTransformation.setDescription(downTransform.getDescription());
    }
    return chainedTransformation;
}
Also used : AbstractOneInputPythonFunctionOperator(org.apache.flink.streaming.api.operators.python.AbstractOneInputPythonFunctionOperator) DataStreamPythonFunctionInfo(org.apache.flink.streaming.api.functions.python.DataStreamPythonFunctionInfo) TwoInputStreamOperator(org.apache.flink.streaming.api.operators.TwoInputStreamOperator) OneInputStreamOperator(org.apache.flink.streaming.api.operators.OneInputStreamOperator) TwoInputTransformation(org.apache.flink.streaming.api.transformations.TwoInputTransformation) PythonProcessOperator(org.apache.flink.streaming.api.operators.python.PythonProcessOperator) OneInputTransformation(org.apache.flink.streaming.api.transformations.OneInputTransformation) HashMap(java.util.HashMap) Map(java.util.Map) ManagedMemoryUseCase(org.apache.flink.core.memory.ManagedMemoryUseCase) AbstractDataStreamPythonFunctionOperator(org.apache.flink.streaming.api.operators.python.AbstractDataStreamPythonFunctionOperator)

Example 2 with PythonProcessOperator

use of org.apache.flink.streaming.api.operators.python.PythonProcessOperator in project flink by apache.

the class PythonOperatorChainingOptimizerTest method testChainingNonKeyedOperators.

@Test
public void testChainingNonKeyedOperators() {
    PythonProcessOperator<?, ?> processOperator1 = createProcessOperator("f1", new RowTypeInfo(Types.INT(), Types.INT()), Types.STRING());
    PythonProcessOperator<?, ?> processOperator2 = createProcessOperator("f2", Types.STRING(), Types.INT());
    Transformation<?> sourceTransformation = mock(SourceTransformation.class);
    OneInputTransformation<?, ?> processTransformation1 = new OneInputTransformation(sourceTransformation, "Process1", processOperator1, processOperator1.getProducedType(), 2);
    Transformation<?> processTransformation2 = new OneInputTransformation(processTransformation1, "process2", processOperator2, processOperator2.getProducedType(), 2);
    List<Transformation<?>> transformations = new ArrayList<>();
    transformations.add(sourceTransformation);
    transformations.add(processTransformation1);
    transformations.add(processTransformation2);
    List<Transformation<?>> optimized = PythonOperatorChainingOptimizer.optimize(transformations);
    assertEquals(2, optimized.size());
    OneInputTransformation<?, ?> chainedTransformation = (OneInputTransformation<?, ?>) optimized.get(1);
    assertEquals(sourceTransformation.getOutputType(), chainedTransformation.getInputType());
    assertEquals(processOperator2.getProducedType(), chainedTransformation.getOutputType());
    OneInputStreamOperator<?, ?> chainedOperator = chainedTransformation.getOperator();
    assertTrue(chainedOperator instanceof PythonProcessOperator);
    validateChainedPythonFunctions(((PythonProcessOperator<?, ?>) chainedOperator).getPythonFunctionInfo(), "f2", "f1");
}
Also used : SourceTransformation(org.apache.flink.streaming.api.transformations.SourceTransformation) TwoInputTransformation(org.apache.flink.streaming.api.transformations.TwoInputTransformation) OneInputTransformation(org.apache.flink.streaming.api.transformations.OneInputTransformation) Transformation(org.apache.flink.api.dag.Transformation) ArrayList(java.util.ArrayList) RowTypeInfo(org.apache.flink.api.java.typeutils.RowTypeInfo) OneInputTransformation(org.apache.flink.streaming.api.transformations.OneInputTransformation) PythonProcessOperator(org.apache.flink.streaming.api.operators.python.PythonProcessOperator) Test(org.junit.Test)

Aggregations

PythonProcessOperator (org.apache.flink.streaming.api.operators.python.PythonProcessOperator)2 OneInputTransformation (org.apache.flink.streaming.api.transformations.OneInputTransformation)2 TwoInputTransformation (org.apache.flink.streaming.api.transformations.TwoInputTransformation)2 ArrayList (java.util.ArrayList)1 HashMap (java.util.HashMap)1 Map (java.util.Map)1 Transformation (org.apache.flink.api.dag.Transformation)1 RowTypeInfo (org.apache.flink.api.java.typeutils.RowTypeInfo)1 ManagedMemoryUseCase (org.apache.flink.core.memory.ManagedMemoryUseCase)1 DataStreamPythonFunctionInfo (org.apache.flink.streaming.api.functions.python.DataStreamPythonFunctionInfo)1 OneInputStreamOperator (org.apache.flink.streaming.api.operators.OneInputStreamOperator)1 TwoInputStreamOperator (org.apache.flink.streaming.api.operators.TwoInputStreamOperator)1 AbstractDataStreamPythonFunctionOperator (org.apache.flink.streaming.api.operators.python.AbstractDataStreamPythonFunctionOperator)1 AbstractOneInputPythonFunctionOperator (org.apache.flink.streaming.api.operators.python.AbstractOneInputPythonFunctionOperator)1 SourceTransformation (org.apache.flink.streaming.api.transformations.SourceTransformation)1 Test (org.junit.Test)1