Search in sources :

Example 1 with ParameterMetadata

use of com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata in project presto by prestodb.

the class AccumulatorCompiler method generateInvokeInputFunction.

private static BytecodeBlock generateInvokeInputFunction(Scope scope, FieldDefinition stateField, Variable position, List<Variable> parameterVariables, List<ParameterMetadata> parameterMetadatas, MethodHandle inputFunction, CallSiteBinder callSiteBinder, boolean grouped) {
    BytecodeBlock block = new BytecodeBlock();
    if (grouped) {
        generateSetGroupIdFromGroupIdsBlock(scope, stateField, block);
    }
    block.comment("Call input function with unpacked Block arguments");
    Class<?>[] parameters = inputFunction.type().parameterArray();
    int inputChannel = 0;
    for (int i = 0; i < parameters.length; i++) {
        ParameterMetadata parameterMetadata = parameterMetadatas.get(i);
        switch(parameterMetadata.getParameterType()) {
            case STATE:
                block.append(scope.getThis().getField(stateField));
                break;
            case BLOCK_INDEX:
                block.getVariable(position);
                break;
            case BLOCK_INPUT_CHANNEL:
            case NULLABLE_BLOCK_INPUT_CHANNEL:
                block.getVariable(parameterVariables.get(inputChannel));
                inputChannel++;
                break;
            case INPUT_CHANNEL:
                BytecodeBlock getBlockBytecode = new BytecodeBlock().getVariable(parameterVariables.get(inputChannel));
                pushStackType(scope, block, parameterMetadata.getSqlType(), getBlockBytecode, parameters[i], callSiteBinder);
                inputChannel++;
                break;
            default:
                throw new IllegalArgumentException("Unsupported parameter type: " + parameterMetadata.getParameterType());
        }
    }
    block.append(invoke(callSiteBinder.bind(inputFunction), "input"));
    return block;
}
Also used : BytecodeBlock(com.facebook.presto.bytecode.BytecodeBlock) CompilerUtils.defineClass(com.facebook.presto.bytecode.CompilerUtils.defineClass) ParameterMetadata(com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata)

Example 2 with ParameterMetadata

use of com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata in project presto by prestodb.

the class BindableAggregationFunction method getParameterMetadata.

private static List<ParameterMetadata> getParameterMetadata(@Nullable Method method, List<Type> inputTypes) {
    if (method == null) {
        return null;
    }
    ImmutableList.Builder<ParameterMetadata> builder = ImmutableList.builder();
    Annotation[][] annotations = method.getParameterAnnotations();
    String methodName = method.getDeclaringClass() + "." + method.getName();
    checkArgument(annotations.length > 0, "At least @AggregationState argument is required for each of aggregation functions.");
    int inputId = 0;
    int i = 0;
    if (annotations[0].length == 0) {
        // Backward compatibility - first argument without annotations is interpreted as State argument
        builder.add(new ParameterMetadata(STATE));
        i++;
    }
    for (; i < annotations.length; i++) {
        Annotation baseTypeAnnotation = baseTypeAnnotation(annotations[i], methodName);
        if (baseTypeAnnotation instanceof SqlType) {
            builder.add(fromSqlType(inputTypes.get(i - 1), isParameterBlock(annotations[i]), isParameterNullable(annotations[i]), methodName));
        } else if (baseTypeAnnotation instanceof BlockIndex) {
            builder.add(new ParameterMetadata(BLOCK_INDEX));
        } else if (baseTypeAnnotation instanceof AggregationState) {
            builder.add(new ParameterMetadata(STATE));
        } else {
            throw new IllegalArgumentException("Unsupported annotation: " + annotations[i]);
        }
    }
    return builder.build();
}
Also used : ImmutableList(com.google.common.collect.ImmutableList) ImmutableCollectors.toImmutableList(com.facebook.presto.util.ImmutableCollectors.toImmutableList) ParameterMetadata.fromSqlType(com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata.fromSqlType) SqlType(com.facebook.presto.spi.function.SqlType) ParameterMetadata(com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata) AggregationState(com.facebook.presto.spi.function.AggregationState) Annotation(java.lang.annotation.Annotation)

