use of com.facebook.presto.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor in project presto by prestodb.
the class MultimapAggregationFunction method generateAggregation.
private InternalAggregationFunction generateAggregation(Type keyType, Type valueType, Type outputType) {
DynamicClassLoader classLoader = new DynamicClassLoader(MultimapAggregationFunction.class.getClassLoader());
List<Type> inputTypes = ImmutableList.of(keyType, valueType);
MultimapAggregationStateSerializer stateSerializer = new MultimapAggregationStateSerializer(keyType, valueType);
Type intermediateType = stateSerializer.getSerializedType();
AggregationMetadata metadata = new AggregationMetadata(generateAggregationName(NAME, outputType.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())), createInputParameterMetadata(keyType, valueType), INPUT_FUNCTION, COMBINE_FUNCTION, OUTPUT_FUNCTION.bindTo(keyType).bindTo(valueType), ImmutableList.of(new AccumulatorStateDescriptor(MultimapAggregationState.class, stateSerializer, new MultimapAggregationStateFactory(keyType, valueType, groupMode))), outputType);
GenericAccumulatorFactoryBinder factory = AccumulatorCompiler.generateAccumulatorFactoryBinder(metadata, classLoader);
return new InternalAggregationFunction(NAME, inputTypes, ImmutableList.of(intermediateType), outputType, true, true, factory);
}
use of com.facebook.presto.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor in project presto by prestodb.
the class ArbitraryAggregationFunction method generateAggregation.
private static InternalAggregationFunction generateAggregation(Type type) {
DynamicClassLoader classLoader = new DynamicClassLoader(ArbitraryAggregationFunction.class.getClassLoader());
List<Type> inputTypes = ImmutableList.of(type);
MethodHandle inputFunction;
MethodHandle combineFunction;
MethodHandle outputFunction;
Class<? extends AccumulatorState> stateInterface;
AccumulatorStateSerializer<?> stateSerializer;
if (type.getJavaType() == long.class) {
stateInterface = NullableLongState.class;
stateSerializer = StateCompiler.generateStateSerializer(stateInterface, classLoader);
inputFunction = LONG_INPUT_FUNCTION;
combineFunction = LONG_COMBINE_FUNCTION;
outputFunction = LONG_OUTPUT_FUNCTION;
} else if (type.getJavaType() == double.class) {
stateInterface = NullableDoubleState.class;
stateSerializer = StateCompiler.generateStateSerializer(stateInterface, classLoader);
inputFunction = DOUBLE_INPUT_FUNCTION;
combineFunction = DOUBLE_COMBINE_FUNCTION;
outputFunction = DOUBLE_OUTPUT_FUNCTION;
} else if (type.getJavaType() == boolean.class) {
stateInterface = NullableBooleanState.class;
stateSerializer = StateCompiler.generateStateSerializer(stateInterface, classLoader);
inputFunction = BOOLEAN_INPUT_FUNCTION;
combineFunction = BOOLEAN_COMBINE_FUNCTION;
outputFunction = BOOLEAN_OUTPUT_FUNCTION;
} else {
// native container type is Slice or Block
stateInterface = BlockPositionState.class;
stateSerializer = new BlockPositionStateSerializer(type);
inputFunction = BLOCK_POSITION_INPUT_FUNCTION;
combineFunction = BLOCK_POSITION_COMBINE_FUNCTION;
outputFunction = BLOCK_POSITION_OUTPUT_FUNCTION;
}
inputFunction = inputFunction.bindTo(type);
Type intermediateType = stateSerializer.getSerializedType();
List<ParameterMetadata> inputParameterMetadata = createInputParameterMetadata(type);
AggregationMetadata metadata = new AggregationMetadata(generateAggregationName(NAME, type.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())), inputParameterMetadata, inputFunction, combineFunction, outputFunction.bindTo(type), ImmutableList.of(new AccumulatorStateDescriptor(stateInterface, stateSerializer, StateCompiler.generateStateFactory(stateInterface, classLoader))), type);
GenericAccumulatorFactoryBinder factory = AccumulatorCompiler.generateAccumulatorFactoryBinder(metadata, classLoader);
return new InternalAggregationFunction(NAME, inputTypes, ImmutableList.of(intermediateType), type, true, false, factory);
}
use of com.facebook.presto.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor in project presto by prestodb.
the class AccumulatorCompiler method generateAccumulatorClass.
private static <T> Class<? extends T> generateAccumulatorClass(Class<T> accumulatorInterface, AggregationMetadata metadata, DynamicClassLoader classLoader) {
boolean grouped = accumulatorInterface == GroupedAccumulator.class;
ClassDefinition definition = new ClassDefinition(a(PUBLIC, FINAL), makeClassName(metadata.getName() + accumulatorInterface.getSimpleName()), type(Object.class), type(accumulatorInterface));
CallSiteBinder callSiteBinder = new CallSiteBinder();
List<AccumulatorStateDescriptor> stateDescriptors = metadata.getAccumulatorStateDescriptors();
List<StateFieldAndDescriptor> stateFieldAndDescriptors = new ArrayList<>();
for (int i = 0; i < stateDescriptors.size(); i++) {
stateFieldAndDescriptors.add(new StateFieldAndDescriptor(definition.declareField(a(PRIVATE, FINAL), "stateSerializer_" + i, AccumulatorStateSerializer.class), definition.declareField(a(PRIVATE, FINAL), "stateFactory_" + i, AccumulatorStateFactory.class), definition.declareField(a(PRIVATE, FINAL), "state_" + i, grouped ? stateDescriptors.get(i).getFactory().getGroupedStateClass() : stateDescriptors.get(i).getFactory().getSingleStateClass()), stateDescriptors.get(i)));
}
List<FieldDefinition> stateFileds = stateFieldAndDescriptors.stream().map(StateFieldAndDescriptor::getStateField).collect(toImmutableList());
int lambdaCount = metadata.getLambdaInterfaces().size();
List<FieldDefinition> lambdaProviderFields = new ArrayList<>(lambdaCount);
for (int i = 0; i < lambdaCount; i++) {
lambdaProviderFields.add(definition.declareField(a(PRIVATE, FINAL), "lambdaProvider_" + i, LambdaProvider.class));
}
FieldDefinition maskChannelField = definition.declareField(a(PRIVATE, FINAL), "maskChannel", int.class);
int inputChannelCount = countInputChannels(metadata.getValueInputMetadata());
ImmutableList.Builder<FieldDefinition> inputFieldsBuilder = ImmutableList.builderWithExpectedSize(inputChannelCount);
for (int i = 0; i < inputChannelCount; i++) {
inputFieldsBuilder.add(definition.declareField(a(PRIVATE, FINAL), "inputChannel" + i, int.class));
}
List<FieldDefinition> inputChannelFields = inputFieldsBuilder.build();
// Generate constructor
generateConstructor(definition, stateFieldAndDescriptors, inputChannelFields, maskChannelField, lambdaProviderFields, grouped);
// Generate methods
generateAddInput(definition, stateFileds, inputChannelFields, maskChannelField, metadata.getValueInputMetadata(), metadata.getLambdaInterfaces(), lambdaProviderFields, metadata.getInputFunction(), callSiteBinder, grouped);
generateAddInputWindowIndex(definition, stateFileds, metadata.getValueInputMetadata(), metadata.getLambdaInterfaces(), lambdaProviderFields, metadata.getInputFunction(), callSiteBinder);
generateGetEstimatedSize(definition, stateFileds);
generateGetIntermediateType(definition, callSiteBinder, stateDescriptors.stream().map(stateDescriptor -> stateDescriptor.getSerializer().getSerializedType()).collect(toImmutableList()));
generateGetFinalType(definition, callSiteBinder, metadata.getOutputType());
generateAddIntermediateAsCombine(definition, stateFieldAndDescriptors, metadata.getLambdaInterfaces(), lambdaProviderFields, metadata.getCombineFunction(), callSiteBinder, grouped);
if (grouped) {
generateGroupedEvaluateIntermediate(definition, stateFieldAndDescriptors);
} else {
generateEvaluateIntermediate(definition, stateFieldAndDescriptors);
}
if (grouped) {
generateGroupedEvaluateFinal(definition, stateFileds, metadata.getOutputFunction(), callSiteBinder);
} else {
generateEvaluateFinal(definition, stateFileds, metadata.getOutputFunction(), callSiteBinder);
}
if (grouped) {
generatePrepareFinal(definition);
}
return defineClass(definition, accumulatorInterface, callSiteBinder.getBindings(), classLoader);
}
use of com.facebook.presto.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor in project presto by prestodb.
the class CountColumn method generateAggregation.
private static InternalAggregationFunction generateAggregation(Type type) {
DynamicClassLoader classLoader = new DynamicClassLoader(CountColumn.class.getClassLoader());
AccumulatorStateSerializer<LongState> stateSerializer = StateCompiler.generateStateSerializer(LongState.class, classLoader);
AccumulatorStateFactory<LongState> stateFactory = StateCompiler.generateStateFactory(LongState.class, classLoader);
Type intermediateType = stateSerializer.getSerializedType();
List<Type> inputTypes = ImmutableList.of(type);
AggregationMetadata metadata = new AggregationMetadata(generateAggregationName(NAME, BIGINT.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())), createInputParameterMetadata(type), INPUT_FUNCTION, COMBINE_FUNCTION, OUTPUT_FUNCTION, ImmutableList.of(new AccumulatorStateDescriptor(LongState.class, stateSerializer, stateFactory)), BIGINT);
GenericAccumulatorFactoryBinder factory = AccumulatorCompiler.generateAccumulatorFactoryBinder(metadata, classLoader);
return new InternalAggregationFunction(NAME, inputTypes, ImmutableList.of(intermediateType), BIGINT, true, false, factory);
}
use of com.facebook.presto.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor in project presto by prestodb.
the class DecimalAverageAggregation method generateAggregation.
private static InternalAggregationFunction generateAggregation(Type type) {
checkArgument(type instanceof DecimalType, "type must be Decimal");
DynamicClassLoader classLoader = new DynamicClassLoader(DecimalAverageAggregation.class.getClassLoader());
List<Type> inputTypes = ImmutableList.of(type);
MethodHandle inputFunction;
MethodHandle outputFunction;
Class<? extends AccumulatorState> stateInterface = LongDecimalWithOverflowAndLongState.class;
AccumulatorStateSerializer<?> stateSerializer = new LongDecimalWithOverflowAndLongStateSerializer();
if (((DecimalType) type).isShort()) {
inputFunction = SHORT_DECIMAL_INPUT_FUNCTION;
outputFunction = SHORT_DECIMAL_OUTPUT_FUNCTION;
} else {
inputFunction = LONG_DECIMAL_INPUT_FUNCTION;
outputFunction = LONG_DECIMAL_OUTPUT_FUNCTION;
}
outputFunction = outputFunction.bindTo(type);
AggregationMetadata metadata = new AggregationMetadata(generateAggregationName(NAME, type.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())), createInputParameterMetadata(type), inputFunction, COMBINE_FUNCTION, outputFunction, ImmutableList.of(new AccumulatorStateDescriptor(stateInterface, stateSerializer, new LongDecimalWithOverflowAndLongStateFactory())), type);
Type intermediateType = stateSerializer.getSerializedType();
GenericAccumulatorFactoryBinder factory = AccumulatorCompiler.generateAccumulatorFactoryBinder(metadata, classLoader);
return new InternalAggregationFunction(NAME, inputTypes, ImmutableList.of(intermediateType), type, true, false, factory);
}
Aggregations