Search in sources :

Example 1 with ParameterMetadata

use of io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata in project hetu-core by openlookeng.

the class ParametricAggregation method buildParameterMetadata.

private static List<ParameterMetadata> buildParameterMetadata(List<ParameterType> parameterMetadataTypes, List<Type> inputTypes) {
    ImmutableList.Builder<ParameterMetadata> builder = ImmutableList.builder();
    int inputId = 0;
    for (ParameterType parameterMetadataType : parameterMetadataTypes) {
        switch(parameterMetadataType) {
            case STATE:
            case BLOCK_INDEX:
                builder.add(new ParameterMetadata(parameterMetadataType));
                break;
            case INPUT_CHANNEL:
            case BLOCK_INPUT_CHANNEL:
            case NULLABLE_BLOCK_INPUT_CHANNEL:
                builder.add(new ParameterMetadata(parameterMetadataType, inputTypes.get(inputId++)));
                break;
        }
    }
    return builder.build();
}
Also used : ParameterType(io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType) ImmutableList(com.google.common.collect.ImmutableList) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) ParameterMetadata(io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata)

Example 2 with ParameterMetadata

use of io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata in project hetu-core by openlookeng.

the class ParametricAggregation method specialize.

@Override
public InternalAggregationFunction specialize(BoundVariables variables, int arity, FunctionAndTypeManager functionAndTypeManager) {
    // Bind variables
    Signature boundSignature = applyBoundVariables(getSignature(), variables, arity);
    // Find implementation matching arguments
    AggregationImplementation concreteImplementation = findMatchingImplementation(boundSignature, variables, functionAndTypeManager);
    // Build argument and return Types from signatures
    List<Type> inputTypes = boundSignature.getArgumentTypes().stream().map(functionAndTypeManager::getType).collect(toImmutableList());
    Type outputType = functionAndTypeManager.getType(boundSignature.getReturnType());
    // Create classloader for additional aggregation dependencies
    Class<?> definitionClass = concreteImplementation.getDefinitionClass();
    DynamicClassLoader classLoader = new DynamicClassLoader(definitionClass.getClassLoader(), getClass().getClassLoader());
    // Build state factory and serializer
    Class<?> stateClass = concreteImplementation.getStateClass();
    AccumulatorStateSerializer<?> stateSerializer = getAccumulatorStateSerializer(concreteImplementation, variables, functionAndTypeManager, stateClass, classLoader);
    AccumulatorStateFactory<?> stateFactory = StateCompiler.generateStateFactory(stateClass, classLoader);
    // Bind provided dependencies to aggregation method handlers
    MethodHandle inputHandle = bindDependencies(concreteImplementation.getInputFunction(), concreteImplementation.getInputDependencies(), variables, functionAndTypeManager);
    MethodHandle combineHandle = bindDependencies(concreteImplementation.getCombineFunction(), concreteImplementation.getCombineDependencies(), variables, functionAndTypeManager);
    MethodHandle outputHandle = bindDependencies(concreteImplementation.getOutputFunction(), concreteImplementation.getOutputDependencies(), variables, functionAndTypeManager);
    // Build metadata of input parameters
    List<ParameterMetadata> parametersMetadata = buildParameterMetadata(concreteImplementation.getInputParameterMetadataTypes(), inputTypes);
    // Generate Aggregation name
    String aggregationName = generateAggregationName(getSignature().getNameSuffix(), outputType.getTypeSignature(), signaturesFromTypes(inputTypes));
    // Collect all collected data in Metadata
    AggregationMetadata aggregationMetadata = new AggregationMetadata(aggregationName, parametersMetadata, inputHandle, combineHandle, outputHandle, ImmutableList.of(new AccumulatorStateDescriptor(stateClass, stateSerializer, stateFactory)), outputType);
    // Create specialized InternalAggregregationFunction for Presto
    return new InternalAggregationFunction(getSignature().getNameSuffix(), inputTypes, ImmutableList.of(stateSerializer.getSerializedType()), outputType, details.isDecomposable(), details.isOrderSensitive(), new LazyAccumulatorFactoryBinder(aggregationMetadata, classLoader));
}
Also used : DynamicClassLoader(io.airlift.bytecode.DynamicClassLoader) AccumulatorStateDescriptor(io.prestosql.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor) ParameterMetadata(io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata) Type(io.prestosql.spi.type.Type) ParameterType(io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType) Signature(io.prestosql.spi.function.Signature) TypeSignature(io.prestosql.spi.type.TypeSignature) MethodHandle(java.lang.invoke.MethodHandle)

