use of com.facebook.presto.operator.aggregation.InternalAggregationFunction in project presto by prestodb.
the class TestKHyperLogLogAggregationFunction method testAggregation.
private void testAggregation(Type valueType, List<?> values, Type uiiType, List<?> uiis) {
InternalAggregationFunction aggregationFunction = getAggregation(valueType, uiiType);
KHyperLogLog khll = null;
long value;
long uii;
for (int i = 0; i < values.size(); i++) {
if (values.get(i) == null || uiis.get(i) == null) {
continue;
}
if (khll == null) {
khll = new KHyperLogLog();
}
value = toLong(values.get(i), valueType);
uii = toLong(uiis.get(i), uiiType);
if (valueType == VARCHAR) {
khll.add((Slice) values.get(i), uii);
} else {
khll.add(value, uii);
}
}
assertAggregation(aggregationFunction, (khll == null) ? null : new SqlVarbinary(khll.serialize().getBytes()), buildBlock(values, valueType), buildBlock(uiis, uiiType));
}
use of com.facebook.presto.operator.aggregation.InternalAggregationFunction in project presto by prestodb.
the class TestHashAggregationOperator method testHashAggregationWithGlobals.
@Test(dataProvider = "hashEnabled")
public void testHashAggregationWithGlobals(boolean hashEnabled) throws Exception {
MetadataManager metadata = MetadataManager.createTestMetadataManager();
InternalAggregationFunction countVarcharColumn = metadata.getFunctionRegistry().getAggregateFunctionImplementation(new Signature("count", AGGREGATE, parseTypeSignature(StandardTypes.BIGINT), parseTypeSignature(StandardTypes.VARCHAR)));
InternalAggregationFunction countBooleanColumn = metadata.getFunctionRegistry().getAggregateFunctionImplementation(new Signature("count", AGGREGATE, parseTypeSignature(StandardTypes.BIGINT), parseTypeSignature(StandardTypes.BOOLEAN)));
InternalAggregationFunction maxVarcharColumn = metadata.getFunctionRegistry().getAggregateFunctionImplementation(new Signature("max", AGGREGATE, parseTypeSignature(StandardTypes.VARCHAR), parseTypeSignature(StandardTypes.VARCHAR)));
Optional<Integer> groupIdChannel = Optional.of(1);
List<Integer> groupByChannels = Ints.asList(1, 2);
List<Integer> globalAggregationGroupIds = Ints.asList(42, 49);
RowPagesBuilder rowPagesBuilder = rowPagesBuilder(hashEnabled, groupByChannels, VARCHAR, VARCHAR, VARCHAR, BIGINT, BIGINT, BOOLEAN);
List<Page> input = rowPagesBuilder.build();
HashAggregationOperatorFactory operatorFactory = new HashAggregationOperatorFactory(0, new PlanNodeId("test"), ImmutableList.of(VARCHAR, BIGINT), groupByChannels, globalAggregationGroupIds, Step.SINGLE, ImmutableList.of(COUNT.bind(ImmutableList.of(0), Optional.empty()), LONG_SUM.bind(ImmutableList.of(4), Optional.empty()), LONG_AVERAGE.bind(ImmutableList.of(4), Optional.empty()), maxVarcharColumn.bind(ImmutableList.of(2), Optional.empty()), countVarcharColumn.bind(ImmutableList.of(0), Optional.empty()), countBooleanColumn.bind(ImmutableList.of(5), Optional.empty())), rowPagesBuilder.getHashChannel(), groupIdChannel, 100_000, new DataSize(16, MEGABYTE), joinCompiler);
MaterializedResult expected = resultBuilder(driverContext.getSession(), VARCHAR, BIGINT, BIGINT, BIGINT, DOUBLE, VARCHAR, BIGINT, BIGINT).row(null, 42L, 0L, null, null, null, 0L, 0L).row(null, 49L, 0L, null, null, null, 0L, 0L).build();
assertOperatorEqualsIgnoreOrder(operatorFactory, driverContext, input, expected, hashEnabled, Optional.of(groupByChannels.size()));
}
use of com.facebook.presto.operator.aggregation.InternalAggregationFunction in project presto by prestodb.
the class DoubleSumAggregationBenchmark method createOperatorFactories.
@Override
protected List<? extends OperatorFactory> createOperatorFactories() {
OperatorFactory tableScanOperator = createTableScanOperator(0, new PlanNodeId("test"), "orders", "totalprice");
FunctionAndTypeManager functionAndTypeManager = MetadataManager.createTestMetadataManager().getFunctionAndTypeManager();
InternalAggregationFunction doubleSum = functionAndTypeManager.getAggregateFunctionImplementation(functionAndTypeManager.lookupFunction("sum", fromTypes(DOUBLE)));
AggregationOperatorFactory aggregationOperator = new AggregationOperatorFactory(1, new PlanNodeId("test"), Step.SINGLE, ImmutableList.of(doubleSum.bind(ImmutableList.of(0), Optional.empty())), false);
return ImmutableList.of(tableScanOperator, aggregationOperator);
}
use of com.facebook.presto.operator.aggregation.InternalAggregationFunction in project presto by prestodb.
the class PushPartialAggregationThroughExchange method split.
private PlanNode split(AggregationNode node, Context context) {
// otherwise, add a partial and final with an exchange in between
Map<VariableReferenceExpression, AggregationNode.Aggregation> intermediateAggregation = new HashMap<>();
Map<VariableReferenceExpression, AggregationNode.Aggregation> finalAggregation = new HashMap<>();
for (Map.Entry<VariableReferenceExpression, AggregationNode.Aggregation> entry : node.getAggregations().entrySet()) {
AggregationNode.Aggregation originalAggregation = entry.getValue();
String functionName = functionAndTypeManager.getFunctionMetadata(originalAggregation.getFunctionHandle()).getName().getObjectName();
FunctionHandle functionHandle = originalAggregation.getFunctionHandle();
InternalAggregationFunction function = functionAndTypeManager.getAggregateFunctionImplementation(functionHandle);
VariableReferenceExpression intermediateVariable = context.getVariableAllocator().newVariable(entry.getValue().getCall().getSourceLocation(), functionName, function.getIntermediateType());
checkState(!originalAggregation.getOrderBy().isPresent(), "Aggregate with ORDER BY does not support partial aggregation");
intermediateAggregation.put(intermediateVariable, new AggregationNode.Aggregation(new CallExpression(originalAggregation.getCall().getSourceLocation(), functionName, functionHandle, function.getIntermediateType(), originalAggregation.getArguments()), originalAggregation.getFilter(), originalAggregation.getOrderBy(), originalAggregation.isDistinct(), originalAggregation.getMask()));
// rewrite final aggregation in terms of intermediate function
finalAggregation.put(entry.getKey(), new AggregationNode.Aggregation(new CallExpression(originalAggregation.getCall().getSourceLocation(), functionName, functionHandle, function.getFinalType(), ImmutableList.<RowExpression>builder().add(intermediateVariable).addAll(originalAggregation.getArguments().stream().filter(PushPartialAggregationThroughExchange::isLambda).collect(toImmutableList())).build()), Optional.empty(), Optional.empty(), false, Optional.empty()));
}
// We can always enable streaming aggregation for partial aggregations. But if the table is not pre-group by the groupby columns, it may have regressions.
// This session property is just a solution to force enabling when we know the execution would benefit from partial streaming aggregation.
// We can work on determining it based on the input table properties later.
List<VariableReferenceExpression> preGroupedSymbols = ImmutableList.of();
if (isStreamingForPartialAggregationEnabled(context.getSession())) {
preGroupedSymbols = ImmutableList.copyOf(node.getGroupingSets().getGroupingKeys());
}
PlanNode partial = new AggregationNode(node.getSourceLocation(), context.getIdAllocator().getNextId(), node.getSource(), intermediateAggregation, node.getGroupingSets(), // through the exchange may or may not preserve these properties. Hence, it is safest to drop preGroupedSymbols here.
preGroupedSymbols, PARTIAL, node.getHashVariable(), node.getGroupIdVariable());
return new AggregationNode(node.getSourceLocation(), node.getId(), partial, finalAggregation, node.getGroupingSets(), // through the exchange may or may not preserve these properties. Hence, it is safest to drop preGroupedSymbols here.
ImmutableList.of(), FINAL, node.getHashVariable(), node.getGroupIdVariable());
}
use of com.facebook.presto.operator.aggregation.InternalAggregationFunction in project presto by prestodb.
the class TestEvaluateClassifierPredictions method testEvaluateClassifierPredictions.
@Test
public void testEvaluateClassifierPredictions() {
metadata.registerBuiltInFunctions(extractFunctions(new MLPlugin().getFunctions()));
InternalAggregationFunction aggregation = functionAndTypeManager.getAggregateFunctionImplementation(functionAndTypeManager.lookupFunction("evaluate_classifier_predictions", fromTypes(BIGINT, BIGINT)));
Accumulator accumulator = aggregation.bind(ImmutableList.of(0, 1), Optional.empty()).createAccumulator(UpdateMemory.NOOP);
accumulator.addInput(getPage());
BlockBuilder finalOut = accumulator.getFinalType().createBlockBuilder(null, 1);
accumulator.evaluateFinal(finalOut);
Block block = finalOut.build();
String output = VARCHAR.getSlice(block, 0).toStringUtf8();
List<String> parts = ImmutableList.copyOf(Splitter.on('\n').omitEmptyStrings().split(output));
assertEquals(parts.size(), 7, output);
assertEquals(parts.get(0), "Accuracy: 1/2 (50.00%)");
}
Aggregations