use of com.facebook.presto.bytecode.CallSiteBinder in project presto by prestodb.
the class ArrayTransformFunction method generateTransform.
private static Class<?> generateTransform(Type inputType, Type outputType) {
CallSiteBinder binder = new CallSiteBinder();
Class<?> inputJavaType = Primitives.wrap(inputType.getJavaType());
Class<?> outputJavaType = Primitives.wrap(outputType.getJavaType());
ClassDefinition definition = new ClassDefinition(a(PUBLIC, FINAL), makeClassName("ArrayTransform"), type(Object.class));
definition.declareDefaultConstructor(a(PRIVATE));
// define createPageBuilder
MethodDefinition createPageBuilderMethod = definition.declareMethod(a(PUBLIC, STATIC), "createPageBuilder", type(PageBuilder.class));
createPageBuilderMethod.getBody().append(newInstance(PageBuilder.class, constantType(binder, new ArrayType(outputType)).invoke("getTypeParameters", List.class)).ret());
// define transform method
Parameter pageBuilder = arg("pageBuilder", PageBuilder.class);
Parameter block = arg("block", Block.class);
Parameter function = arg("function", UnaryFunctionInterface.class);
MethodDefinition method = definition.declareMethod(a(PUBLIC, STATIC), "transform", type(Block.class), ImmutableList.of(pageBuilder, block, function));
BytecodeBlock body = method.getBody();
Scope scope = method.getScope();
Variable positionCount = scope.declareVariable(int.class, "positionCount");
Variable position = scope.declareVariable(int.class, "position");
Variable blockBuilder = scope.declareVariable(BlockBuilder.class, "blockBuilder");
Variable inputElement = scope.declareVariable(inputJavaType, "inputElement");
Variable outputElement = scope.declareVariable(outputJavaType, "outputElement");
// invoke block.getPositionCount()
body.append(positionCount.set(block.invoke("getPositionCount", int.class)));
// reset page builder if it is full
body.append(new IfStatement().condition(pageBuilder.invoke("isFull", boolean.class)).ifTrue(pageBuilder.invoke("reset", void.class)));
// get block builder
body.append(blockBuilder.set(pageBuilder.invoke("getBlockBuilder", BlockBuilder.class, constantInt(0))));
BytecodeNode loadInputElement;
if (!inputType.equals(UNKNOWN)) {
loadInputElement = new IfStatement().condition(block.invoke("isNull", boolean.class, position)).ifTrue(inputElement.set(constantNull(inputJavaType))).ifFalse(inputElement.set(constantType(binder, inputType).getValue(block, position).cast(inputJavaType)));
} else {
loadInputElement = new BytecodeBlock().append(inputElement.set(constantNull(inputJavaType)));
}
BytecodeNode writeOutputElement;
if (!outputType.equals(UNKNOWN)) {
writeOutputElement = new IfStatement().condition(equal(outputElement, constantNull(outputJavaType))).ifTrue(blockBuilder.invoke("appendNull", BlockBuilder.class).pop()).ifFalse(constantType(binder, outputType).writeValue(blockBuilder, outputElement.cast(outputType.getJavaType())));
} else {
writeOutputElement = new BytecodeBlock().append(blockBuilder.invoke("appendNull", BlockBuilder.class).pop());
}
body.append(new ForLoop().initialize(position.set(constantInt(0))).condition(lessThan(position, positionCount)).update(incrementVariable(position, (byte) 1)).body(new BytecodeBlock().append(loadInputElement).append(outputElement.set(function.invoke("apply", Object.class, inputElement.cast(Object.class)).cast(outputJavaType))).append(writeOutputElement)));
body.append(pageBuilder.invoke("declarePositions", void.class, positionCount));
body.append(blockBuilder.invoke("getRegion", Block.class, subtract(blockBuilder.invoke("getPositionCount", int.class), positionCount), positionCount).ret());
return defineClass(definition, Object.class, binder.getBindings(), ArrayTransformFunction.class.getClassLoader());
}
use of com.facebook.presto.bytecode.CallSiteBinder in project presto by prestodb.
the class PageFunctionCompiler method compileProjectionInternal.
private Supplier<PageProjection> compileProjectionInternal(SqlFunctionProperties sqlFunctionProperties, Map<SqlFunctionId, SqlInvokedFunction> sessionFunctions, List<RowExpression> projections, boolean isOptimizeCommonSubExpression, Optional<String> classNameSuffix) {
requireNonNull(projections, "projections is null");
checkArgument(!projections.isEmpty() && projections.stream().allMatch(projection -> projection instanceof CallExpression || projection instanceof SpecialFormExpression));
PageFieldsToInputParametersRewriter.Result result = rewritePageFieldsToInputParameters(projections);
List<RowExpression> rewrittenExpression = result.getRewrittenExpressions();
CallSiteBinder callSiteBinder = new CallSiteBinder();
// generate Work
ClassDefinition pageProjectionWorkDefinition = definePageProjectWorkClass(sqlFunctionProperties, sessionFunctions, rewrittenExpression, callSiteBinder, isOptimizeCommonSubExpression, classNameSuffix);
Class<? extends Work> pageProjectionWorkClass;
try {
pageProjectionWorkClass = defineClass(pageProjectionWorkDefinition, Work.class, callSiteBinder.getBindings(), getClass().getClassLoader());
} catch (PrestoException prestoException) {
throw prestoException;
} catch (Exception e) {
throw new PrestoException(COMPILER_ERROR, e);
}
return () -> new GeneratedPageProjection(rewrittenExpression, rewrittenExpression.stream().allMatch(determinismEvaluator::isDeterministic), result.getInputChannels(), constructorMethodHandle(pageProjectionWorkClass, List.class, SqlFunctionProperties.class, Page.class, SelectedPositions.class));
}
use of com.facebook.presto.bytecode.CallSiteBinder in project presto by prestodb.
the class PageFunctionCompiler method defineFilterClass.
private ClassDefinition defineFilterClass(SqlFunctionProperties sqlFunctionProperties, Map<SqlFunctionId, SqlInvokedFunction> sessionFunctions, RowExpression filter, InputChannels inputChannels, CallSiteBinder callSiteBinder, boolean isOptimizeCommonSubExpression, Optional<String> classNameSuffix) {
ClassDefinition classDefinition = new ClassDefinition(a(PUBLIC, FINAL), generateFilterClassName(classNameSuffix), type(Object.class), type(PageFilter.class));
CachedInstanceBinder cachedInstanceBinder = new CachedInstanceBinder(classDefinition, callSiteBinder);
Map<LambdaDefinitionExpression, CompiledLambda> compiledLambdaMap = generateMethodsForLambda(classDefinition, callSiteBinder, cachedInstanceBinder, filter, metadata, sqlFunctionProperties, sessionFunctions);
// cse
Map<VariableReferenceExpression, CommonSubExpressionFields> cseFields = ImmutableMap.of();
RowExpressionCompiler compiler = new RowExpressionCompiler(classDefinition, callSiteBinder, cachedInstanceBinder, new FieldAndVariableReferenceCompiler(callSiteBinder, cseFields), metadata, sqlFunctionProperties, sessionFunctions, compiledLambdaMap);
if (isOptimizeCommonSubExpression) {
Map<Integer, Map<RowExpression, VariableReferenceExpression>> commonSubExpressionsByLevel = collectCSEByLevel(filter);
if (!commonSubExpressionsByLevel.isEmpty()) {
cseFields = declareCommonSubExpressionFields(classDefinition, commonSubExpressionsByLevel);
compiler = new RowExpressionCompiler(classDefinition, callSiteBinder, cachedInstanceBinder, new FieldAndVariableReferenceCompiler(callSiteBinder, cseFields), metadata, sqlFunctionProperties, sessionFunctions, compiledLambdaMap);
generateCommonSubExpressionMethods(classDefinition, compiler, commonSubExpressionsByLevel, cseFields);
Map<RowExpression, VariableReferenceExpression> commonSubExpressions = commonSubExpressionsByLevel.values().stream().flatMap(m -> m.entrySet().stream()).collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
filter = rewriteExpressionWithCSE(filter, commonSubExpressions);
if (log.isDebugEnabled()) {
log.debug("Extracted %d common sub-expressions", commonSubExpressions.size());
commonSubExpressions.entrySet().forEach(entry -> log.debug("\t%s = %s", entry.getValue(), entry.getKey()));
log.debug("Rewrote filter: %s", filter);
}
}
}
generateFilterMethod(classDefinition, compiler, filter, cseFields);
FieldDefinition selectedPositions = classDefinition.declareField(a(PRIVATE), "selectedPositions", boolean[].class);
generatePageFilterMethod(classDefinition, selectedPositions);
// isDeterministic
classDefinition.declareMethod(a(PUBLIC), "isDeterministic", type(boolean.class)).getBody().append(constantBoolean(determinismEvaluator.isDeterministic(filter))).retBoolean();
// getInputChannels
classDefinition.declareMethod(a(PUBLIC), "getInputChannels", type(InputChannels.class)).getBody().append(invoke(callSiteBinder.bind(inputChannels, InputChannels.class), "getInputChannels")).retObject();
// toString
String toStringResult = toStringHelper(classDefinition.getType().getJavaClassName()).add("filter", filter).toString();
classDefinition.declareMethod(a(PUBLIC), "toString", type(String.class)).getBody().append(invoke(callSiteBinder.bind(toStringResult, String.class), "toString")).retObject();
// constructor
MethodDefinition constructorDefinition = classDefinition.declareConstructor(a(PUBLIC));
BytecodeBlock body = constructorDefinition.getBody();
Variable thisVariable = constructorDefinition.getThis();
body.comment("super();").append(thisVariable).invokeConstructor(Object.class).append(thisVariable.setField(selectedPositions, newArray(type(boolean[].class), 0)));
initializeCommonSubExpressionFields(cseFields.values(), thisVariable, body);
cachedInstanceBinder.generateInitializations(thisVariable, body);
body.ret();
return classDefinition;
}
use of com.facebook.presto.bytecode.CallSiteBinder in project presto by prestodb.
the class VarArgsToArrayAdapterGenerator method generateVarArgsToArrayAdapter.
public static VarArgMethodHandle generateVarArgsToArrayAdapter(Class<?> returnType, Class<?> javaType, int argsLength, MethodHandle function) {
requireNonNull(returnType, "returnType is null");
requireNonNull(javaType, "javaType is null");
requireNonNull(function, "function is null");
checkCondition(argsLength <= 254, NOT_SUPPORTED, "Too many arguments for vararg function");
MethodType methodType = function.type();
Class<?> javaArrayType = toArrayClass(javaType);
checkArgument(methodType.returnType() == returnType, "returnType does not match");
checkArgument(methodType.parameterList().equals(ImmutableList.of(javaArrayType)), "parameter types do not match");
CallSiteBinder callSiteBinder = new CallSiteBinder();
ClassDefinition classDefinition = new ClassDefinition(a(PUBLIC, FINAL), makeClassName("VarArgsToListAdapter"), type(Object.class));
classDefinition.declareDefaultConstructor(a(PRIVATE));
// generate adapter method
ImmutableList.Builder<Parameter> parameterListBuilder = ImmutableList.builder();
for (int i = 0; i < argsLength; i++) {
parameterListBuilder.add(arg("input_" + i, javaType));
}
ImmutableList<Parameter> parameterList = parameterListBuilder.build();
MethodDefinition methodDefinition = classDefinition.declareMethod(a(PUBLIC, STATIC), "varArgsToArray", type(returnType), parameterList);
BytecodeBlock body = methodDefinition.getBody();
body.append(loadConstant(callSiteBinder, function, MethodHandle.class).invoke("invokeExact", returnType, BytecodeExpressions.newArray(ParameterizedType.type(javaArrayType), parameterList)).ret());
// define class
Class<?> generatedClass = defineClass(classDefinition, Object.class, callSiteBinder.getBindings(), VarArgsToArrayAdapterGenerator.class.getClassLoader());
return new VarArgMethodHandle(Reflection.methodHandle(generatedClass, "varArgsToArray", ImmutableList.builder().addAll(nCopies(argsLength, javaType)).build().toArray(new Class<?>[argsLength])));
}
use of com.facebook.presto.bytecode.CallSiteBinder in project presto by prestodb.
the class LambdaBytecodeGenerator method generateMethodsForLambda.
private static Map<LambdaDefinitionExpression, CompiledLambda> generateMethodsForLambda(ClassDefinition containerClassDefinition, CallSiteBinder callSiteBinder, CachedInstanceBinder cachedInstanceBinder, List<RowExpression> expressions, Metadata metadata, SqlFunctionProperties sqlFunctionProperties, Map<SqlFunctionId, SqlInvokedFunction> sessionFunctions, String methodNamePrefix, Set<LambdaDefinitionExpression> existingCompiledLambdas) {
Set<LambdaDefinitionExpression> lambdaExpressions = expressions.stream().map(LambdaExpressionExtractor::extractLambdaExpressions).flatMap(List::stream).filter(lambda -> !existingCompiledLambdas.contains(lambda)).collect(toImmutableSet());
ImmutableMap.Builder<LambdaDefinitionExpression, CompiledLambda> compiledLambdaMap = ImmutableMap.builder();
int counter = existingCompiledLambdas.size();
for (LambdaDefinitionExpression lambdaExpression : lambdaExpressions) {
CompiledLambda compiledLambda = LambdaBytecodeGenerator.preGenerateLambdaExpression(lambdaExpression, methodNamePrefix + "lambda_" + counter, containerClassDefinition, compiledLambdaMap.build(), callSiteBinder, cachedInstanceBinder, metadata, sqlFunctionProperties, sessionFunctions);
compiledLambdaMap.put(lambdaExpression, compiledLambda);
counter++;
}
return compiledLambdaMap.build();
}
Aggregations