use of com.facebook.presto.operator.aggregation.InternalAggregationFunction in project presto by prestodb.
the class TestLearnAggregations method testLearnLibSvm.
@Test
public void testLearnLibSvm() {
Type mapType = functionAndTypeManager.getParameterizedType("map", ImmutableList.of(TypeSignatureParameter.of(parseTypeSignature(StandardTypes.BIGINT)), TypeSignatureParameter.of(parseTypeSignature(StandardTypes.DOUBLE))));
InternalAggregationFunction aggregation = AggregationFromAnnotationsParser.parseFunctionDefinitionWithTypesConstraint(LearnLibSvmClassifierAggregation.class, ClassifierType.BIGINT_CLASSIFIER.getTypeSignature(), ImmutableList.of(BIGINT.getTypeSignature(), mapType.getTypeSignature(), VarcharType.getParametrizedVarcharSignature("x"))).specialize(BoundVariables.builder().setLongVariable("x", (long) Integer.MAX_VALUE).build(), 3, functionAndTypeManager);
assertLearnClassifer(aggregation.bind(ImmutableList.of(0, 1, 2), Optional.empty()).createAccumulator(UpdateMemory.NOOP));
}
use of com.facebook.presto.operator.aggregation.InternalAggregationFunction in project presto by prestodb.
the class TestLearnAggregations method testLearn.
@Test
public void testLearn() {
Type mapType = functionAndTypeManager.getParameterizedType("map", ImmutableList.of(TypeSignatureParameter.of(parseTypeSignature(StandardTypes.BIGINT)), TypeSignatureParameter.of(parseTypeSignature(StandardTypes.DOUBLE))));
InternalAggregationFunction aggregation = generateInternalAggregationFunction(LearnClassifierAggregation.class, ClassifierType.BIGINT_CLASSIFIER.getTypeSignature(), ImmutableList.of(BIGINT.getTypeSignature(), mapType.getTypeSignature()), functionAndTypeManager);
assertLearnClassifer(aggregation.bind(ImmutableList.of(0, 1), Optional.empty()).createAccumulator(UpdateMemory.NOOP));
}
use of com.facebook.presto.operator.aggregation.InternalAggregationFunction in project presto by prestodb.
the class TranslateExpressions method createRewriter.
private static PlanRowExpressionRewriter createRewriter(Metadata metadata, SqlParser sqlParser) {
return new PlanRowExpressionRewriter() {
@Override
public RowExpression rewrite(RowExpression expression, Rule.Context context) {
// special treatment of the CallExpression in Aggregation
if (expression instanceof CallExpression && ((CallExpression) expression).getArguments().stream().anyMatch(OriginalExpressionUtils::isExpression)) {
return removeOriginalExpressionArguments((CallExpression) expression, context.getSession(), context.getVariableAllocator());
}
return removeOriginalExpression(expression, context);
}
private RowExpression removeOriginalExpressionArguments(CallExpression callExpression, Session session, PlanVariableAllocator variableAllocator) {
Map<NodeRef<Expression>, Type> types = analyzeCallExpressionTypes(callExpression, session, variableAllocator.getTypes());
return new CallExpression(callExpression.getSourceLocation(), callExpression.getDisplayName(), callExpression.getFunctionHandle(), callExpression.getType(), callExpression.getArguments().stream().map(expression -> removeOriginalExpression(expression, session, types)).collect(toImmutableList()));
}
private Map<NodeRef<Expression>, Type> analyzeCallExpressionTypes(CallExpression callExpression, Session session, TypeProvider typeProvider) {
List<LambdaExpression> lambdaExpressions = callExpression.getArguments().stream().filter(OriginalExpressionUtils::isExpression).map(OriginalExpressionUtils::castToExpression).filter(LambdaExpression.class::isInstance).map(LambdaExpression.class::cast).collect(toImmutableList());
ImmutableMap.Builder<NodeRef<Expression>, Type> builder = ImmutableMap.<NodeRef<Expression>, Type>builder();
if (!lambdaExpressions.isEmpty()) {
List<FunctionType> functionTypes = metadata.getFunctionAndTypeManager().getFunctionMetadata(callExpression.getFunctionHandle()).getArgumentTypes().stream().filter(typeSignature -> typeSignature.getBase().equals(FunctionType.NAME)).map(typeSignature -> (FunctionType) (metadata.getFunctionAndTypeManager().getType(typeSignature))).collect(toImmutableList());
InternalAggregationFunction internalAggregationFunction = metadata.getFunctionAndTypeManager().getAggregateFunctionImplementation(callExpression.getFunctionHandle());
List<Class> lambdaInterfaces = internalAggregationFunction.getLambdaInterfaces();
verify(lambdaExpressions.size() == functionTypes.size());
verify(lambdaExpressions.size() == lambdaInterfaces.size());
for (int i = 0; i < lambdaExpressions.size(); i++) {
LambdaExpression lambdaExpression = lambdaExpressions.get(i);
FunctionType functionType = functionTypes.get(i);
// To compile lambda, LambdaDefinitionExpression needs to be generated from LambdaExpression,
// which requires the types of all sub-expressions.
//
// In project and filter expression compilation, ExpressionAnalyzer.getExpressionTypesFromInput
// is used to generate the types of all sub-expressions. (see visitScanFilterAndProject and visitFilter)
//
// This does not work here since the function call representation in final aggregation node
// is currently a hack: it takes intermediate type as input, and may not be a valid
// function call in Presto.
//
// TODO: Once the final aggregation function call representation is fixed,
// the same mechanism in project and filter expression should be used here.
verify(lambdaExpression.getArguments().size() == functionType.getArgumentTypes().size());
Map<NodeRef<Expression>, Type> lambdaArgumentExpressionTypes = new HashMap<>();
Map<String, Type> lambdaArgumentSymbolTypes = new HashMap<>();
for (int j = 0; j < lambdaExpression.getArguments().size(); j++) {
LambdaArgumentDeclaration argument = lambdaExpression.getArguments().get(j);
Type type = functionType.getArgumentTypes().get(j);
lambdaArgumentExpressionTypes.put(NodeRef.of(argument), type);
lambdaArgumentSymbolTypes.put(argument.getName().getValue(), type);
}
// the lambda expression itself
builder.put(NodeRef.of(lambdaExpression), functionType).putAll(lambdaArgumentExpressionTypes).putAll(getExpressionTypes(session, metadata, sqlParser, TypeProvider.copyOf(lambdaArgumentSymbolTypes), lambdaExpression.getBody(), emptyList(), NOOP));
}
}
for (RowExpression argument : callExpression.getArguments()) {
if (!isExpression(argument) || castToExpression(argument) instanceof LambdaExpression) {
continue;
}
builder.putAll(analyze(castToExpression(argument), session, typeProvider));
}
return builder.build();
}
private Map<NodeRef<Expression>, Type> analyze(Expression expression, Session session, TypeProvider typeProvider) {
return getExpressionTypes(session, metadata, sqlParser, typeProvider, expression, emptyList(), NOOP);
}
private RowExpression toRowExpression(Expression expression, Session session, Map<NodeRef<Expression>, Type> types) {
return SqlToRowExpressionTranslator.translate(expression, types, ImmutableMap.of(), metadata.getFunctionAndTypeManager(), session);
}
private RowExpression removeOriginalExpression(RowExpression expression, Rule.Context context) {
if (isExpression(expression)) {
return toRowExpression(castToExpression(expression), context.getSession(), analyze(castToExpression(expression), context.getSession(), context.getVariableAllocator().getTypes()));
}
return expression;
}
private RowExpression removeOriginalExpression(RowExpression rowExpression, Session session, Map<NodeRef<Expression>, Type> types) {
if (isExpression(rowExpression)) {
Expression expression = castToExpression(rowExpression);
return toRowExpression(expression, session, types);
}
return rowExpression;
}
};
}
use of com.facebook.presto.operator.aggregation.InternalAggregationFunction in project presto by prestodb.
the class TestHashAggregationOperator method testHashAggregationMemoryReservation.
@Test(dataProvider = "hashEnabledAndMemoryLimitForMergeValues")
public void testHashAggregationMemoryReservation(boolean hashEnabled, boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimitForMerge, long memoryLimitForMergeWithMemory) {
InternalAggregationFunction arrayAggColumn = getAggregation("array_agg", BIGINT);
List<Integer> hashChannels = Ints.asList(1);
RowPagesBuilder rowPagesBuilder = rowPagesBuilder(hashEnabled, hashChannels, BIGINT, BIGINT);
List<Page> input = rowPagesBuilder.addSequencePage(10, 100, 0).addSequencePage(10, 200, 0).addSequencePage(10, 300, 0).build();
DriverContext driverContext = createTaskContext(executor, scheduledExecutor, TEST_SESSION, new DataSize(10, Unit.MEGABYTE)).addPipelineContext(0, true, true, false).addDriverContext();
HashAggregationOperatorFactory operatorFactory = new HashAggregationOperatorFactory(0, new PlanNodeId("test"), ImmutableList.of(BIGINT), hashChannels, ImmutableList.of(), Step.SINGLE, true, ImmutableList.of(arrayAggColumn.bind(ImmutableList.of(0), Optional.empty())), rowPagesBuilder.getHashChannel(), Optional.empty(), 100_000, Optional.of(new DataSize(16, MEGABYTE)), spillEnabled, succinctBytes(memoryLimitForMerge), succinctBytes(memoryLimitForMergeWithMemory), spillerFactory, joinCompiler, false);
Operator operator = operatorFactory.createOperator(driverContext);
toPages(operator, input.iterator(), revokeMemoryWhenAddingPages);
assertEquals(operator.getOperatorContext().getOperatorStats().getUserMemoryReservation().toBytes(), 0);
}
use of com.facebook.presto.operator.aggregation.InternalAggregationFunction in project presto by prestodb.
the class TestHashAggregationOperator method testSpillerFailure.
@Test
public void testSpillerFailure() {
InternalAggregationFunction maxVarcharColumn = getAggregation("max", VARCHAR);
List<Integer> hashChannels = Ints.asList(1);
ImmutableList<Type> types = ImmutableList.of(VARCHAR, BIGINT, VARCHAR, BIGINT);
RowPagesBuilder rowPagesBuilder = rowPagesBuilder(false, hashChannels, types);
List<Page> input = rowPagesBuilder.addSequencePage(10, 100, 0, 100, 0).addSequencePage(10, 100, 0, 200, 0).addSequencePage(10, 100, 0, 300, 0).build();
DriverContext driverContext = TestingTaskContext.builder(executor, scheduledExecutor, TEST_SESSION).setQueryMaxMemory(DataSize.valueOf("7MB")).setMemoryPoolSize(DataSize.valueOf("1GB")).build().addPipelineContext(0, true, true, false).addDriverContext();
HashAggregationOperatorFactory operatorFactory = new HashAggregationOperatorFactory(0, new PlanNodeId("test"), ImmutableList.of(BIGINT), hashChannels, ImmutableList.of(), Step.SINGLE, false, ImmutableList.of(COUNT.bind(ImmutableList.of(0), Optional.empty()), LONG_SUM.bind(ImmutableList.of(3), Optional.empty()), LONG_AVERAGE.bind(ImmutableList.of(3), Optional.empty()), maxVarcharColumn.bind(ImmutableList.of(2), Optional.empty())), rowPagesBuilder.getHashChannel(), Optional.empty(), 100_000, Optional.of(new DataSize(16, MEGABYTE)), true, succinctBytes(8), succinctBytes(Integer.MAX_VALUE), new FailingSpillerFactory(), joinCompiler, false);
try {
toPages(operatorFactory, driverContext, input);
fail("An exception was expected");
} catch (RuntimeException expected) {
if (!nullToEmpty(expected.getMessage()).matches(".* Failed to spill")) {
fail("Exception other than expected was thrown", expected);
}
}
}
Aggregations