use of com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata in project presto by prestodb.
the class AccumulatorCompiler method generateInputForLoop.
private static BytecodeBlock generateInputForLoop(FieldDefinition stateField, List<ParameterMetadata> parameterMetadatas, MethodHandle inputFunction, Scope scope, List<Variable> parameterVariables, Variable masksBlock, CallSiteBinder callSiteBinder, boolean grouped) {
// For-loop over rows
Variable page = scope.getVariable("page");
Variable positionVariable = scope.declareVariable(int.class, "position");
Variable rowsVariable = scope.declareVariable(int.class, "rows");
BytecodeBlock block = new BytecodeBlock().append(page).invokeVirtual(Page.class, "getPositionCount", int.class).putVariable(rowsVariable).initializeVariable(positionVariable);
BytecodeNode loopBody = generateInvokeInputFunction(scope, stateField, positionVariable, parameterVariables, parameterMetadatas, inputFunction, callSiteBinder, grouped);
// Wrap with null checks
List<Boolean> nullable = new ArrayList<>();
for (ParameterMetadata metadata : parameterMetadatas) {
switch(metadata.getParameterType()) {
case INPUT_CHANNEL:
case BLOCK_INPUT_CHANNEL:
nullable.add(false);
break;
case NULLABLE_BLOCK_INPUT_CHANNEL:
nullable.add(true);
break;
// do nothing
default:
}
}
checkState(nullable.size() == parameterVariables.size(), "Number of parameters does not match");
for (int i = 0; i < parameterVariables.size(); i++) {
if (!nullable.get(i)) {
Variable variableDefinition = parameterVariables.get(i);
loopBody = new IfStatement("if(!%s.isNull(position))", variableDefinition.getName()).condition(new BytecodeBlock().getVariable(variableDefinition).getVariable(positionVariable).invokeInterface(Block.class, "isNull", boolean.class, int.class)).ifFalse(loopBody);
}
}
loopBody = new IfStatement("if(testMask(%s, position))", masksBlock.getName()).condition(new BytecodeBlock().getVariable(masksBlock).getVariable(positionVariable).invokeStatic(CompilerOperations.class, "testMask", boolean.class, Block.class, int.class)).ifTrue(loopBody);
block.append(new ForLoop().initialize(new BytecodeBlock().putVariable(positionVariable, 0)).condition(new BytecodeBlock().getVariable(positionVariable).getVariable(rowsVariable).invokeStatic(CompilerOperations.class, "lessThan", boolean.class, int.class, int.class)).update(new BytecodeBlock().incrementVariable(positionVariable, (byte) 1)).body(loopBody));
return block;
}
use of com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata in project presto by prestodb.
the class ArbitraryAggregationFunction method generateAggregation.
private static InternalAggregationFunction generateAggregation(Type type) {
DynamicClassLoader classLoader = new DynamicClassLoader(ArbitraryAggregationFunction.class.getClassLoader());
List<Type> inputTypes = ImmutableList.of(type);
MethodHandle inputFunction;
MethodHandle combineFunction;
MethodHandle outputFunction;
Class<? extends AccumulatorState> stateInterface;
AccumulatorStateSerializer<?> stateSerializer;
if (type.getJavaType() == long.class) {
stateInterface = NullableLongState.class;
stateSerializer = StateCompiler.generateStateSerializer(stateInterface, classLoader);
inputFunction = LONG_INPUT_FUNCTION;
combineFunction = LONG_COMBINE_FUNCTION;
outputFunction = LONG_OUTPUT_FUNCTION;
} else if (type.getJavaType() == double.class) {
stateInterface = NullableDoubleState.class;
stateSerializer = StateCompiler.generateStateSerializer(stateInterface, classLoader);
inputFunction = DOUBLE_INPUT_FUNCTION;
combineFunction = DOUBLE_COMBINE_FUNCTION;
outputFunction = DOUBLE_OUTPUT_FUNCTION;
} else if (type.getJavaType() == boolean.class) {
stateInterface = NullableBooleanState.class;
stateSerializer = StateCompiler.generateStateSerializer(stateInterface, classLoader);
inputFunction = BOOLEAN_INPUT_FUNCTION;
combineFunction = BOOLEAN_COMBINE_FUNCTION;
outputFunction = BOOLEAN_OUTPUT_FUNCTION;
} else {
// native container type is Slice or Block
stateInterface = BlockPositionState.class;
stateSerializer = new BlockPositionStateSerializer(type);
inputFunction = BLOCK_POSITION_INPUT_FUNCTION;
combineFunction = BLOCK_POSITION_COMBINE_FUNCTION;
outputFunction = BLOCK_POSITION_OUTPUT_FUNCTION;
}
inputFunction = inputFunction.bindTo(type);
Type intermediateType = stateSerializer.getSerializedType();
List<ParameterMetadata> inputParameterMetadata = createInputParameterMetadata(type);
AggregationMetadata metadata = new AggregationMetadata(generateAggregationName(NAME, type.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())), inputParameterMetadata, inputFunction, combineFunction, outputFunction.bindTo(type), ImmutableList.of(new AccumulatorStateDescriptor(stateInterface, stateSerializer, StateCompiler.generateStateFactory(stateInterface, classLoader))), type);
GenericAccumulatorFactoryBinder factory = AccumulatorCompiler.generateAccumulatorFactoryBinder(metadata, classLoader);
return new InternalAggregationFunction(NAME, inputTypes, ImmutableList.of(intermediateType), type, true, false, factory);
}
use of com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata in project presto by prestodb.
the class AccumulatorCompiler method getInvokeFunctionOnWindowIndexParameters.
private static List<BytecodeExpression> getInvokeFunctionOnWindowIndexParameters(Scope scope, Class<?>[] parameterTypes, List<ParameterMetadata> parameterMetadatas, List<Class> lambdaInterfaces, List<FieldDefinition> lambdaProviderFields, List<FieldDefinition> stateField, Variable index, Variable channels, Variable position) {
int inputChannel = 0;
int stateIndex = 0;
verify(parameterTypes.length == parameterMetadatas.size() + lambdaInterfaces.size());
List<BytecodeExpression> expressions = new ArrayList<>();
for (int i = 0; i < parameterMetadatas.size(); i++) {
ParameterMetadata parameterMetadata = parameterMetadatas.get(i);
Class<?> parameterType = parameterTypes[i];
BytecodeExpression getChannel = channels.invoke("get", Object.class, constantInt(inputChannel)).cast(int.class);
switch(parameterMetadata.getParameterType()) {
case STATE:
expressions.add(scope.getThis().getField(stateField.get(stateIndex)));
stateIndex++;
break;
case BLOCK_INDEX:
// index.getSingleValueBlock(channel, position) generates always a page with only one position
expressions.add(constantInt(0));
break;
case BLOCK_INPUT_CHANNEL:
case NULLABLE_BLOCK_INPUT_CHANNEL:
expressions.add(index.invoke("getSingleValueBlock", Block.class, getChannel, position));
inputChannel++;
break;
case INPUT_CHANNEL:
if (parameterType == long.class) {
expressions.add(index.invoke("getLong", long.class, getChannel, position));
} else if (parameterType == double.class) {
expressions.add(index.invoke("getDouble", double.class, getChannel, position));
} else if (parameterType == boolean.class) {
expressions.add(index.invoke("getBoolean", boolean.class, getChannel, position));
} else if (parameterType == Slice.class) {
expressions.add(index.invoke("getSlice", Slice.class, getChannel, position));
} else if (parameterType == Block.class) {
// Even though the method signature requires a Block parameter, we can pass an Object here.
// A runtime check will assert that the Object passed as a parameter is actually of type Block.
expressions.add(index.invoke("getObject", Object.class, getChannel, position));
} else {
throw new IllegalArgumentException(format("Unsupported parameter type: %s", parameterType));
}
inputChannel++;
break;
default:
throw new IllegalArgumentException("Unsupported parameter type: " + parameterMetadata.getParameterType());
}
}
for (int i = 0; i < lambdaInterfaces.size(); i++) {
expressions.add(scope.getThis().getField(lambdaProviderFields.get(i)).invoke("getLambda", Object.class).cast(lambdaInterfaces.get(i)));
}
return expressions;
}
use of com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata in project presto by prestodb.
the class SetAggregationFunction method generateAggregation.
private static InternalAggregationFunction generateAggregation(Type type, ArrayType outputType) {
DynamicClassLoader classLoader = new DynamicClassLoader(SetAggregationFunction.class.getClassLoader());
List<Type> inputTypes = ImmutableList.of(type);
AccumulatorStateSerializer<?> stateSerializer = new SetAggregationStateSerializer(outputType);
AccumulatorStateFactory<?> stateFactory = new SetAggregationStateFactory(type);
Type intermediateType = stateSerializer.getSerializedType();
List<ParameterMetadata> inputParameterMetadata = createInputParameterMetadata(type);
Class<? extends AccumulatorState> stateInterface = SetAggregationState.class;
AggregationMetadata metadata = new AggregationMetadata(generateAggregationName(NAME, outputType.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())), inputParameterMetadata, INPUT_FUNCTION.bindTo(type), COMBINE_FUNCTION, OUTPUT_FUNCTION, ImmutableList.of(new AccumulatorStateDescriptor(stateInterface, stateSerializer, stateFactory)), outputType);
GenericAccumulatorFactoryBinder factory = AccumulatorCompiler.generateAccumulatorFactoryBinder(metadata, classLoader);
return new InternalAggregationFunction(NAME, inputTypes, ImmutableList.of(intermediateType), outputType, true, true, factory);
}
use of com.facebook.presto.operator.aggregation.AggregationMetadata.ParameterMetadata in project presto by prestodb.
the class ParametricAggregation method buildParameterMetadata.
private static List<ParameterMetadata> buildParameterMetadata(List<ParameterType> parameterMetadataTypes, List<Type> inputTypes) {
ImmutableList.Builder<ParameterMetadata> builder = ImmutableList.builder();
int inputId = 0;
for (ParameterType parameterMetadataType : parameterMetadataTypes) {
switch(parameterMetadataType) {
case STATE:
case BLOCK_INDEX:
builder.add(new ParameterMetadata(parameterMetadataType));
break;
case INPUT_CHANNEL:
case BLOCK_INPUT_CHANNEL:
case NULLABLE_BLOCK_INPUT_CHANNEL:
builder.add(new ParameterMetadata(parameterMetadataType, inputTypes.get(inputId++)));
break;
}
}
return builder.build();
}
Aggregations