use of io.prestosql.spi.function.FunctionHandle in project hetu-core by openlookeng.
the class CubeOptimizer method rewriteAggregationNode.
private PlanNode rewriteAggregationNode(CubeRewriteResult cubeRewriteResult, PlanNode inputPlanNode) {
TypeProvider typeProvider = context.getSymbolAllocator().getTypes();
// Add group by
List<Symbol> groupings = aggregationNode.getGroupingKeys().stream().map(Symbol::getName).map(columnRewritesMap::get).map(optimizedPlanMappings::get).collect(Collectors.toList());
Map<Symbol, Symbol> cubeScanToAggOutputMap = new HashMap<>();
// Rewrite AggregationNode using Cube table
ImmutableMap.Builder<Symbol, AggregationNode.Aggregation> aggregationsBuilder = ImmutableMap.builder();
for (CubeRewriteResult.AggregatorSource aggregatorSource : cubeRewriteResult.getAggregationColumns()) {
Type type = cubeRewriteResult.getSymbolMetadataMap().get(aggregatorSource.getOriginalAggSymbol()).getType();
TypeSignature typeSignature = type.getTypeSignature();
ColumnHandle cubeColHandle = cubeRewriteResult.getTableScanNode().getAssignments().get(aggregatorSource.getScanSymbol());
ColumnMetadata cubeColumnMetadata = metadata.getColumnMetadata(context.getSession(), cubeTableHandle, cubeColHandle);
AggregationSignature aggregationSignature = cubeMetadata.getAggregationSignature(cubeColumnMetadata.getName()).orElseThrow(() -> new ColumnNotFoundException(new SchemaTableName("", ""), cubeColHandle.getColumnName()));
String aggFunction = COUNT.getName().equals(aggregationSignature.getFunction()) ? SUM.getName() : aggregationSignature.getFunction();
SymbolReference argument = toSymbolReference(aggregatorSource.getScanSymbol());
FunctionHandle functionHandle = metadata.getFunctionAndTypeManager().lookupFunction(aggFunction, TypeSignatureProvider.fromTypeSignatures(typeSignature));
cubeScanToAggOutputMap.put(aggregatorSource.getScanSymbol(), aggregatorSource.getOriginalAggSymbol());
aggregationsBuilder.put(aggregatorSource.getOriginalAggSymbol(), new AggregationNode.Aggregation(new CallExpression(aggFunction, functionHandle, type, ImmutableList.of(castToRowExpression(argument))), ImmutableList.of(castToRowExpression(argument)), false, Optional.empty(), Optional.empty(), Optional.empty()));
}
PlanNode planNode = inputPlanNode;
AggregationNode aggNode = new AggregationNode(context.getIdAllocator().getNextId(), planNode, aggregationsBuilder.build(), singleGroupingSet(groupings), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty(), AggregationNode.AggregationType.HASH, Optional.empty());
if (cubeRewriteResult.getAvgAggregationColumns().isEmpty()) {
return aggNode;
}
if (!cubeRewriteResult.getComputeAvgDividingSumByCount()) {
Map<Symbol, Expression> aggregateAssignments = new HashMap<>();
for (CubeRewriteResult.AggregatorSource aggregatorSource : cubeRewriteResult.getAggregationColumns()) {
aggregateAssignments.put(aggregatorSource.getOriginalAggSymbol(), toSymbolReference(aggregatorSource.getScanSymbol()));
}
planNode = new ProjectNode(context.getIdAllocator().getNextId(), aggNode, new Assignments(aggregateAssignments.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> castToRowExpression(entry.getValue())))));
} else {
// If there was an AVG aggregation, map it to AVG = SUM/COUNT
Map<Symbol, Expression> projections = new HashMap<>();
aggNode.getGroupingKeys().forEach(symbol -> projections.put(symbol, toSymbolReference(symbol)));
aggNode.getAggregations().keySet().stream().filter(originalAggregationsMap::containsValue).forEach(aggSymbol -> projections.put(aggSymbol, toSymbolReference(aggSymbol)));
// Add AVG = SUM / COUNT
for (CubeRewriteResult.AverageAggregatorSource avgAggSource : cubeRewriteResult.getAvgAggregationColumns()) {
Symbol sumSymbol = cubeScanToAggOutputMap.get(avgAggSource.getSum());
Symbol countSymbol = cubeScanToAggOutputMap.get(avgAggSource.getCount());
Type avgResultType = typeProvider.get(avgAggSource.getOriginalAggSymbol());
ArithmeticBinaryExpression division = new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.DIVIDE, new Cast(toSymbolReference(sumSymbol), avgResultType.getTypeSignature().toString()), new Cast(toSymbolReference(countSymbol), avgResultType.getTypeSignature().toString()));
projections.put(avgAggSource.getOriginalAggSymbol(), division);
}
return new ProjectNode(context.getIdAllocator().getNextId(), aggNode, new Assignments(projections.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> castToRowExpression(entry.getValue())))));
}
return planNode;
}
use of io.prestosql.spi.function.FunctionHandle in project hetu-core by openlookeng.
the class ArrayToArrayCast method specialize.
@Override
public BuiltInScalarFunctionImplementation specialize(BoundVariables boundVariables, int arity, FunctionAndTypeManager functionAndTypeManager) {
checkArgument(arity == 1, "Expected arity to be 1");
Type fromType = boundVariables.getTypeVariable("F");
Type toType = boundVariables.getTypeVariable("T");
FunctionHandle functionHandle = functionAndTypeManager.lookupCast(CastType.CAST, fromType.getTypeSignature(), toType.getTypeSignature());
BuiltInScalarFunctionImplementation function = functionAndTypeManager.getBuiltInScalarFunctionImplementation(functionHandle);
Class<?> castOperatorClass = generateArrayCast(functionAndTypeManager, functionAndTypeManager.getFunctionMetadata(functionHandle), function);
MethodHandle methodHandle = methodHandle(castOperatorClass, "castArray", ConnectorSession.class, Block.class);
return new BuiltInScalarFunctionImplementation(false, ImmutableList.of(valueTypeArgumentProperty(RETURN_NULL_ON_NULL), valueTypeArgumentProperty(RETURN_NULL_ON_NULL)), methodHandle);
}
use of io.prestosql.spi.function.FunctionHandle in project hetu-core by openlookeng.
the class InCodeGenerator method buildInCase.
private static BytecodeBlock buildInCase(BytecodeGeneratorContext generatorContext, Scope scope, Type type, LabelNode matchLabel, LabelNode noMatchLabel, Variable value, Collection<BytecodeNode> testValues, boolean checkForNulls, BuiltInScalarFunctionImplementation isIndeterminateFunction) {
// caseWasNull is set to true the first time a null in `testValues` is encountered
Variable caseWasNull = null;
if (checkForNulls) {
caseWasNull = scope.createTempVariable(boolean.class);
}
BytecodeBlock caseBlock = new BytecodeBlock();
if (checkForNulls) {
caseBlock.putVariable(caseWasNull, false);
}
LabelNode elseLabel = new LabelNode("else");
BytecodeBlock elseBlock = new BytecodeBlock().visitLabel(elseLabel);
Variable wasNull = generatorContext.wasNull();
if (checkForNulls) {
// That is incorrect. Doing an explicit check for indeterminate is required to correctly return NULL.
if (testValues.isEmpty()) {
elseBlock.append(new BytecodeBlock().append(generatorContext.generateCall(INDETERMINATE.name(), isIndeterminateFunction, ImmutableList.of(value))).putVariable(wasNull));
} else {
elseBlock.append(wasNull.set(caseWasNull));
}
}
elseBlock.gotoLabel(noMatchLabel);
FunctionHandle equalsHandle = generatorContext.getFunctionManager().resolveOperatorFunctionHandle(EQUAL, fromTypes(type, type));
BuiltInScalarFunctionImplementation equalsFunction = generatorContext.getFunctionManager().getBuiltInScalarFunctionImplementation(equalsHandle);
BytecodeNode elseNode = elseBlock;
for (BytecodeNode testNode : testValues) {
LabelNode testLabel = new LabelNode("test");
IfStatement test = new IfStatement();
BytecodeNode equalsCall = generatorContext.generateCall(EQUAL.name(), equalsFunction, ImmutableList.of(value, testNode));
test.condition().visitLabel(testLabel).append(equalsCall);
if (checkForNulls) {
IfStatement wasNullCheck = new IfStatement("if wasNull, set caseWasNull to true, clear wasNull, pop boolean, and goto next test value");
wasNullCheck.condition(wasNull);
wasNullCheck.ifTrue(new BytecodeBlock().append(caseWasNull.set(constantTrue())).append(wasNull.set(constantFalse())).pop(boolean.class).gotoLabel(elseLabel));
test.condition().append(wasNullCheck);
}
test.ifTrue().gotoLabel(matchLabel);
test.ifFalse(elseNode);
elseNode = test;
elseLabel = testLabel;
}
caseBlock.append(elseNode);
return caseBlock;
}
use of io.prestosql.spi.function.FunctionHandle in project hetu-core by openlookeng.
the class ValuePrinter method castToVarcharOrFail.
public String castToVarcharOrFail(Type type, Object value) throws OperatorNotFoundException {
if (value == null) {
return "NULL";
}
FunctionHandle coercion = metadata.getFunctionAndTypeManager().lookupCast(CAST, type.getTypeSignature(), VARCHAR.getTypeSignature());
Slice coerced = (Slice) new InterpretedFunctionInvoker(metadata.getFunctionAndTypeManager()).invoke(coercion, session.toConnectorSession(), value);
return coerced.toStringUtf8();
}
use of io.prestosql.spi.function.FunctionHandle in project hetu-core by openlookeng.
the class TestTypeValidator method testInvalidWindowFunctionCall.
@Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "type of symbol 'sum(_[0-9]+)?' is expected to be double, but the actual type is bigint")
public void testInvalidWindowFunctionCall() {
Symbol windowSymbol = planSymbolAllocator.newSymbol("sum", DOUBLE);
FunctionHandle functionHandle = FUNCTION_MANAGER.lookupFunction("sum", fromTypes(DOUBLE));
WindowNode.Frame frame = new WindowNode.Frame(WindowFrameType.RANGE, FrameBoundType.UNBOUNDED_PRECEDING, Optional.empty(), FrameBoundType.UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty(), Optional.empty());
WindowNode.Function function = new WindowNode.Function(call("sum", functionHandle, BIGINT, ImmutableList.of(VariableReferenceSymbolConverter.toVariableReference(columnA, BIGINT))), ImmutableList.of(VariableReferenceSymbolConverter.toVariableReference(columnA, BIGINT)), frame);
WindowNode.Specification specification = new WindowNode.Specification(ImmutableList.of(), Optional.empty());
PlanNode node = new WindowNode(newId(), baseTableScan, specification, ImmutableMap.of(windowSymbol, function), Optional.empty(), ImmutableSet.of(), 0);
assertTypesValid(node);
}
Aggregations