Example 3 with ParameterMetadata

use of io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata in project hetu-core by openlookeng.

the class ArrayAggregationFunction method generateAggregation.

private static InternalAggregationFunction generateAggregation(Type type, ArrayAggGroupImplementation groupMode) {
    DynamicClassLoader classLoader = new DynamicClassLoader(ArrayAggregationFunction.class.getClassLoader());
    AccumulatorStateSerializer<?> stateSerializer = new ArrayAggregationStateSerializer(type);
    AccumulatorStateFactory<?> stateFactory = new ArrayAggregationStateFactory(type, groupMode);
    List<Type> inputTypes = ImmutableList.of(type);
    Type outputType = new ArrayType(type);
    Type intermediateType = stateSerializer.getSerializedType();
    List<ParameterMetadata> inputParameterMetadata = createInputParameterMetadata(type);
    MethodHandle inputFunction = INPUT_FUNCTION.bindTo(type);
    MethodHandle combineFunction = COMBINE_FUNCTION.bindTo(type);
    MethodHandle outputFunction = OUTPUT_FUNCTION.bindTo(type);
    Class<? extends AccumulatorState> stateInterface = ArrayAggregationState.class;
    AggregationMetadata metadata = new AggregationMetadata(generateAggregationName(NAME, type.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())), inputParameterMetadata, inputFunction, combineFunction, outputFunction, ImmutableList.of(new AccumulatorStateDescriptor(stateInterface, stateSerializer, stateFactory)), outputType);
    GenericAccumulatorFactoryBinder factory = AccumulatorCompiler.generateAccumulatorFactoryBinder(metadata, classLoader);
    return new InternalAggregationFunction(NAME, inputTypes, ImmutableList.of(intermediateType), outputType, true, true, factory);
}
Also used : DynamicClassLoader(io.airlift.bytecode.DynamicClassLoader) AccumulatorStateDescriptor(io.prestosql.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor) GenericAccumulatorFactoryBinder(io.prestosql.operator.aggregation.GenericAccumulatorFactoryBinder) InternalAggregationFunction(io.prestosql.operator.aggregation.InternalAggregationFunction) ParameterMetadata(io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata) ArrayType(io.prestosql.spi.type.ArrayType) Type(io.prestosql.spi.type.Type) ArrayType(io.prestosql.spi.type.ArrayType) AggregationMetadata(io.prestosql.operator.aggregation.AggregationMetadata) MethodHandle(java.lang.invoke.MethodHandle)

Example 4 with ParameterMetadata

use of io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata in project hetu-core by openlookeng.

the class AccumulatorCompiler method anyParametersAreNull.

private static BytecodeExpression anyParametersAreNull(List<ParameterMetadata> parameterMetadatas, Variable index, Variable channels, Variable position) {
    int inputChannel = 0;
    BytecodeExpression isNull = constantFalse();
    for (ParameterMetadata parameterMetadata : parameterMetadatas) {
        switch(parameterMetadata.getParameterType()) {
            case BLOCK_INPUT_CHANNEL:
            case INPUT_CHANNEL:
                BytecodeExpression getChannel = channels.invoke("get", Object.class, constantInt(inputChannel)).cast(int.class);
                isNull = BytecodeExpressions.or(isNull, index.invoke("isNull", boolean.class, getChannel, position));
                inputChannel++;
                break;
            case NULLABLE_BLOCK_INPUT_CHANNEL:
                inputChannel++;
                break;
        }
    }
    return isNull;
}
Also used : BytecodeExpression(io.airlift.bytecode.expression.BytecodeExpression) ParameterMetadata(io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata)

Example 5 with ParameterMetadata

use of io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata in project hetu-core by openlookeng.

the class AccumulatorCompiler method generateInputForLoop.

