use of io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata in project hetu-core by openlookeng.
the class ParametricAggregation method buildParameterMetadata.
private static List<ParameterMetadata> buildParameterMetadata(List<ParameterType> parameterMetadataTypes, List<Type> inputTypes) {
ImmutableList.Builder<ParameterMetadata> builder = ImmutableList.builder();
int inputId = 0;
for (ParameterType parameterMetadataType : parameterMetadataTypes) {
switch(parameterMetadataType) {
case STATE:
case BLOCK_INDEX:
builder.add(new ParameterMetadata(parameterMetadataType));
break;
case INPUT_CHANNEL:
case BLOCK_INPUT_CHANNEL:
case NULLABLE_BLOCK_INPUT_CHANNEL:
builder.add(new ParameterMetadata(parameterMetadataType, inputTypes.get(inputId++)));
break;
}
}
return builder.build();
}
use of io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata in project hetu-core by openlookeng.
the class ParametricAggregation method specialize.
@Override
public InternalAggregationFunction specialize(BoundVariables variables, int arity, FunctionAndTypeManager functionAndTypeManager) {
// Bind variables
Signature boundSignature = applyBoundVariables(getSignature(), variables, arity);
// Find implementation matching arguments
AggregationImplementation concreteImplementation = findMatchingImplementation(boundSignature, variables, functionAndTypeManager);
// Build argument and return Types from signatures
List<Type> inputTypes = boundSignature.getArgumentTypes().stream().map(functionAndTypeManager::getType).collect(toImmutableList());
Type outputType = functionAndTypeManager.getType(boundSignature.getReturnType());
// Create classloader for additional aggregation dependencies
Class<?> definitionClass = concreteImplementation.getDefinitionClass();
DynamicClassLoader classLoader = new DynamicClassLoader(definitionClass.getClassLoader(), getClass().getClassLoader());
// Build state factory and serializer
Class<?> stateClass = concreteImplementation.getStateClass();
AccumulatorStateSerializer<?> stateSerializer = getAccumulatorStateSerializer(concreteImplementation, variables, functionAndTypeManager, stateClass, classLoader);
AccumulatorStateFactory<?> stateFactory = StateCompiler.generateStateFactory(stateClass, classLoader);
// Bind provided dependencies to aggregation method handlers
MethodHandle inputHandle = bindDependencies(concreteImplementation.getInputFunction(), concreteImplementation.getInputDependencies(), variables, functionAndTypeManager);
MethodHandle combineHandle = bindDependencies(concreteImplementation.getCombineFunction(), concreteImplementation.getCombineDependencies(), variables, functionAndTypeManager);
MethodHandle outputHandle = bindDependencies(concreteImplementation.getOutputFunction(), concreteImplementation.getOutputDependencies(), variables, functionAndTypeManager);
// Build metadata of input parameters
List<ParameterMetadata> parametersMetadata = buildParameterMetadata(concreteImplementation.getInputParameterMetadataTypes(), inputTypes);
// Generate Aggregation name
String aggregationName = generateAggregationName(getSignature().getNameSuffix(), outputType.getTypeSignature(), signaturesFromTypes(inputTypes));
// Collect all collected data in Metadata
AggregationMetadata aggregationMetadata = new AggregationMetadata(aggregationName, parametersMetadata, inputHandle, combineHandle, outputHandle, ImmutableList.of(new AccumulatorStateDescriptor(stateClass, stateSerializer, stateFactory)), outputType);
// Create specialized InternalAggregregationFunction for Presto
return new InternalAggregationFunction(getSignature().getNameSuffix(), inputTypes, ImmutableList.of(stateSerializer.getSerializedType()), outputType, details.isDecomposable(), details.isOrderSensitive(), new LazyAccumulatorFactoryBinder(aggregationMetadata, classLoader));
}
use of io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata in project hetu-core by openlookeng.
the class ArrayAggregationFunction method generateAggregation.
private static InternalAggregationFunction generateAggregation(Type type, ArrayAggGroupImplementation groupMode) {
DynamicClassLoader classLoader = new DynamicClassLoader(ArrayAggregationFunction.class.getClassLoader());
AccumulatorStateSerializer<?> stateSerializer = new ArrayAggregationStateSerializer(type);
AccumulatorStateFactory<?> stateFactory = new ArrayAggregationStateFactory(type, groupMode);
List<Type> inputTypes = ImmutableList.of(type);
Type outputType = new ArrayType(type);
Type intermediateType = stateSerializer.getSerializedType();
List<ParameterMetadata> inputParameterMetadata = createInputParameterMetadata(type);
MethodHandle inputFunction = INPUT_FUNCTION.bindTo(type);
MethodHandle combineFunction = COMBINE_FUNCTION.bindTo(type);
MethodHandle outputFunction = OUTPUT_FUNCTION.bindTo(type);
Class<? extends AccumulatorState> stateInterface = ArrayAggregationState.class;
AggregationMetadata metadata = new AggregationMetadata(generateAggregationName(NAME, type.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())), inputParameterMetadata, inputFunction, combineFunction, outputFunction, ImmutableList.of(new AccumulatorStateDescriptor(stateInterface, stateSerializer, stateFactory)), outputType);
GenericAccumulatorFactoryBinder factory = AccumulatorCompiler.generateAccumulatorFactoryBinder(metadata, classLoader);
return new InternalAggregationFunction(NAME, inputTypes, ImmutableList.of(intermediateType), outputType, true, true, factory);
}
use of io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata in project hetu-core by openlookeng.
the class AccumulatorCompiler method anyParametersAreNull.
private static BytecodeExpression anyParametersAreNull(List<ParameterMetadata> parameterMetadatas, Variable index, Variable channels, Variable position) {
int inputChannel = 0;
BytecodeExpression isNull = constantFalse();
for (ParameterMetadata parameterMetadata : parameterMetadatas) {
switch(parameterMetadata.getParameterType()) {
case BLOCK_INPUT_CHANNEL:
case INPUT_CHANNEL:
BytecodeExpression getChannel = channels.invoke("get", Object.class, constantInt(inputChannel)).cast(int.class);
isNull = BytecodeExpressions.or(isNull, index.invoke("isNull", boolean.class, getChannel, position));
inputChannel++;
break;
case NULLABLE_BLOCK_INPUT_CHANNEL:
inputChannel++;
break;
}
}
return isNull;
}
use of io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata in project hetu-core by openlookeng.
the class AccumulatorCompiler method generateInputForLoop.
private static BytecodeBlock generateInputForLoop(List<FieldDefinition> stateField, List<ParameterMetadata> parameterMetadatas, MethodHandle inputFunction, Scope scope, List<Variable> parameterVariables, List<Class<?>> lambdaInterfaces, List<FieldDefinition> lambdaProviderFields, Variable masksBlock, CallSiteBinder callSiteBinder, boolean grouped) {
// For-loop over rows
Variable page = scope.getVariable("page");
Variable positionVariable = scope.declareVariable(int.class, "position");
Variable rowsVariable = scope.declareVariable(int.class, "rows");
BytecodeBlock block = new BytecodeBlock().append(page).invokeVirtual(Page.class, "getPositionCount", int.class).putVariable(rowsVariable).initializeVariable(positionVariable);
BytecodeNode loopBody = generateInvokeInputFunction(scope, stateField, positionVariable, parameterVariables, parameterMetadatas, lambdaInterfaces, lambdaProviderFields, inputFunction, callSiteBinder, grouped);
// Wrap with null checks
List<Boolean> nullable = new ArrayList<>();
for (ParameterMetadata metadata : parameterMetadatas) {
switch(metadata.getParameterType()) {
case INPUT_CHANNEL:
case BLOCK_INPUT_CHANNEL:
nullable.add(false);
break;
case NULLABLE_BLOCK_INPUT_CHANNEL:
nullable.add(true);
break;
// do nothing
default:
}
}
checkState(nullable.size() == parameterVariables.size(), "Number of parameters does not match");
for (int i = 0; i < parameterVariables.size(); i++) {
if (!nullable.get(i)) {
Variable variableDefinition = parameterVariables.get(i);
loopBody = new IfStatement("if(!%s.isNull(position))", variableDefinition.getName()).condition(new BytecodeBlock().getVariable(variableDefinition).getVariable(positionVariable).invokeInterface(Block.class, "isNull", boolean.class, int.class)).ifFalse(loopBody);
}
}
loopBody = new IfStatement("if(testMask(%s, position))", masksBlock.getName()).condition(new BytecodeBlock().getVariable(masksBlock).getVariable(positionVariable).invokeStatic(CompilerOperations.class, "testMask", boolean.class, Block.class, int.class)).ifTrue(loopBody);
block.append(new ForLoop().initialize(new BytecodeBlock().putVariable(positionVariable, 0)).condition(new BytecodeBlock().getVariable(positionVariable).getVariable(rowsVariable).invokeStatic(CompilerOperations.class, "lessThan", boolean.class, int.class, int.class)).update(new BytecodeBlock().incrementVariable(positionVariable, (byte) 1)).body(loopBody));
return block;
}
Aggregations