Example 3 with ParameterMetadata

use of com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata in project presto by prestodb.

the class ArrayAggregationFunction method generateAggregation.

private static InternalAggregationFunction generateAggregation(Type type, boolean legacyArrayAgg) {
    DynamicClassLoader classLoader = new DynamicClassLoader(ArrayAggregationFunction.class.getClassLoader());
    AccumulatorStateSerializer<?> stateSerializer = new ArrayAggregationStateSerializer(type);
    AccumulatorStateFactory<?> stateFactory = new ArrayAggregationStateFactory();
    List<Type> inputTypes = ImmutableList.of(type);
    Type outputType = new ArrayType(type);
    Type intermediateType = stateSerializer.getSerializedType();
    List<ParameterMetadata> inputParameterMetadata = createInputParameterMetadata(type, legacyArrayAgg);
    MethodHandle inputFunction = INPUT_FUNCTION.bindTo(type);
    MethodHandle combineFunction = COMBINE_FUNCTION.bindTo(type);
    MethodHandle outputFunction = OUTPUT_FUNCTION.bindTo(outputType);
    Class<? extends AccumulatorState> stateInterface = ArrayAggregationState.class;
    AggregationMetadata metadata = new AggregationMetadata(generateAggregationName(NAME, type.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())), inputParameterMetadata, inputFunction, combineFunction, outputFunction, stateInterface, stateSerializer, stateFactory, outputType);
    GenericAccumulatorFactoryBinder factory = AccumulatorCompiler.generateAccumulatorFactoryBinder(metadata, classLoader);
    return new InternalAggregationFunction(NAME, inputTypes, intermediateType, outputType, true, factory);
}
Also used : DynamicClassLoader(com.facebook.presto.bytecode.DynamicClassLoader) ArrayAggregationStateSerializer(com.facebook.presto.operator.aggregation.state.ArrayAggregationStateSerializer) ParameterMetadata(com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata) ArrayType(com.facebook.presto.type.ArrayType) ArrayType(com.facebook.presto.type.ArrayType) Type(com.facebook.presto.spi.type.Type) ArrayAggregationState(com.facebook.presto.operator.aggregation.state.ArrayAggregationState) ArrayAggregationStateFactory(com.facebook.presto.operator.aggregation.state.ArrayAggregationStateFactory) MethodHandle(java.lang.invoke.MethodHandle)

Example 4 with ParameterMetadata

use of com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata in project presto by prestodb.

the class AccumulatorCompiler method generateInvokeInputFunction.