private static BytecodeBlock generateInputForLoop(List<FieldDefinition> stateField, List<ParameterMetadata> parameterMetadatas, MethodHandle inputFunction, Scope scope, List<Variable> parameterVariables, List<Class<?>> lambdaInterfaces, List<FieldDefinition> lambdaProviderFields, Variable masksBlock, CallSiteBinder callSiteBinder, boolean grouped) {
    // For-loop over rows
    Variable page = scope.getVariable("page");
    Variable positionVariable = scope.declareVariable(int.class, "position");
    Variable rowsVariable = scope.declareVariable(int.class, "rows");
    BytecodeBlock block = new BytecodeBlock().append(page).invokeVirtual(Page.class, "getPositionCount", int.class).putVariable(rowsVariable).initializeVariable(positionVariable);
    BytecodeNode loopBody = generateInvokeInputFunction(scope, stateField, positionVariable, parameterVariables, parameterMetadatas, lambdaInterfaces, lambdaProviderFields, inputFunction, callSiteBinder, grouped);
    // Wrap with null checks
    List<Boolean> nullable = new ArrayList<>();
    for (ParameterMetadata metadata : parameterMetadatas) {
        switch(metadata.getParameterType()) {
            case INPUT_CHANNEL:
            case BLOCK_INPUT_CHANNEL:
                nullable.add(false);
                break;
            case NULLABLE_BLOCK_INPUT_CHANNEL:
                nullable.add(true);
                break;
            // do nothing
            default:
        }
    }
    checkState(nullable.size() == parameterVariables.size(), "Number of parameters does not match");
    for (int i = 0; i < parameterVariables.size(); i++) {
        if (!nullable.get(i)) {
            Variable variableDefinition = parameterVariables.get(i);
            loopBody = new IfStatement("if(!%s.isNull(position))", variableDefinition.getName()).condition(new BytecodeBlock().getVariable(variableDefinition).getVariable(positionVariable).invokeInterface(Block.class, "isNull", boolean.class, int.class)).ifFalse(loopBody);
        }
    }
    loopBody = new IfStatement("if(testMask(%s, position))", masksBlock.getName()).condition(new BytecodeBlock().getVariable(masksBlock).getVariable(positionVariable).invokeStatic(CompilerOperations.class, "testMask", boolean.class, Block.class, int.class)).ifTrue(loopBody);
    block.append(new ForLoop().initialize(new BytecodeBlock().putVariable(positionVariable, 0)).condition(new BytecodeBlock().getVariable(positionVariable).getVariable(rowsVariable).invokeStatic(CompilerOperations.class, "lessThan", boolean.class, int.class, int.class)).update(new BytecodeBlock().incrementVariable(positionVariable, (byte) 1)).body(loopBody));
    return block;
}
Also used : IfStatement(io.airlift.bytecode.control.IfStatement) Variable(io.airlift.bytecode.Variable) ForLoop(io.airlift.bytecode.control.ForLoop) CompilerOperations(io.prestosql.sql.gen.CompilerOperations) BytecodeBlock(io.airlift.bytecode.BytecodeBlock) ArrayList(java.util.ArrayList) BytecodeBlock(io.airlift.bytecode.BytecodeBlock) GroupByIdBlock(io.prestosql.operator.GroupByIdBlock) Block(io.prestosql.spi.block.Block) BytecodeNode(io.airlift.bytecode.BytecodeNode) ParameterMetadata(io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata)

Aggregations

ParameterMetadata (io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata)10 DynamicClassLoader (io.airlift.bytecode.DynamicClassLoader)5 AccumulatorStateDescriptor (io.prestosql.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor)5 Type (io.prestosql.spi.type.Type)4 BytecodeBlock (io.airlift.bytecode.BytecodeBlock)3 MethodHandle (java.lang.invoke.MethodHandle)3 BytecodeExpression (io.airlift.bytecode.expression.BytecodeExpression)2 GroupByIdBlock (io.prestosql.operator.GroupByIdBlock)2 ParameterType (io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType)2 Block (io.prestosql.spi.block.Block)2 ArrayType (io.prestosql.spi.type.ArrayType)2 ArrayList (java.util.ArrayList)2 ImmutableList (com.google.common.collect.ImmutableList)1 ImmutableList.toImmutableList (com.google.common.collect.ImmutableList.toImmutableList)1 BytecodeNode (io.airlift.bytecode.BytecodeNode)1 Variable (io.airlift.bytecode.Variable)1 ForLoop (io.airlift.bytecode.control.ForLoop)1 IfStatement (io.airlift.bytecode.control.IfStatement)1 Slice (io.airlift.slice.Slice)1 AggregationMetadata (io.prestosql.operator.aggregation.AggregationMetadata)1