use of io.airlift.bytecode.FieldDefinition in project trino by trinodb.
the class AccumulatorCompiler method generateAccumulatorClass.
private static <T> Constructor<? extends T> generateAccumulatorClass(BoundSignature boundSignature, Class<T> accumulatorInterface, AggregationMetadata metadata, List<Boolean> argumentNullable, DynamicClassLoader classLoader) {
boolean grouped = accumulatorInterface == GroupedAccumulator.class;
ClassDefinition definition = new ClassDefinition(a(PUBLIC, FINAL), makeClassName(boundSignature.getName() + accumulatorInterface.getSimpleName()), type(Object.class), type(accumulatorInterface));
CallSiteBinder callSiteBinder = new CallSiteBinder();
List<AccumulatorStateDescriptor<?>> stateDescriptors = metadata.getAccumulatorStateDescriptors();
List<StateFieldAndDescriptor> stateFieldAndDescriptors = new ArrayList<>();
for (int i = 0; i < stateDescriptors.size(); i++) {
stateFieldAndDescriptors.add(new StateFieldAndDescriptor(stateDescriptors.get(i), definition.declareField(a(PRIVATE, FINAL), "stateSerializer_" + i, AccumulatorStateSerializer.class), definition.declareField(a(PRIVATE, FINAL), "stateFactory_" + i, AccumulatorStateFactory.class), definition.declareField(a(PRIVATE, FINAL), "state_" + i, grouped ? GroupedAccumulatorState.class : AccumulatorState.class)));
}
List<FieldDefinition> stateFields = stateFieldAndDescriptors.stream().map(StateFieldAndDescriptor::getStateField).collect(toImmutableList());
int lambdaCount = metadata.getLambdaInterfaces().size();
List<FieldDefinition> lambdaProviderFields = new ArrayList<>(lambdaCount);
for (int i = 0; i < lambdaCount; i++) {
lambdaProviderFields.add(definition.declareField(a(PRIVATE, FINAL), "lambdaProvider_" + i, Supplier.class));
}
// Generate constructors
generateConstructor(definition, stateFieldAndDescriptors, lambdaProviderFields, callSiteBinder, grouped);
generateCopyConstructor(definition, stateFieldAndDescriptors, lambdaProviderFields);
// Generate methods
generateCopy(definition, Accumulator.class);
generateAddInput(definition, stateFields, argumentNullable, lambdaProviderFields, metadata.getInputFunction(), callSiteBinder, grouped);
generateGetEstimatedSize(definition, stateFields);
generateAddIntermediateAsCombine(definition, stateFieldAndDescriptors, lambdaProviderFields, metadata.getCombineFunction(), callSiteBinder, grouped);
if (grouped) {
generateGroupedEvaluateIntermediate(definition, stateFieldAndDescriptors, true);
} else {
generateEvaluateIntermediate(definition, stateFieldAndDescriptors, true);
}
if (grouped) {
generateGroupedEvaluateFinal(definition, stateFields, metadata.getOutputFunction(), callSiteBinder);
} else {
generateEvaluateFinal(definition, stateFields, metadata.getOutputFunction(), callSiteBinder);
}
if (grouped) {
generatePrepareFinal(definition);
}
Class<? extends T> accumulatorClass = defineClass(definition, accumulatorInterface, callSiteBinder.getBindings(), classLoader);
try {
return accumulatorClass.getConstructor(List.class);
} catch (NoSuchMethodException e) {
throw new RuntimeException(e);
}
}
use of io.airlift.bytecode.FieldDefinition in project trino by trinodb.
the class AccumulatorCompiler method initializeStateFields.
private static void initializeStateFields(MethodDefinition method, List<StateFieldAndDescriptor> stateFieldAndDescriptors, CallSiteBinder callSiteBinder, boolean grouped) {
BytecodeBlock body = method.getBody();
Variable thisVariable = method.getThis();
for (StateFieldAndDescriptor fieldAndDescriptor : stateFieldAndDescriptors) {
AccumulatorStateDescriptor<?> accumulatorStateDescriptor = fieldAndDescriptor.getAccumulatorStateDescriptor();
body.append(thisVariable.setField(fieldAndDescriptor.getStateSerializerField(), loadConstant(callSiteBinder, accumulatorStateDescriptor.getSerializer(), AccumulatorStateSerializer.class)));
body.append(generateRequireNotNull(thisVariable, fieldAndDescriptor.getStateSerializerField()));
body.append(thisVariable.setField(fieldAndDescriptor.getStateFactoryField(), loadConstant(callSiteBinder, accumulatorStateDescriptor.getFactory(), AccumulatorStateFactory.class)));
body.append(generateRequireNotNull(thisVariable, fieldAndDescriptor.getStateFactoryField()));
// create the state object
FieldDefinition stateField = fieldAndDescriptor.getStateField();
BytecodeExpression stateFactory = thisVariable.getField(fieldAndDescriptor.getStateFactoryField());
BytecodeExpression createStateInstance = stateFactory.invoke(grouped ? "createGroupedState" : "createSingleState", AccumulatorState.class);
body.append(thisVariable.setField(stateField, createStateInstance.cast(stateField.getType())));
body.append(generateRequireNotNull(thisVariable, stateField));
}
}
use of io.airlift.bytecode.FieldDefinition in project trino by trinodb.
the class AccumulatorCompiler method generateGetEstimatedSize.
private static void generateGetEstimatedSize(ClassDefinition definition, List<FieldDefinition> stateFields) {
MethodDefinition method = definition.declareMethod(a(PUBLIC), "getEstimatedSize", type(long.class));
Variable estimatedSize = method.getScope().declareVariable(long.class, "estimatedSize");
method.getBody().append(estimatedSize.set(constantLong(0L)));
for (FieldDefinition stateField : stateFields) {
method.getBody().append(estimatedSize.set(BytecodeExpressions.add(estimatedSize, method.getThis().getField(stateField).invoke("getEstimatedSize", long.class))));
}
method.getBody().append(estimatedSize.ret());
}
use of io.airlift.bytecode.FieldDefinition in project trino by trinodb.
the class AccumulatorCompiler method generateEvaluateFinal.
private static void generateEvaluateFinal(ClassDefinition definition, List<FieldDefinition> stateFields, MethodHandle outputFunction, CallSiteBinder callSiteBinder) {
Parameter out = arg("out", BlockBuilder.class);
MethodDefinition method = definition.declareMethod(a(PUBLIC), "evaluateFinal", type(void.class), out);
BytecodeBlock body = method.getBody();
Variable thisVariable = method.getThis();
List<BytecodeExpression> states = new ArrayList<>();
for (FieldDefinition stateField : stateFields) {
BytecodeExpression state = thisVariable.getField(stateField);
states.add(state);
}
body.comment("output(state_0, state_1, ..., out)");
states.forEach(body::append);
body.append(out);
body.append(invoke(callSiteBinder.bind(outputFunction), "output"));
body.ret();
}
use of io.airlift.bytecode.FieldDefinition in project trino by trinodb.
the class AccumulatorCompiler method generateAddIntermediateAsCombine.
private static void generateAddIntermediateAsCombine(ClassDefinition definition, List<StateFieldAndDescriptor> stateFieldAndDescriptors, List<FieldDefinition> lambdaProviderFields, Optional<MethodHandle> combineFunction, CallSiteBinder callSiteBinder, boolean grouped) {
MethodDefinition method = declareAddIntermediate(definition, grouped);
if (combineFunction.isEmpty()) {
method.getBody().append(newInstance(UnsupportedOperationException.class, constantString("Aggregation is not decomposable"))).throwObject();
return;
}
Scope scope = method.getScope();
BytecodeBlock body = method.getBody();
Variable thisVariable = method.getThis();
int stateCount = stateFieldAndDescriptors.size();
List<Variable> scratchStates = new ArrayList<>();
for (int i = 0; i < stateCount; i++) {
Class<?> scratchStateClass = AccumulatorState.class;
scratchStates.add(scope.declareVariable(scratchStateClass, "scratchState_" + i));
}
List<Variable> block;
if (stateCount == 1) {
block = ImmutableList.of(scope.getVariable("block"));
} else {
// ColumnarRow is used to get the column blocks represents each state, this allows to
// 1. handle single state and multiple states in a unified way
// 2. avoid the cost of constructing SingleRowBlock for each group
Variable columnarRow = scope.declareVariable(ColumnarRow.class, "columnarRow");
body.append(columnarRow.set(invokeStatic(ColumnarRow.class, "toColumnarRow", ColumnarRow.class, scope.getVariable("block"))));
block = new ArrayList<>();
for (int i = 0; i < stateCount; i++) {
Variable columnBlock = scope.declareVariable(Block.class, "columnBlock_" + i);
body.append(columnBlock.set(columnarRow.invoke("getField", Block.class, constantInt(i))));
block.add(columnBlock);
}
}
Variable position = scope.declareVariable(int.class, "position");
for (int i = 0; i < stateCount; i++) {
FieldDefinition stateFactoryField = stateFieldAndDescriptors.get(i).getStateFactoryField();
body.comment(format("scratchState_%s = stateFactory[%s].createSingleState();", i, i)).append(thisVariable.getField(stateFactoryField)).invokeInterface(AccumulatorStateFactory.class, "createSingleState", AccumulatorState.class).checkCast(scratchStates.get(i).getType()).putVariable(scratchStates.get(i));
}
List<FieldDefinition> stateFields = stateFieldAndDescriptors.stream().map(StateFieldAndDescriptor::getStateField).collect(toImmutableList());
if (grouped) {
generateEnsureCapacity(scope, stateFields, body);
}
BytecodeBlock loopBody = new BytecodeBlock();
loopBody.comment("combine(state_0, state_1, ... scratchState_0, scratchState_1, ... lambda_0, lambda_1, ...)");
for (FieldDefinition stateField : stateFields) {
if (grouped) {
Variable groupIdsBlock = scope.getVariable("groupIdsBlock");
loopBody.append(thisVariable.getField(stateField).invoke("setGroupId", void.class, groupIdsBlock.invoke("getGroupId", long.class, position)));
}
loopBody.append(thisVariable.getField(stateField));
}
for (int i = 0; i < stateCount; i++) {
FieldDefinition stateSerializerField = stateFieldAndDescriptors.get(i).getStateSerializerField();
loopBody.append(thisVariable.getField(stateSerializerField).invoke("deserialize", void.class, block.get(i), position, scratchStates.get(i).cast(AccumulatorState.class)));
loopBody.append(scratchStates.get(i));
}
for (FieldDefinition lambdaProviderField : lambdaProviderFields) {
loopBody.append(scope.getThis().getField(lambdaProviderField).invoke("get", Object.class));
}
loopBody.append(invoke(callSiteBinder.bind(combineFunction.get()), "combine"));
if (grouped) {
// skip rows with null group id
IfStatement ifStatement = new IfStatement("if (!groupIdsBlock.isNull(position))").condition(not(scope.getVariable("groupIdsBlock").invoke("isNull", boolean.class, position))).ifTrue(loopBody);
loopBody = new BytecodeBlock().append(ifStatement);
}
body.append(generateBlockNonNullPositionForLoop(scope, position, loopBody)).ret();
}
Aggregations