private static BytecodeBlock generateInvokeInputFunction(Scope scope, List<FieldDefinition> stateField, Variable position, List<Variable> parameterVariables, List<ParameterMetadata> parameterMetadatas, List<Class> lambdaInterfaces, List<FieldDefinition> lambdaProviderFields, MethodHandle inputFunction, CallSiteBinder callSiteBinder, boolean grouped) {
    BytecodeBlock block = new BytecodeBlock();
    if (grouped) {
        generateSetGroupIdFromGroupIdsBlock(scope, stateField, block);
    }
    block.comment("Call input function with unpacked Block arguments");
    Class<?>[] parameters = inputFunction.type().parameterArray();
    int inputChannel = 0;
    int stateIndex = 0;
    verify(parameters.length == parameterMetadatas.size() + lambdaInterfaces.size());
    for (int i = 0; i < parameterMetadatas.size(); i++) {
        ParameterMetadata parameterMetadata = parameterMetadatas.get(i);
        switch(parameterMetadata.getParameterType()) {
            case STATE:
                block.append(scope.getThis().getField(stateField.get(stateIndex)));
                stateIndex++;
                break;
            case BLOCK_INDEX:
                block.getVariable(position);
                break;
            case BLOCK_INPUT_CHANNEL:
            case NULLABLE_BLOCK_INPUT_CHANNEL:
                block.getVariable(parameterVariables.get(inputChannel));
                inputChannel++;
                break;
            case INPUT_CHANNEL:
                BytecodeBlock getBlockBytecode = new BytecodeBlock().getVariable(parameterVariables.get(inputChannel));
                pushStackType(scope, block, parameterMetadata.getSqlType(), getBlockBytecode, parameters[i], callSiteBinder);
                inputChannel++;
                break;
            default:
                throw new IllegalArgumentException("Unsupported parameter type: " + parameterMetadata.getParameterType());
        }
    }
    for (int i = 0; i < lambdaInterfaces.size(); i++) {
        block.append(scope.getThis().getField(lambdaProviderFields.get(i)).invoke("getLambda", Object.class).cast(lambdaInterfaces.get(i)));
    }
    block.append(invoke(callSiteBinder.bind(inputFunction), "input"));
    return block;
}
Also used : BytecodeBlock(com.facebook.presto.bytecode.BytecodeBlock) CompilerUtils.defineClass(com.facebook.presto.util.CompilerUtils.defineClass) ParameterMetadata(com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata)

Example 5 with ParameterMetadata

use of com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata in project presto by prestodb.

the class AccumulatorCompiler method generateInputForLoop.

private static BytecodeBlock generateInputForLoop(List<FieldDefinition> stateField, List<ParameterMetadata> parameterMetadatas, MethodHandle inputFunction, Scope scope, List<Variable> parameterVariables, List<Class> lambdaInterfaces, List<FieldDefinition> lambdaProviderFields, Variable masksBlock, CallSiteBinder callSiteBinder, boolean grouped) {
    // For-loop over rows
    Variable page = scope.getVariable("page");
    Variable positionVariable = scope.declareVariable(int.class, "position");
    Variable rowsVariable = scope.declareVariable(int.class, "rows");
    BytecodeBlock block = new BytecodeBlock().append(page).invokeVirtual(Page.class, "getPositionCount", int.class).putVariable(rowsVariable).initializeVariable(positionVariable);
    BytecodeNode loopBody = generateInvokeInputFunction(scope, stateField, positionVariable, parameterVariables, parameterMetadatas, lambdaInterfaces, lambdaProviderFields, inputFunction, callSiteBinder, grouped);
    // Wrap with null checks
    List<Boolean> nullable = new ArrayList<>();
    for (ParameterMetadata metadata : parameterMetadatas) {
        switch(metadata.getParameterType()) {
            case INPUT_CHANNEL:
            case BLOCK_INPUT_CHANNEL:
                nullable.add(false);
                break;
            case NULLABLE_BLOCK_INPUT_CHANNEL:
                nullable.add(true);
                break;
            // do nothing
            default:
        }
    }
    checkState(nullable.size() == parameterVariables.size(), "Number of parameters does not match");
    for (int i = 0; i < parameterVariables.size(); i++) {
        if (!nullable.get(i)) {
            Variable variableDefinition = parameterVariables.get(i);
            loopBody = new IfStatement("if(!%s.isNull(position))", variableDefinition.getName()).condition(new BytecodeBlock().getVariable(variableDefinition).getVariable(positionVariable).invokeInterface(Block.class, "isNull", boolean.class, int.class)).ifFalse(loopBody);
        }
    }
    loopBody = new IfStatement("if(testMask(%s, position))", masksBlock.getName()).condition(new BytecodeBlock().getVariable(masksBlock).getVariable(positionVariable).invokeStatic(CompilerOperations.class, "testMask", boolean.class, Block.class, int.class)).ifTrue(loopBody);
    BytecodeNode forLoop = new ForLoop().initialize(new BytecodeBlock().putVariable(positionVariable, 0)).condition(new BytecodeBlock().getVariable(positionVariable).getVariable(rowsVariable).invokeStatic(CompilerOperations.class, "lessThan", boolean.class, int.class, int.class)).update(new BytecodeBlock().incrementVariable(positionVariable, (byte) 1)).body(loopBody);
    for (int i = 0; i < parameterVariables.size(); i++) {
        if (!nullable.get(i)) {
            Variable variableDefinition = parameterVariables.get(i);
            forLoop = new IfStatement("if(!(%s instanceof RunLengthEncodedBlock && %s.isNull(0)))", variableDefinition.getName(), variableDefinition.getName()).condition(and(variableDefinition.instanceOf(RunLengthEncodedBlock.class), variableDefinition.invoke("isNull", boolean.class, constantInt(0)))).ifFalse(forLoop);
        }
    }
    // Skip input blocks that eliminate all input positions
    forLoop = new IfStatement("if(!(%s instanceof RunLengthEncodedBlock && !testMask(%s, 0)))", masksBlock.getName(), masksBlock.getName()).condition(and(masksBlock.instanceOf(RunLengthEncodedBlock.class), not(invokeStatic(type(CompilerOperations.class), "testMask", type(boolean.class), ImmutableList.of(type(Block.class), type(int.class)), masksBlock, constantInt(0))))).ifFalse(forLoop);
    block.append(new IfStatement("if(%s > 0)", rowsVariable.getName()).condition(new BytecodeBlock().getVariable(rowsVariable).push(0).invokeStatic(CompilerOperations.class, "greaterThan", boolean.class, int.class, int.class)).ifTrue(forLoop));
    return block;
}
Also used : IfStatement(com.facebook.presto.bytecode.control.IfStatement) Variable(com.facebook.presto.bytecode.Variable) ForLoop(com.facebook.presto.bytecode.control.ForLoop) CompilerOperations(com.facebook.presto.sql.gen.CompilerOperations) BytecodeBlock(com.facebook.presto.bytecode.BytecodeBlock) ArrayList(java.util.ArrayList) RunLengthEncodedBlock(com.facebook.presto.common.block.RunLengthEncodedBlock) GroupByIdBlock(com.facebook.presto.operator.GroupByIdBlock) BytecodeBlock(com.facebook.presto.bytecode.BytecodeBlock) Block(com.facebook.presto.common.block.Block) BytecodeNode(com.facebook.presto.bytecode.BytecodeNode) ParameterMetadata(com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata)

