Search in sources :

Example 26 with ClassDefinition

use of io.airlift.bytecode.ClassDefinition in project trino by trinodb.

the class StateCompiler method generateSingleStateClass.

private static <T> Class<? extends T> generateSingleStateClass(Class<T> clazz, Map<String, Type> fieldTypes, DynamicClassLoader classLoader) {
    ClassDefinition definition = new ClassDefinition(a(PUBLIC, FINAL), makeClassName("Single" + clazz.getSimpleName()), type(Object.class), type(clazz));
    FieldDefinition instanceSize = generateInstanceSize(definition);
    // Add getter for class size
    definition.declareMethod(a(PUBLIC), "getEstimatedSize", type(long.class)).getBody().getStaticField(instanceSize).retLong();
    // Generate constructor
    MethodDefinition constructor = definition.declareConstructor(a(PUBLIC));
    constructor.getBody().append(constructor.getThis()).invokeConstructor(Object.class);
    // Generate fields
    List<StateField> fields = enumerateFields(clazz, fieldTypes);
    List<FieldDefinition> fieldDefinitions = new ArrayList<>();
    for (StateField field : fields) {
        fieldDefinitions.add(generateField(definition, constructor, field));
    }
    constructor.getBody().ret();
    generateCopy(definition, fields, fieldDefinitions);
    return defineClass(definition, clazz, classLoader);
}
Also used : MethodDefinition(io.airlift.bytecode.MethodDefinition) FieldDefinition(io.airlift.bytecode.FieldDefinition) ArrayList(java.util.ArrayList) ClassDefinition(io.airlift.bytecode.ClassDefinition)

Example 27 with ClassDefinition

use of io.airlift.bytecode.ClassDefinition in project trino by trinodb.

the class StateCompiler method generateGroupedStateClass.

private static <T> Class<? extends T> generateGroupedStateClass(Class<T> clazz, Map<String, Type> fieldTypes, DynamicClassLoader classLoader) {
    ClassDefinition definition = new ClassDefinition(a(PUBLIC, FINAL), makeClassName("Grouped" + clazz.getSimpleName()), type(AbstractGroupedAccumulatorState.class), type(clazz));
    FieldDefinition instanceSize = generateInstanceSize(definition);
    List<StateField> fields = enumerateFields(clazz, fieldTypes);
    // Create constructor
    MethodDefinition constructor = definition.declareConstructor(a(PUBLIC));
    constructor.getBody().append(constructor.getThis()).invokeConstructor(AbstractGroupedAccumulatorState.class);
    // Create ensureCapacity
    MethodDefinition ensureCapacity = definition.declareMethod(a(PUBLIC), "ensureCapacity", type(void.class), arg("size", long.class));
    // Generate fields, constructor, and ensureCapacity
    List<FieldDefinition> fieldDefinitions = new ArrayList<>();
    for (StateField field : fields) {
        fieldDefinitions.add(generateGroupedField(definition, constructor, ensureCapacity, field));
    }
    constructor.getBody().ret();
    ensureCapacity.getBody().ret();
    // Generate getEstimatedSize
    MethodDefinition getEstimatedSize = definition.declareMethod(a(PUBLIC), "getEstimatedSize", type(long.class));
    BytecodeBlock body = getEstimatedSize.getBody();
    Variable size = getEstimatedSize.getScope().declareVariable(long.class, "size");
    // initialize size to the size of the instance
    body.append(size.set(getStatic(instanceSize)));
    // add field to size
    for (FieldDefinition field : fieldDefinitions) {
        body.append(size.set(add(size, getEstimatedSize.getThis().getField(field).invoke("sizeOf", long.class))));
    }
    // return size
    body.append(size.ret());
    return defineClass(definition, clazz, classLoader);
}
Also used : Variable(io.airlift.bytecode.Variable) MethodDefinition(io.airlift.bytecode.MethodDefinition) FieldDefinition(io.airlift.bytecode.FieldDefinition) ArrayList(java.util.ArrayList) BytecodeBlock(io.airlift.bytecode.BytecodeBlock) ClassDefinition(io.airlift.bytecode.ClassDefinition)

