Search in sources :

Example 1 with DeltaIteration

use of org.apache.flink.api.java.operators.DeltaIteration in project flink by apache.

the class DeltaIterationTranslationTest method testCorrectTranslation.

@Test
public void testCorrectTranslation() {
    try {
        final String jobName = "Test JobName";
        final String iterationName = "Test Name";
        final String beforeNextWorksetMap = "Some Mapper";
        final String aggregatorName = "AggregatorName";
        final int[] iterationKeys = new int[] { 2 };
        final int numIterations = 13;
        final int defaultParallelism = 133;
        final int iterationParallelism = 77;
        ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
        // ------------ construct the test program ------------------
        {
            env.setParallelism(defaultParallelism);
            @SuppressWarnings("unchecked") DataSet<Tuple3<Double, Long, String>> initialSolutionSet = env.fromElements(new Tuple3<Double, Long, String>(3.44, 5L, "abc"));
            @SuppressWarnings("unchecked") DataSet<Tuple2<Double, String>> initialWorkSet = env.fromElements(new Tuple2<Double, String>(1.23, "abc"));
            DeltaIteration<Tuple3<Double, Long, String>, Tuple2<Double, String>> iteration = initialSolutionSet.iterateDelta(initialWorkSet, numIterations, iterationKeys);
            iteration.name(iterationName).parallelism(iterationParallelism);
            iteration.registerAggregator(aggregatorName, new LongSumAggregator());
            // test that multiple workset consumers are supported
            DataSet<Tuple2<Double, String>> worksetSelfJoin = iteration.getWorkset().map(new IdentityMapper<Tuple2<Double, String>>()).join(iteration.getWorkset()).where(1).equalTo(1).projectFirst(0, 1);
            DataSet<Tuple3<Double, Long, String>> joined = worksetSelfJoin.join(iteration.getSolutionSet()).where(1).equalTo(2).with(new SolutionWorksetJoin());
            DataSet<Tuple3<Double, Long, String>> result = iteration.closeWith(joined, joined.map(new NextWorksetMapper()).name(beforeNextWorksetMap));
            result.output(new DiscardingOutputFormat<Tuple3<Double, Long, String>>());
            result.writeAsText("/dev/null");
        }
        Plan p = env.createProgramPlan(jobName);
        // ------------- validate the plan ----------------
        assertEquals(jobName, p.getJobName());
        assertEquals(defaultParallelism, p.getDefaultParallelism());
        // validate the iteration
        GenericDataSinkBase<?> sink1, sink2;
        {
            Iterator<? extends GenericDataSinkBase<?>> sinks = p.getDataSinks().iterator();
            sink1 = sinks.next();
            sink2 = sinks.next();
        }
        DeltaIterationBase<?, ?> iteration = (DeltaIterationBase<?, ?>) sink1.getInput();
        // check that multi consumer translation works for iterations
        assertEquals(iteration, sink2.getInput());
        // check the basic iteration properties
        assertEquals(numIterations, iteration.getMaximumNumberOfIterations());
        assertArrayEquals(iterationKeys, iteration.getSolutionSetKeyFields());
        assertEquals(iterationParallelism, iteration.getParallelism());
        assertEquals(iterationName, iteration.getName());
        MapOperatorBase<?, ?, ?> nextWorksetMapper = (MapOperatorBase<?, ?, ?>) iteration.getNextWorkset();
        InnerJoinOperatorBase<?, ?, ?, ?> solutionSetJoin = (InnerJoinOperatorBase<?, ?, ?, ?>) iteration.getSolutionSetDelta();
        InnerJoinOperatorBase<?, ?, ?, ?> worksetSelfJoin = (InnerJoinOperatorBase<?, ?, ?, ?>) solutionSetJoin.getFirstInput();
        MapOperatorBase<?, ?, ?> worksetMapper = (MapOperatorBase<?, ?, ?>) worksetSelfJoin.getFirstInput();
        assertEquals(IdentityMapper.class, worksetMapper.getUserCodeWrapper().getUserCodeClass());
        assertEquals(NextWorksetMapper.class, nextWorksetMapper.getUserCodeWrapper().getUserCodeClass());
        if (solutionSetJoin.getUserCodeWrapper().getUserCodeObject() instanceof WrappingFunction) {
            WrappingFunction<?> wf = (WrappingFunction<?>) solutionSetJoin.getUserCodeWrapper().getUserCodeObject();
            assertEquals(SolutionWorksetJoin.class, wf.getWrappedFunction().getClass());
        } else {
            assertEquals(SolutionWorksetJoin.class, solutionSetJoin.getUserCodeWrapper().getUserCodeClass());
        }
        assertEquals(beforeNextWorksetMap, nextWorksetMapper.getName());
        assertEquals(aggregatorName, iteration.getAggregators().getAllRegisteredAggregators().iterator().next().getName());
    } catch (Exception e) {
        System.err.println(e.getMessage());
        e.printStackTrace();
        fail(e.getMessage());
    }
}
Also used : ExecutionEnvironment(org.apache.flink.api.java.ExecutionEnvironment) GenericDataSinkBase(org.apache.flink.api.common.operators.GenericDataSinkBase) DataSet(org.apache.flink.api.java.DataSet) LongSumAggregator(org.apache.flink.api.common.aggregators.LongSumAggregator) DiscardingOutputFormat(org.apache.flink.api.java.io.DiscardingOutputFormat) MapOperatorBase(org.apache.flink.api.common.operators.base.MapOperatorBase) Iterator(java.util.Iterator) DeltaIterationBase(org.apache.flink.api.common.operators.base.DeltaIterationBase) DeltaIteration(org.apache.flink.api.java.operators.DeltaIteration) InnerJoinOperatorBase(org.apache.flink.api.common.operators.base.InnerJoinOperatorBase) Plan(org.apache.flink.api.common.Plan) InvalidProgramException(org.apache.flink.api.common.InvalidProgramException) Tuple2(org.apache.flink.api.java.tuple.Tuple2) Tuple3(org.apache.flink.api.java.tuple.Tuple3) Test(org.junit.Test)

Aggregations

Iterator (java.util.Iterator)1 InvalidProgramException (org.apache.flink.api.common.InvalidProgramException)1 Plan (org.apache.flink.api.common.Plan)1 LongSumAggregator (org.apache.flink.api.common.aggregators.LongSumAggregator)1 GenericDataSinkBase (org.apache.flink.api.common.operators.GenericDataSinkBase)1 DeltaIterationBase (org.apache.flink.api.common.operators.base.DeltaIterationBase)1 InnerJoinOperatorBase (org.apache.flink.api.common.operators.base.InnerJoinOperatorBase)1 MapOperatorBase (org.apache.flink.api.common.operators.base.MapOperatorBase)1 DataSet (org.apache.flink.api.java.DataSet)1 ExecutionEnvironment (org.apache.flink.api.java.ExecutionEnvironment)1 DiscardingOutputFormat (org.apache.flink.api.java.io.DiscardingOutputFormat)1 DeltaIteration (org.apache.flink.api.java.operators.DeltaIteration)1 Tuple2 (org.apache.flink.api.java.tuple.Tuple2)1 Tuple3 (org.apache.flink.api.java.tuple.Tuple3)1 Test (org.junit.Test)1