use of org.apache.flink.api.common.functions.RichGroupReduceFunction in project flink by apache.
the class GroupReduceCompilationTest method testAllGroupReduceNoCombiner.
@Test
public void testAllGroupReduceNoCombiner() {
try {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(8);
DataSet<Double> data = env.fromElements(0.2, 0.3, 0.4, 0.5).name("source");
data.reduceGroup(new RichGroupReduceFunction<Double, Double>() {
public void reduce(Iterable<Double> values, Collector<Double> out) {
}
}).name("reducer").output(new DiscardingOutputFormat<Double>()).name("sink");
Plan p = env.createProgramPlan();
OptimizedPlan op = compileNoStats(p);
OptimizerPlanNodeResolver resolver = getOptimizerPlanNodeResolver(op);
// the all-reduce has no combiner, when the parallelism of the input is one
SourcePlanNode sourceNode = resolver.getNode("source");
SingleInputPlanNode reduceNode = resolver.getNode("reducer");
SinkPlanNode sinkNode = resolver.getNode("sink");
// check wiring
assertEquals(sourceNode, reduceNode.getInput().getSource());
assertEquals(reduceNode, sinkNode.getInput().getSource());
// check that reduce has the right strategy
assertEquals(DriverStrategy.ALL_GROUP_REDUCE, reduceNode.getDriverStrategy());
// check parallelism
assertEquals(1, sourceNode.getParallelism());
assertEquals(1, reduceNode.getParallelism());
assertEquals(1, sinkNode.getParallelism());
} catch (Exception e) {
System.err.println(e.getMessage());
e.printStackTrace();
fail(e.getClass().getSimpleName() + " in test: " + e.getMessage());
}
}
use of org.apache.flink.api.common.functions.RichGroupReduceFunction in project flink by apache.
the class GroupReduceOperatorTest method testGroupReduceCollectionWithRuntimeContext.
@Test
public void testGroupReduceCollectionWithRuntimeContext() {
try {
final String taskName = "Test Task";
final AtomicBoolean opened = new AtomicBoolean();
final AtomicBoolean closed = new AtomicBoolean();
final RichGroupReduceFunction<Tuple2<String, Integer>, Tuple2<String, Integer>> reducer = new RichGroupReduceFunction<Tuple2<String, Integer>, Tuple2<String, Integer>>() {
@Override
public void reduce(Iterable<Tuple2<String, Integer>> values, Collector<Tuple2<String, Integer>> out) throws Exception {
Iterator<Tuple2<String, Integer>> input = values.iterator();
Tuple2<String, Integer> result = input.next();
int sum = result.f1;
while (input.hasNext()) {
Tuple2<String, Integer> next = input.next();
sum += next.f1;
}
result.f1 = sum;
out.collect(result);
}
@Override
public void open(Configuration parameters) throws Exception {
opened.set(true);
RuntimeContext ctx = getRuntimeContext();
assertEquals(0, ctx.getIndexOfThisSubtask());
assertEquals(1, ctx.getNumberOfParallelSubtasks());
assertEquals(taskName, ctx.getTaskName());
}
@Override
public void close() throws Exception {
closed.set(true);
}
};
GroupReduceOperatorBase<Tuple2<String, Integer>, Tuple2<String, Integer>, GroupReduceFunction<Tuple2<String, Integer>, Tuple2<String, Integer>>> op = new GroupReduceOperatorBase<>(reducer, new UnaryOperatorInformation<>(STRING_INT_TUPLE, STRING_INT_TUPLE), new int[] { 0 }, "TestReducer");
List<Tuple2<String, Integer>> input = new ArrayList<>(asList(new Tuple2<>("foo", 1), new Tuple2<>("foo", 3), new Tuple2<>("bar", 2), new Tuple2<>("bar", 4)));
final TaskInfo taskInfo = new TaskInfo(taskName, 1, 0, 1, 0);
ExecutionConfig executionConfig = new ExecutionConfig();
executionConfig.disableObjectReuse();
List<Tuple2<String, Integer>> resultMutableSafe = op.executeOnCollections(input, new RuntimeUDFContext(taskInfo, null, executionConfig, new HashMap<>(), new HashMap<>(), UnregisteredMetricsGroup.createOperatorMetricGroup()), executionConfig);
executionConfig.enableObjectReuse();
List<Tuple2<String, Integer>> resultRegular = op.executeOnCollections(input, new RuntimeUDFContext(taskInfo, null, executionConfig, new HashMap<>(), new HashMap<>(), UnregisteredMetricsGroup.createOperatorMetricGroup()), executionConfig);
Set<Tuple2<String, Integer>> resultSetMutableSafe = new HashSet<>(resultMutableSafe);
Set<Tuple2<String, Integer>> resultSetRegular = new HashSet<>(resultRegular);
Set<Tuple2<String, Integer>> expectedResult = new HashSet<>(asList(new Tuple2<>("foo", 4), new Tuple2<>("bar", 6)));
assertEquals(expectedResult, resultSetMutableSafe);
assertEquals(expectedResult, resultSetRegular);
assertTrue(opened.get());
assertTrue(closed.get());
} catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
Aggregations