Example 28 with ClassDefinition

use of io.airlift.bytecode.ClassDefinition in project trino by trinodb.

the class AccumulatorCompiler method generateAccumulatorClass.

private static <T> Constructor<? extends T> generateAccumulatorClass(BoundSignature boundSignature, Class<T> accumulatorInterface, AggregationMetadata metadata, List<Boolean> argumentNullable, DynamicClassLoader classLoader) {
    boolean grouped = accumulatorInterface == GroupedAccumulator.class;
    ClassDefinition definition = new ClassDefinition(a(PUBLIC, FINAL), makeClassName(boundSignature.getName() + accumulatorInterface.getSimpleName()), type(Object.class), type(accumulatorInterface));
    CallSiteBinder callSiteBinder = new CallSiteBinder();
    List<AccumulatorStateDescriptor<?>> stateDescriptors = metadata.getAccumulatorStateDescriptors();
    List<StateFieldAndDescriptor> stateFieldAndDescriptors = new ArrayList<>();
    for (int i = 0; i < stateDescriptors.size(); i++) {
        stateFieldAndDescriptors.add(new StateFieldAndDescriptor(stateDescriptors.get(i), definition.declareField(a(PRIVATE, FINAL), "stateSerializer_" + i, AccumulatorStateSerializer.class), definition.declareField(a(PRIVATE, FINAL), "stateFactory_" + i, AccumulatorStateFactory.class), definition.declareField(a(PRIVATE, FINAL), "state_" + i, grouped ? GroupedAccumulatorState.class : AccumulatorState.class)));
    }
    List<FieldDefinition> stateFields = stateFieldAndDescriptors.stream().map(StateFieldAndDescriptor::getStateField).collect(toImmutableList());
    int lambdaCount = metadata.getLambdaInterfaces().size();
    List<FieldDefinition> lambdaProviderFields = new ArrayList<>(lambdaCount);
    for (int i = 0; i < lambdaCount; i++) {
        lambdaProviderFields.add(definition.declareField(a(PRIVATE, FINAL), "lambdaProvider_" + i, Supplier.class));
    }
    // Generate constructors
    generateConstructor(definition, stateFieldAndDescriptors, lambdaProviderFields, callSiteBinder, grouped);
    generateCopyConstructor(definition, stateFieldAndDescriptors, lambdaProviderFields);
    // Generate methods
    generateCopy(definition, Accumulator.class);
    generateAddInput(definition, stateFields, argumentNullable, lambdaProviderFields, metadata.getInputFunction(), callSiteBinder, grouped);
    generateGetEstimatedSize(definition, stateFields);
    generateAddIntermediateAsCombine(definition, stateFieldAndDescriptors, lambdaProviderFields, metadata.getCombineFunction(), callSiteBinder, grouped);
    if (grouped) {
        generateGroupedEvaluateIntermediate(definition, stateFieldAndDescriptors, true);
    } else {
        generateEvaluateIntermediate(definition, stateFieldAndDescriptors, true);
    }
    if (grouped) {
        generateGroupedEvaluateFinal(definition, stateFields, metadata.getOutputFunction(), callSiteBinder);
    } else {
        generateEvaluateFinal(definition, stateFields, metadata.getOutputFunction(), callSiteBinder);
    }
    if (grouped) {
        generatePrepareFinal(definition);
    }
    Class<? extends T> accumulatorClass = defineClass(definition, accumulatorInterface, callSiteBinder.getBindings(), classLoader);
    try {
        return accumulatorClass.getConstructor(List.class);
    } catch (NoSuchMethodException e) {
        throw new RuntimeException(e);
    }
}
Also used : GroupedAccumulatorState(io.trino.spi.function.GroupedAccumulatorState) AccumulatorState(io.trino.spi.function.AccumulatorState) AccumulatorStateDescriptor(io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor) FieldDefinition(io.airlift.bytecode.FieldDefinition) ArrayList(java.util.ArrayList) ClassDefinition(io.airlift.bytecode.ClassDefinition) CallSiteBinder(io.trino.sql.gen.CallSiteBinder) Supplier(java.util.function.Supplier) GroupedAccumulatorState(io.trino.spi.function.GroupedAccumulatorState)

