use of com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata in project presto by prestodb.
the class AccumulatorCompiler method generateInvokeInputFunction.
private static BytecodeBlock generateInvokeInputFunction(Scope scope, FieldDefinition stateField, Variable position, List<Variable> parameterVariables, List<ParameterMetadata> parameterMetadatas, MethodHandle inputFunction, CallSiteBinder callSiteBinder, boolean grouped) {
BytecodeBlock block = new BytecodeBlock();
if (grouped) {
generateSetGroupIdFromGroupIdsBlock(scope, stateField, block);
}
block.comment("Call input function with unpacked Block arguments");
Class<?>[] parameters = inputFunction.type().parameterArray();
int inputChannel = 0;
for (int i = 0; i < parameters.length; i++) {
ParameterMetadata parameterMetadata = parameterMetadatas.get(i);
switch(parameterMetadata.getParameterType()) {
case STATE:
block.append(scope.getThis().getField(stateField));
break;
case BLOCK_INDEX:
block.getVariable(position);
break;
case BLOCK_INPUT_CHANNEL:
case NULLABLE_BLOCK_INPUT_CHANNEL:
block.getVariable(parameterVariables.get(inputChannel));
inputChannel++;
break;
case INPUT_CHANNEL:
BytecodeBlock getBlockBytecode = new BytecodeBlock().getVariable(parameterVariables.get(inputChannel));
pushStackType(scope, block, parameterMetadata.getSqlType(), getBlockBytecode, parameters[i], callSiteBinder);
inputChannel++;
break;
default:
throw new IllegalArgumentException("Unsupported parameter type: " + parameterMetadata.getParameterType());
}
}
block.append(invoke(callSiteBinder.bind(inputFunction), "input"));
return block;
}
use of com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata in project presto by prestodb.
the class BindableAggregationFunction method getParameterMetadata.
private static List<ParameterMetadata> getParameterMetadata(@Nullable Method method, List<Type> inputTypes) {
if (method == null) {
return null;
}
ImmutableList.Builder<ParameterMetadata> builder = ImmutableList.builder();
Annotation[][] annotations = method.getParameterAnnotations();
String methodName = method.getDeclaringClass() + "." + method.getName();
checkArgument(annotations.length > 0, "At least @AggregationState argument is required for each of aggregation functions.");
int inputId = 0;
int i = 0;
if (annotations[0].length == 0) {
// Backward compatibility - first argument without annotations is interpreted as State argument
builder.add(new ParameterMetadata(STATE));
i++;
}
for (; i < annotations.length; i++) {
Annotation baseTypeAnnotation = baseTypeAnnotation(annotations[i], methodName);
if (baseTypeAnnotation instanceof SqlType) {
builder.add(fromSqlType(inputTypes.get(i - 1), isParameterBlock(annotations[i]), isParameterNullable(annotations[i]), methodName));
} else if (baseTypeAnnotation instanceof BlockIndex) {
builder.add(new ParameterMetadata(BLOCK_INDEX));
} else if (baseTypeAnnotation instanceof AggregationState) {
builder.add(new ParameterMetadata(STATE));
} else {
throw new IllegalArgumentException("Unsupported annotation: " + annotations[i]);
}
}
return builder.build();
}
use of com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata in project presto by prestodb.
the class ArrayAggregationFunction method generateAggregation.
private static InternalAggregationFunction generateAggregation(Type type, boolean legacyArrayAgg) {
DynamicClassLoader classLoader = new DynamicClassLoader(ArrayAggregationFunction.class.getClassLoader());
AccumulatorStateSerializer<?> stateSerializer = new ArrayAggregationStateSerializer(type);
AccumulatorStateFactory<?> stateFactory = new ArrayAggregationStateFactory();
List<Type> inputTypes = ImmutableList.of(type);
Type outputType = new ArrayType(type);
Type intermediateType = stateSerializer.getSerializedType();
List<ParameterMetadata> inputParameterMetadata = createInputParameterMetadata(type, legacyArrayAgg);
MethodHandle inputFunction = INPUT_FUNCTION.bindTo(type);
MethodHandle combineFunction = COMBINE_FUNCTION.bindTo(type);
MethodHandle outputFunction = OUTPUT_FUNCTION.bindTo(outputType);
Class<? extends AccumulatorState> stateInterface = ArrayAggregationState.class;
AggregationMetadata metadata = new AggregationMetadata(generateAggregationName(NAME, type.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())), inputParameterMetadata, inputFunction, combineFunction, outputFunction, stateInterface, stateSerializer, stateFactory, outputType);
GenericAccumulatorFactoryBinder factory = AccumulatorCompiler.generateAccumulatorFactoryBinder(metadata, classLoader);
return new InternalAggregationFunction(NAME, inputTypes, intermediateType, outputType, true, factory);
}
use of com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata in project presto by prestodb.
the class AccumulatorCompiler method generateInvokeInputFunction.
private static BytecodeBlock generateInvokeInputFunction(Scope scope, List<FieldDefinition> stateField, Variable position, List<Variable> parameterVariables, List<ParameterMetadata> parameterMetadatas, List<Class> lambdaInterfaces, List<FieldDefinition> lambdaProviderFields, MethodHandle inputFunction, CallSiteBinder callSiteBinder, boolean grouped) {
BytecodeBlock block = new BytecodeBlock();
if (grouped) {
generateSetGroupIdFromGroupIdsBlock(scope, stateField, block);
}
block.comment("Call input function with unpacked Block arguments");
Class<?>[] parameters = inputFunction.type().parameterArray();
int inputChannel = 0;
int stateIndex = 0;
verify(parameters.length == parameterMetadatas.size() + lambdaInterfaces.size());
for (int i = 0; i < parameterMetadatas.size(); i++) {
ParameterMetadata parameterMetadata = parameterMetadatas.get(i);
switch(parameterMetadata.getParameterType()) {
case STATE:
block.append(scope.getThis().getField(stateField.get(stateIndex)));
stateIndex++;
break;
case BLOCK_INDEX:
block.getVariable(position);
break;
case BLOCK_INPUT_CHANNEL:
case NULLABLE_BLOCK_INPUT_CHANNEL:
block.getVariable(parameterVariables.get(inputChannel));
inputChannel++;
break;
case INPUT_CHANNEL:
BytecodeBlock getBlockBytecode = new BytecodeBlock().getVariable(parameterVariables.get(inputChannel));
pushStackType(scope, block, parameterMetadata.getSqlType(), getBlockBytecode, parameters[i], callSiteBinder);
inputChannel++;
break;
default:
throw new IllegalArgumentException("Unsupported parameter type: " + parameterMetadata.getParameterType());
}
}
for (int i = 0; i < lambdaInterfaces.size(); i++) {
block.append(scope.getThis().getField(lambdaProviderFields.get(i)).invoke("getLambda", Object.class).cast(lambdaInterfaces.get(i)));
}
block.append(invoke(callSiteBinder.bind(inputFunction), "input"));
return block;
}
use of com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata in project presto by prestodb.
the class AccumulatorCompiler method generateInputForLoop.
private static BytecodeBlock generateInputForLoop(List<FieldDefinition> stateField, List<ParameterMetadata> parameterMetadatas, MethodHandle inputFunction, Scope scope, List<Variable> parameterVariables, List<Class> lambdaInterfaces, List<FieldDefinition> lambdaProviderFields, Variable masksBlock, CallSiteBinder callSiteBinder, boolean grouped) {
// For-loop over rows
Variable page = scope.getVariable("page");
Variable positionVariable = scope.declareVariable(int.class, "position");
Variable rowsVariable = scope.declareVariable(int.class, "rows");
BytecodeBlock block = new BytecodeBlock().append(page).invokeVirtual(Page.class, "getPositionCount", int.class).putVariable(rowsVariable).initializeVariable(positionVariable);
BytecodeNode loopBody = generateInvokeInputFunction(scope, stateField, positionVariable, parameterVariables, parameterMetadatas, lambdaInterfaces, lambdaProviderFields, inputFunction, callSiteBinder, grouped);
// Wrap with null checks
List<Boolean> nullable = new ArrayList<>();
for (ParameterMetadata metadata : parameterMetadatas) {
switch(metadata.getParameterType()) {
case INPUT_CHANNEL:
case BLOCK_INPUT_CHANNEL:
nullable.add(false);
break;
case NULLABLE_BLOCK_INPUT_CHANNEL:
nullable.add(true);
break;
// do nothing
default:
}
}
checkState(nullable.size() == parameterVariables.size(), "Number of parameters does not match");
for (int i = 0; i < parameterVariables.size(); i++) {
if (!nullable.get(i)) {
Variable variableDefinition = parameterVariables.get(i);
loopBody = new IfStatement("if(!%s.isNull(position))", variableDefinition.getName()).condition(new BytecodeBlock().getVariable(variableDefinition).getVariable(positionVariable).invokeInterface(Block.class, "isNull", boolean.class, int.class)).ifFalse(loopBody);
}
}
loopBody = new IfStatement("if(testMask(%s, position))", masksBlock.getName()).condition(new BytecodeBlock().getVariable(masksBlock).getVariable(positionVariable).invokeStatic(CompilerOperations.class, "testMask", boolean.class, Block.class, int.class)).ifTrue(loopBody);
BytecodeNode forLoop = new ForLoop().initialize(new BytecodeBlock().putVariable(positionVariable, 0)).condition(new BytecodeBlock().getVariable(positionVariable).getVariable(rowsVariable).invokeStatic(CompilerOperations.class, "lessThan", boolean.class, int.class, int.class)).update(new BytecodeBlock().incrementVariable(positionVariable, (byte) 1)).body(loopBody);
for (int i = 0; i < parameterVariables.size(); i++) {
if (!nullable.get(i)) {
Variable variableDefinition = parameterVariables.get(i);
forLoop = new IfStatement("if(!(%s instanceof RunLengthEncodedBlock && %s.isNull(0)))", variableDefinition.getName(), variableDefinition.getName()).condition(and(variableDefinition.instanceOf(RunLengthEncodedBlock.class), variableDefinition.invoke("isNull", boolean.class, constantInt(0)))).ifFalse(forLoop);
}
}
// Skip input blocks that eliminate all input positions
forLoop = new IfStatement("if(!(%s instanceof RunLengthEncodedBlock && !testMask(%s, 0)))", masksBlock.getName(), masksBlock.getName()).condition(and(masksBlock.instanceOf(RunLengthEncodedBlock.class), not(invokeStatic(type(CompilerOperations.class), "testMask", type(boolean.class), ImmutableList.of(type(Block.class), type(int.class)), masksBlock, constantInt(0))))).ifFalse(forLoop);
block.append(new IfStatement("if(%s > 0)", rowsVariable.getName()).condition(new BytecodeBlock().getVariable(rowsVariable).push(0).invokeStatic(CompilerOperations.class, "greaterThan", boolean.class, int.class, int.class)).ifTrue(forLoop));
return block;
}
Aggregations