Aggregations

ParameterMetadata (com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata)16 DynamicClassLoader (com.facebook.presto.bytecode.DynamicClassLoader)8 AccumulatorStateDescriptor (com.facebook.presto.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor)7 Type (com.facebook.presto.common.type.Type)6 BytecodeBlock (com.facebook.presto.bytecode.BytecodeBlock)5 ArrayType (com.facebook.presto.common.type.ArrayType)4 MethodHandle (java.lang.invoke.MethodHandle)4 GroupByIdBlock (com.facebook.presto.operator.GroupByIdBlock)3 AggregationMetadata (com.facebook.presto.operator.aggregation.AggregationMetadata)3 GenericAccumulatorFactoryBinder (com.facebook.presto.operator.aggregation.GenericAccumulatorFactoryBinder)3 InternalAggregationFunction (com.facebook.presto.operator.aggregation.InternalAggregationFunction)3 ArrayList (java.util.ArrayList)3 BytecodeNode (com.facebook.presto.bytecode.BytecodeNode)2 Variable (com.facebook.presto.bytecode.Variable)2 ForLoop (com.facebook.presto.bytecode.control.ForLoop)2 IfStatement (com.facebook.presto.bytecode.control.IfStatement)2 BytecodeExpression (com.facebook.presto.bytecode.expression.BytecodeExpression)2 Block (com.facebook.presto.common.block.Block)2 RunLengthEncodedBlock (com.facebook.presto.common.block.RunLengthEncodedBlock)2 ParameterType (com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType)2