Example 29 with ClassDefinition

use of io.airlift.bytecode.ClassDefinition in project trino by trinodb.

the class AccumulatorCompiler method generateWindowAccumulatorClass.

public static Constructor<? extends WindowAccumulator> generateWindowAccumulatorClass(BoundSignature boundSignature, AggregationMetadata metadata, FunctionNullability functionNullability) {
    DynamicClassLoader classLoader = new DynamicClassLoader(AccumulatorCompiler.class.getClassLoader());
    List<Boolean> argumentNullable = functionNullability.getArgumentNullable().subList(0, functionNullability.getArgumentNullable().size() - metadata.getLambdaInterfaces().size());
    ClassDefinition definition = new ClassDefinition(a(PUBLIC, FINAL), makeClassName(boundSignature.getName() + WindowAccumulator.class.getSimpleName()), type(Object.class), type(WindowAccumulator.class));
    CallSiteBinder callSiteBinder = new CallSiteBinder();
    List<AccumulatorStateDescriptor<?>> stateDescriptors = metadata.getAccumulatorStateDescriptors();
    List<StateFieldAndDescriptor> stateFieldAndDescriptors = new ArrayList<>();
    for (int i = 0; i < stateDescriptors.size(); i++) {
        stateFieldAndDescriptors.add(new StateFieldAndDescriptor(stateDescriptors.get(i), definition.declareField(a(PRIVATE, FINAL), "stateSerializer_" + i, AccumulatorStateSerializer.class), definition.declareField(a(PRIVATE, FINAL), "stateFactory_" + i, AccumulatorStateFactory.class), definition.declareField(a(PRIVATE, FINAL), "state_" + i, AccumulatorState.class)));
    }
    List<FieldDefinition> stateFields = stateFieldAndDescriptors.stream().map(StateFieldAndDescriptor::getStateField).collect(toImmutableList());
    int lambdaCount = metadata.getLambdaInterfaces().size();
    List<FieldDefinition> lambdaProviderFields = new ArrayList<>(lambdaCount);
    for (int i = 0; i < lambdaCount; i++) {
        lambdaProviderFields.add(definition.declareField(a(PRIVATE, FINAL), "lambdaProvider_" + i, Supplier.class));
    }
    // Generate constructor
    generateWindowAccumulatorConstructor(definition, stateFieldAndDescriptors, lambdaProviderFields, callSiteBinder);
    generateCopyConstructor(definition, stateFieldAndDescriptors, lambdaProviderFields);
    // Generate methods
    generateCopy(definition, WindowAccumulator.class);
    generateAddOrRemoveInputWindowIndex(definition, stateFields, argumentNullable, lambdaProviderFields, metadata.getInputFunction(), "addInput", callSiteBinder);
    metadata.getRemoveInputFunction().ifPresent(removeInputFunction -> generateAddOrRemoveInputWindowIndex(definition, stateFields, argumentNullable, lambdaProviderFields, removeInputFunction, "removeInput", callSiteBinder));
    generateEvaluateFinal(definition, stateFields, metadata.getOutputFunction(), callSiteBinder);
    generateGetEstimatedSize(definition, stateFields);
    Class<? extends WindowAccumulator> windowAccumulatorClass = defineClass(definition, WindowAccumulator.class, callSiteBinder.getBindings(), classLoader);
    try {
        return windowAccumulatorClass.getConstructor(List.class);
    } catch (NoSuchMethodException e) {
        throw new RuntimeException(e);
    }
}
Also used : DynamicClassLoader(io.airlift.bytecode.DynamicClassLoader) AccumulatorStateDescriptor(io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor) FieldDefinition(io.airlift.bytecode.FieldDefinition) ArrayList(java.util.ArrayList) ClassDefinition(io.airlift.bytecode.ClassDefinition) CallSiteBinder(io.trino.sql.gen.CallSiteBinder) Supplier(java.util.function.Supplier)

