use of org.apache.flink.api.common.functions.RichGroupReduceFunction in project flink by apache.
the class GroupReduceCompilationTest method testGroupedReduceWithSelectorFunctionKeyNoncombinable.
@Test
public void testGroupedReduceWithSelectorFunctionKeyNoncombinable() {
try {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(8);
DataSet<Tuple2<String, Double>> data = env.readCsvFile("file:///will/never/be/read").types(String.class, Double.class).name("source").setParallelism(6);
data.groupBy(new KeySelector<Tuple2<String, Double>, String>() {
public String getKey(Tuple2<String, Double> value) {
return value.f0;
}
}).reduceGroup(new RichGroupReduceFunction<Tuple2<String, Double>, Tuple2<String, Double>>() {
public void reduce(Iterable<Tuple2<String, Double>> values, Collector<Tuple2<String, Double>> out) {
}
}).name("reducer").output(new DiscardingOutputFormat<Tuple2<String, Double>>()).name("sink");
Plan p = env.createProgramPlan();
OptimizedPlan op = compileNoStats(p);
OptimizerPlanNodeResolver resolver = getOptimizerPlanNodeResolver(op);
// get the original nodes
SourcePlanNode sourceNode = resolver.getNode("source");
SingleInputPlanNode reduceNode = resolver.getNode("reducer");
SinkPlanNode sinkNode = resolver.getNode("sink");
// get the key extractors and projectors
SingleInputPlanNode keyExtractor = (SingleInputPlanNode) reduceNode.getInput().getSource();
SingleInputPlanNode keyProjector = (SingleInputPlanNode) sinkNode.getInput().getSource();
// check wiring
assertEquals(sourceNode, keyExtractor.getInput().getSource());
assertEquals(keyProjector, sinkNode.getInput().getSource());
// check that both reduce and combiner have the same strategy
assertEquals(DriverStrategy.SORTED_GROUP_REDUCE, reduceNode.getDriverStrategy());
// check the keys
assertEquals(new FieldList(0), reduceNode.getKeys(0));
assertEquals(new FieldList(0), reduceNode.getInput().getLocalStrategyKeys());
// check parallelism
assertEquals(6, sourceNode.getParallelism());
assertEquals(6, keyExtractor.getParallelism());
assertEquals(8, reduceNode.getParallelism());
assertEquals(8, keyProjector.getParallelism());
assertEquals(8, sinkNode.getParallelism());
} catch (Exception e) {
System.err.println(e.getMessage());
e.printStackTrace();
fail(e.getClass().getSimpleName() + " in test: " + e.getMessage());
}
}
use of org.apache.flink.api.common.functions.RichGroupReduceFunction in project flink by apache.
the class AggregateOperator method translateToDataFlow.
@SuppressWarnings("unchecked")
@Override
@Internal
protected org.apache.flink.api.common.operators.base.GroupReduceOperatorBase<IN, IN, GroupReduceFunction<IN, IN>> translateToDataFlow(Operator<IN> input) {
// sanity check
if (this.aggregationFunctions.isEmpty() || this.aggregationFunctions.size() != this.fields.size()) {
throw new IllegalStateException();
}
// construct the aggregation function
AggregationFunction<Object>[] aggFunctions = new AggregationFunction[this.aggregationFunctions.size()];
int[] fields = new int[this.fields.size()];
StringBuilder genName = new StringBuilder();
for (int i = 0; i < fields.length; i++) {
aggFunctions[i] = (AggregationFunction<Object>) this.aggregationFunctions.get(i);
fields[i] = this.fields.get(i);
genName.append(aggFunctions[i].toString()).append('(').append(fields[i]).append(')').append(',');
}
genName.append(" at ").append(aggregateLocationName);
genName.setLength(genName.length() - 1);
@SuppressWarnings("rawtypes") RichGroupReduceFunction<IN, IN> function = new AggregatingUdf(aggFunctions, fields);
String name = getName() != null ? getName() : genName.toString();
// distinguish between grouped reduce and non-grouped reduce
if (this.grouping == null) {
// non grouped aggregation
UnaryOperatorInformation<IN, IN> operatorInfo = new UnaryOperatorInformation<>(getInputType(), getResultType());
GroupReduceOperatorBase<IN, IN, GroupReduceFunction<IN, IN>> po = new GroupReduceOperatorBase<IN, IN, GroupReduceFunction<IN, IN>>(function, operatorInfo, new int[0], name);
po.setCombinable(true);
// set input
po.setInput(input);
// set parallelism
po.setParallelism(this.getParallelism());
return po;
}
if (this.grouping.getKeys() instanceof Keys.ExpressionKeys) {
// grouped aggregation
int[] logicalKeyPositions = this.grouping.getKeys().computeLogicalKeyPositions();
UnaryOperatorInformation<IN, IN> operatorInfo = new UnaryOperatorInformation<>(getInputType(), getResultType());
GroupReduceOperatorBase<IN, IN, GroupReduceFunction<IN, IN>> po = new GroupReduceOperatorBase<IN, IN, GroupReduceFunction<IN, IN>>(function, operatorInfo, logicalKeyPositions, name);
po.setCombinable(true);
po.setInput(input);
po.setParallelism(this.getParallelism());
po.setCustomPartitioner(grouping.getCustomPartitioner());
SingleInputSemanticProperties props = new SingleInputSemanticProperties();
for (int keyField : logicalKeyPositions) {
boolean keyFieldUsedInAgg = false;
for (int aggField : fields) {
if (keyField == aggField) {
keyFieldUsedInAgg = true;
break;
}
}
if (!keyFieldUsedInAgg) {
props.addForwardedField(keyField, keyField);
}
}
po.setSemanticProperties(props);
return po;
} else if (this.grouping.getKeys() instanceof Keys.SelectorFunctionKeys) {
throw new UnsupportedOperationException("Aggregate does not support grouping with KeySelector functions, yet.");
} else {
throw new UnsupportedOperationException("Unrecognized key type.");
}
}
use of org.apache.flink.api.common.functions.RichGroupReduceFunction in project flink by apache.
the class ScalaAggregateOperator method translateToDataFlow.
@SuppressWarnings("unchecked")
@Override
protected org.apache.flink.api.common.operators.base.GroupReduceOperatorBase<IN, IN, GroupReduceFunction<IN, IN>> translateToDataFlow(Operator<IN> input) {
// sanity check
if (this.aggregationFunctions.isEmpty() || this.aggregationFunctions.size() != this.fields.size()) {
throw new IllegalStateException();
}
// construct the aggregation function
AggregationFunction<Object>[] aggFunctions = new AggregationFunction[this.aggregationFunctions.size()];
int[] fields = new int[this.fields.size()];
StringBuilder genName = new StringBuilder();
for (int i = 0; i < fields.length; i++) {
aggFunctions[i] = (AggregationFunction<Object>) this.aggregationFunctions.get(i);
fields[i] = this.fields.get(i);
genName.append(aggFunctions[i].toString()).append('(').append(fields[i]).append(')').append(',');
}
genName.setLength(genName.length() - 1);
@SuppressWarnings("rawtypes") RichGroupReduceFunction<IN, IN> function = new AggregatingUdf(getInputType(), aggFunctions, fields);
String name = getName() != null ? getName() : genName.toString();
// distinguish between grouped reduce and non-grouped reduce
if (this.grouping == null) {
// non grouped aggregation
UnaryOperatorInformation<IN, IN> operatorInfo = new UnaryOperatorInformation<>(getInputType(), getResultType());
GroupReduceOperatorBase<IN, IN, GroupReduceFunction<IN, IN>> po = new GroupReduceOperatorBase<IN, IN, GroupReduceFunction<IN, IN>>(function, operatorInfo, new int[0], name);
po.setCombinable(true);
// set input
po.setInput(input);
// set parallelism
po.setParallelism(this.getParallelism());
return po;
}
if (this.grouping.getKeys() instanceof Keys.ExpressionKeys) {
// grouped aggregation
int[] logicalKeyPositions = this.grouping.getKeys().computeLogicalKeyPositions();
UnaryOperatorInformation<IN, IN> operatorInfo = new UnaryOperatorInformation<>(getInputType(), getResultType());
GroupReduceOperatorBase<IN, IN, GroupReduceFunction<IN, IN>> po = new GroupReduceOperatorBase<IN, IN, GroupReduceFunction<IN, IN>>(function, operatorInfo, logicalKeyPositions, name);
po.setCombinable(true);
// set input
po.setInput(input);
// set parallelism
po.setParallelism(this.getParallelism());
SingleInputSemanticProperties props = new SingleInputSemanticProperties();
for (int keyField : logicalKeyPositions) {
boolean keyFieldUsedInAgg = false;
for (int aggField : fields) {
if (keyField == aggField) {
keyFieldUsedInAgg = true;
break;
}
}
if (!keyFieldUsedInAgg) {
props.addForwardedField(keyField, keyField);
}
}
po.setSemanticProperties(props);
po.setCustomPartitioner(grouping.getCustomPartitioner());
return po;
} else if (this.grouping.getKeys() instanceof Keys.SelectorFunctionKeys) {
throw new UnsupportedOperationException("Aggregate does not support grouping with KeySelector functions, yet.");
} else {
throw new UnsupportedOperationException("Unrecognized key type.");
}
}
use of org.apache.flink.api.common.functions.RichGroupReduceFunction in project flink by apache.
the class GroupReduceOperator method translateToDataFlow.
// --------------------------------------------------------------------------------------------
// Translation
// --------------------------------------------------------------------------------------------
@Override
@SuppressWarnings("unchecked")
protected GroupReduceOperatorBase<?, OUT, ?> translateToDataFlow(Operator<IN> input) {
String name = getName() != null ? getName() : "GroupReduce at " + defaultName;
// wrap CombineFunction in GroupCombineFunction if combinable
if (combinable && function instanceof CombineFunction<?, ?>) {
this.function = function instanceof RichGroupReduceFunction<?, ?> ? new RichCombineToGroupCombineWrapper((RichGroupReduceFunction<?, ?>) function) : new CombineToGroupCombineWrapper((CombineFunction<?, ?>) function);
}
// distinguish between grouped reduce and non-grouped reduce
if (grouper == null) {
// non grouped reduce
UnaryOperatorInformation<IN, OUT> operatorInfo = new UnaryOperatorInformation<>(getInputType(), getResultType());
GroupReduceOperatorBase<IN, OUT, GroupReduceFunction<IN, OUT>> po = new GroupReduceOperatorBase<>(function, operatorInfo, new int[0], name);
po.setCombinable(combinable);
po.setInput(input);
// the parallelism for a non grouped reduce can only be 1
po.setParallelism(1);
return po;
}
if (grouper.getKeys() instanceof SelectorFunctionKeys) {
@SuppressWarnings("unchecked") SelectorFunctionKeys<IN, ?> selectorKeys = (SelectorFunctionKeys<IN, ?>) grouper.getKeys();
if (grouper instanceof SortedGrouping) {
SortedGrouping<IN> sortedGrouping = (SortedGrouping<IN>) grouper;
SelectorFunctionKeys<IN, ?> sortKeys = sortedGrouping.getSortSelectionFunctionKey();
Ordering groupOrder = sortedGrouping.getGroupOrdering();
PlanUnwrappingSortedReduceGroupOperator<IN, OUT, ?, ?> po = translateSelectorFunctionSortedReducer(selectorKeys, sortKeys, groupOrder, function, getResultType(), name, input, isCombinable());
po.setParallelism(this.getParallelism());
po.setCustomPartitioner(grouper.getCustomPartitioner());
return po;
} else {
PlanUnwrappingReduceGroupOperator<IN, OUT, ?> po = translateSelectorFunctionReducer(selectorKeys, function, getResultType(), name, input, isCombinable());
po.setParallelism(this.getParallelism());
po.setCustomPartitioner(grouper.getCustomPartitioner());
return po;
}
} else if (grouper.getKeys() instanceof ExpressionKeys) {
int[] logicalKeyPositions = grouper.getKeys().computeLogicalKeyPositions();
UnaryOperatorInformation<IN, OUT> operatorInfo = new UnaryOperatorInformation<>(getInputType(), getResultType());
GroupReduceOperatorBase<IN, OUT, GroupReduceFunction<IN, OUT>> po = new GroupReduceOperatorBase<>(function, operatorInfo, logicalKeyPositions, name);
po.setCombinable(combinable);
po.setInput(input);
po.setParallelism(getParallelism());
po.setCustomPartitioner(grouper.getCustomPartitioner());
// set group order
if (grouper instanceof SortedGrouping) {
SortedGrouping<IN> sortedGrouper = (SortedGrouping<IN>) grouper;
int[] sortKeyPositions = sortedGrouper.getGroupSortKeyPositions();
Order[] sortOrders = sortedGrouper.getGroupSortOrders();
Ordering o = new Ordering();
for (int i = 0; i < sortKeyPositions.length; i++) {
o.appendOrdering(sortKeyPositions[i], null, sortOrders[i]);
}
po.setGroupOrder(o);
}
return po;
} else {
throw new UnsupportedOperationException("Unrecognized key type.");
}
}
use of org.apache.flink.api.common.functions.RichGroupReduceFunction in project flink by apache.
the class GroupReduceCompilationTest method testGroupedReduceWithFieldPositionKeyNonCombinable.
@Test
public void testGroupedReduceWithFieldPositionKeyNonCombinable() {
try {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(8);
DataSet<Tuple2<String, Double>> data = env.readCsvFile("file:///will/never/be/read").types(String.class, Double.class).name("source").setParallelism(6);
data.groupBy(1).reduceGroup(new RichGroupReduceFunction<Tuple2<String, Double>, Tuple2<String, Double>>() {
public void reduce(Iterable<Tuple2<String, Double>> values, Collector<Tuple2<String, Double>> out) {
}
}).name("reducer").output(new DiscardingOutputFormat<Tuple2<String, Double>>()).name("sink");
Plan p = env.createProgramPlan();
OptimizedPlan op = compileNoStats(p);
OptimizerPlanNodeResolver resolver = getOptimizerPlanNodeResolver(op);
// get the original nodes
SourcePlanNode sourceNode = resolver.getNode("source");
SingleInputPlanNode reduceNode = resolver.getNode("reducer");
SinkPlanNode sinkNode = resolver.getNode("sink");
// check wiring
assertEquals(sourceNode, reduceNode.getInput().getSource());
assertEquals(reduceNode, sinkNode.getInput().getSource());
// check that both reduce and combiner have the same strategy
assertEquals(DriverStrategy.SORTED_GROUP_REDUCE, reduceNode.getDriverStrategy());
// check the keys
assertEquals(new FieldList(1), reduceNode.getKeys(0));
assertEquals(new FieldList(1), reduceNode.getInput().getLocalStrategyKeys());
// check parallelism
assertEquals(6, sourceNode.getParallelism());
assertEquals(8, reduceNode.getParallelism());
assertEquals(8, sinkNode.getParallelism());
} catch (Exception e) {
System.err.println(e.getMessage());
e.printStackTrace();
fail(e.getClass().getSimpleName() + " in test: " + e.getMessage());
}
}
Aggregations