Example 30 with ClassDefinition

use of io.airlift.bytecode.ClassDefinition in project trino by trinodb.

the class MapFilterFunction method generateFilter.

private static MethodHandle generateFilter(MapType mapType) {
    CallSiteBinder binder = new CallSiteBinder();
    Type keyType = mapType.getKeyType();
    Type valueType = mapType.getValueType();
    Class<?> keyJavaType = Primitives.wrap(keyType.getJavaType());
    Class<?> valueJavaType = Primitives.wrap(valueType.getJavaType());
    ClassDefinition definition = new ClassDefinition(a(PUBLIC, FINAL), makeClassName("MapFilter"), type(Object.class));
    definition.declareDefaultConstructor(a(PRIVATE));
    Parameter state = arg("state", Object.class);
    Parameter block = arg("block", Block.class);
    Parameter function = arg("function", BinaryFunctionInterface.class);
    MethodDefinition method = definition.declareMethod(a(PUBLIC, STATIC), "filter", type(Block.class), ImmutableList.of(state, block, function));
    BytecodeBlock body = method.getBody();
    Scope scope = method.getScope();
    Variable positionCount = scope.declareVariable(int.class, "positionCount");
    Variable position = scope.declareVariable(int.class, "position");
    Variable pageBuilder = scope.declareVariable(PageBuilder.class, "pageBuilder");
    Variable mapBlockBuilder = scope.declareVariable(BlockBuilder.class, "mapBlockBuilder");
    Variable singleMapBlockWriter = scope.declareVariable(BlockBuilder.class, "singleMapBlockWriter");
    Variable keyElement = scope.declareVariable(keyJavaType, "keyElement");
    Variable valueElement = scope.declareVariable(valueJavaType, "valueElement");
    Variable keep = scope.declareVariable(Boolean.class, "keep");
    // invoke block.getPositionCount()
    body.append(positionCount.set(block.invoke("getPositionCount", int.class)));
    // prepare the single map block builder
    body.append(pageBuilder.set(state.cast(PageBuilder.class)));
    body.append(new IfStatement().condition(pageBuilder.invoke("isFull", boolean.class)).ifTrue(pageBuilder.invoke("reset", void.class)));
    body.append(mapBlockBuilder.set(pageBuilder.invoke("getBlockBuilder", BlockBuilder.class, constantInt(0))));
    body.append(singleMapBlockWriter.set(mapBlockBuilder.invoke("beginBlockEntry", BlockBuilder.class)));
    SqlTypeBytecodeExpression keySqlType = constantType(binder, keyType);
    BytecodeNode loadKeyElement;
    if (!keyType.equals(UNKNOWN)) {
        // key element must be non-null
        loadKeyElement = new BytecodeBlock().append(keyElement.set(keySqlType.getValue(block, position).cast(keyJavaType)));
    } else {
        loadKeyElement = new BytecodeBlock().append(keyElement.set(constantNull(keyJavaType)));
    }
    SqlTypeBytecodeExpression valueSqlType = constantType(binder, valueType);
    BytecodeNode loadValueElement;
    if (!valueType.equals(UNKNOWN)) {
        loadValueElement = new IfStatement().condition(block.invoke("isNull", boolean.class, add(position, constantInt(1)))).ifTrue(valueElement.set(constantNull(valueJavaType))).ifFalse(valueElement.set(valueSqlType.getValue(block, add(position, constantInt(1))).cast(valueJavaType)));
    } else {
        loadValueElement = new BytecodeBlock().append(valueElement.set(constantNull(valueJavaType)));
    }
    body.append(new ForLoop().initialize(position.set(constantInt(0))).condition(lessThan(position, positionCount)).update(incrementVariable(position, (byte) 2)).body(new BytecodeBlock().append(loadKeyElement).append(loadValueElement).append(keep.set(function.invoke("apply", Object.class, keyElement.cast(Object.class), valueElement.cast(Object.class)).cast(Boolean.class))).append(new IfStatement("if (keep != null && keep) ...").condition(and(notEqual(keep, constantNull(Boolean.class)), keep.cast(boolean.class))).ifTrue(new BytecodeBlock().append(keySqlType.invoke("appendTo", void.class, block, position, singleMapBlockWriter)).append(valueSqlType.invoke("appendTo", void.class, block, add(position, constantInt(1)), singleMapBlockWriter))))));
    body.append(mapBlockBuilder.invoke("closeEntry", BlockBuilder.class).pop());
    body.append(pageBuilder.invoke("declarePosition", void.class));
    body.append(constantType(binder, mapType).invoke("getObject", Object.class, mapBlockBuilder.cast(Block.class), subtract(mapBlockBuilder.invoke("getPositionCount", int.class), constantInt(1))).ret());
    Class<?> generatedClass = defineClass(definition, Object.class, binder.getBindings(), MapFilterFunction.class.getClassLoader());
    return methodHandle(generatedClass, "filter", Object.class, Block.class, BinaryFunctionInterface.class);
}
Also used : Signature.typeVariable(io.trino.metadata.Signature.typeVariable) Variable(io.airlift.bytecode.Variable) VariableInstruction.incrementVariable(io.airlift.bytecode.instruction.VariableInstruction.incrementVariable) ForLoop(io.airlift.bytecode.control.ForLoop) BytecodeBlock(io.airlift.bytecode.BytecodeBlock) ClassDefinition(io.airlift.bytecode.ClassDefinition) IfStatement(io.airlift.bytecode.control.IfStatement) SqlTypeBytecodeExpression.constantType(io.trino.sql.gen.SqlTypeBytecodeExpression.constantType) TypeSignature.mapType(io.trino.spi.type.TypeSignature.mapType) Type(io.trino.spi.type.Type) TypeSignature.functionType(io.trino.spi.type.TypeSignature.functionType) MapType(io.trino.spi.type.MapType) Scope(io.airlift.bytecode.Scope) MethodDefinition(io.airlift.bytecode.MethodDefinition) CallSiteBinder(io.trino.sql.gen.CallSiteBinder) Parameter(io.airlift.bytecode.Parameter) BytecodeBlock(io.airlift.bytecode.BytecodeBlock) Block(io.trino.spi.block.Block) BytecodeNode(io.airlift.bytecode.BytecodeNode) SqlTypeBytecodeExpression(io.trino.sql.gen.SqlTypeBytecodeExpression)

Aggregations

ClassDefinition (io.airlift.bytecode.ClassDefinition)57 MethodDefinition (io.airlift.bytecode.MethodDefinition)32 BytecodeBlock (io.airlift.bytecode.BytecodeBlock)31 Variable (io.airlift.bytecode.Variable)30 Parameter (io.airlift.bytecode.Parameter)29 Scope (io.airlift.bytecode.Scope)21 ImmutableList (com.google.common.collect.ImmutableList)16 FieldDefinition (io.airlift.bytecode.FieldDefinition)16 IfStatement (io.airlift.bytecode.control.IfStatement)16 CallSiteBinder (io.prestosql.sql.gen.CallSiteBinder)12 DynamicClassLoader (io.airlift.bytecode.DynamicClassLoader)11 CallSiteBinder (io.trino.sql.gen.CallSiteBinder)11 BytecodeNode (io.airlift.bytecode.BytecodeNode)10 List (java.util.List)10 ForLoop (io.airlift.bytecode.control.ForLoop)9 BytecodeExpression (io.airlift.bytecode.expression.BytecodeExpression)9 Block (io.prestosql.spi.block.Block)9 ImmutableList.toImmutableList (com.google.common.collect.ImmutableList.toImmutableList)8 VariableInstruction.incrementVariable (io.airlift.bytecode.instruction.VariableInstruction.incrementVariable)8 BlockBuilder (io.prestosql.spi.block.BlockBuilder)8