Search in sources :

Example 41 with FunctionHandle

use of io.prestosql.spi.function.FunctionHandle in project hetu-core by openlookeng.

the class CubeOptimizer method rewriteAggregationNode.

private PlanNode rewriteAggregationNode(CubeRewriteResult cubeRewriteResult, PlanNode inputPlanNode) {
    TypeProvider typeProvider = context.getSymbolAllocator().getTypes();
    // Add group by
    List<Symbol> groupings = aggregationNode.getGroupingKeys().stream().map(Symbol::getName).map(columnRewritesMap::get).map(optimizedPlanMappings::get).collect(Collectors.toList());
    Map<Symbol, Symbol> cubeScanToAggOutputMap = new HashMap<>();
    // Rewrite AggregationNode using Cube table
    ImmutableMap.Builder<Symbol, AggregationNode.Aggregation> aggregationsBuilder = ImmutableMap.builder();
    for (CubeRewriteResult.AggregatorSource aggregatorSource : cubeRewriteResult.getAggregationColumns()) {
        Type type = cubeRewriteResult.getSymbolMetadataMap().get(aggregatorSource.getOriginalAggSymbol()).getType();
        TypeSignature typeSignature = type.getTypeSignature();
        ColumnHandle cubeColHandle = cubeRewriteResult.getTableScanNode().getAssignments().get(aggregatorSource.getScanSymbol());
        ColumnMetadata cubeColumnMetadata = metadata.getColumnMetadata(context.getSession(), cubeTableHandle, cubeColHandle);
        AggregationSignature aggregationSignature = cubeMetadata.getAggregationSignature(cubeColumnMetadata.getName()).orElseThrow(() -> new ColumnNotFoundException(new SchemaTableName("", ""), cubeColHandle.getColumnName()));
        String aggFunction = COUNT.getName().equals(aggregationSignature.getFunction()) ? SUM.getName() : aggregationSignature.getFunction();
        SymbolReference argument = toSymbolReference(aggregatorSource.getScanSymbol());
        FunctionHandle functionHandle = metadata.getFunctionAndTypeManager().lookupFunction(aggFunction, TypeSignatureProvider.fromTypeSignatures(typeSignature));
        cubeScanToAggOutputMap.put(aggregatorSource.getScanSymbol(), aggregatorSource.getOriginalAggSymbol());
        aggregationsBuilder.put(aggregatorSource.getOriginalAggSymbol(), new AggregationNode.Aggregation(new CallExpression(aggFunction, functionHandle, type, ImmutableList.of(castToRowExpression(argument))), ImmutableList.of(castToRowExpression(argument)), false, Optional.empty(), Optional.empty(), Optional.empty()));
    }
    PlanNode planNode = inputPlanNode;
    AggregationNode aggNode = new AggregationNode(context.getIdAllocator().getNextId(), planNode, aggregationsBuilder.build(), singleGroupingSet(groupings), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty(), AggregationNode.AggregationType.HASH, Optional.empty());
    if (cubeRewriteResult.getAvgAggregationColumns().isEmpty()) {
        return aggNode;
    }
    if (!cubeRewriteResult.getComputeAvgDividingSumByCount()) {
        Map<Symbol, Expression> aggregateAssignments = new HashMap<>();
        for (CubeRewriteResult.AggregatorSource aggregatorSource : cubeRewriteResult.getAggregationColumns()) {
            aggregateAssignments.put(aggregatorSource.getOriginalAggSymbol(), toSymbolReference(aggregatorSource.getScanSymbol()));
        }
        planNode = new ProjectNode(context.getIdAllocator().getNextId(), aggNode, new Assignments(aggregateAssignments.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> castToRowExpression(entry.getValue())))));
    } else {
        // If there was an AVG aggregation, map it to AVG = SUM/COUNT
        Map<Symbol, Expression> projections = new HashMap<>();
        aggNode.getGroupingKeys().forEach(symbol -> projections.put(symbol, toSymbolReference(symbol)));
        aggNode.getAggregations().keySet().stream().filter(originalAggregationsMap::containsValue).forEach(aggSymbol -> projections.put(aggSymbol, toSymbolReference(aggSymbol)));
        // Add AVG = SUM / COUNT
        for (CubeRewriteResult.AverageAggregatorSource avgAggSource : cubeRewriteResult.getAvgAggregationColumns()) {
            Symbol sumSymbol = cubeScanToAggOutputMap.get(avgAggSource.getSum());
            Symbol countSymbol = cubeScanToAggOutputMap.get(avgAggSource.getCount());
            Type avgResultType = typeProvider.get(avgAggSource.getOriginalAggSymbol());
            ArithmeticBinaryExpression division = new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.DIVIDE, new Cast(toSymbolReference(sumSymbol), avgResultType.getTypeSignature().toString()), new Cast(toSymbolReference(countSymbol), avgResultType.getTypeSignature().toString()));
            projections.put(avgAggSource.getOriginalAggSymbol(), division);
        }
        return new ProjectNode(context.getIdAllocator().getNextId(), aggNode, new Assignments(projections.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> castToRowExpression(entry.getValue())))));
    }
    return planNode;
}
Also used : ArithmeticBinaryExpression(io.prestosql.sql.tree.ArithmeticBinaryExpression) Cast(io.prestosql.sql.tree.Cast) LongSupplier(java.util.function.LongSupplier) ConstantExpression(io.prestosql.spi.relation.ConstantExpression) PrestoWarning(io.prestosql.spi.PrestoWarning) TypeProvider(io.prestosql.sql.planner.TypeProvider) SqlParser(io.prestosql.sql.parser.SqlParser) ColumnNotFoundException(io.prestosql.spi.connector.ColumnNotFoundException) AggregationNode(io.prestosql.spi.plan.AggregationNode) CallExpression(io.prestosql.spi.relation.CallExpression) Cast(io.prestosql.sql.tree.Cast) FilterNode(io.prestosql.spi.plan.FilterNode) Pair(org.apache.commons.lang3.tuple.Pair) Map(java.util.Map) OriginalExpressionUtils.castToRowExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToRowExpression) Type(io.prestosql.spi.type.Type) OriginalExpressionUtils.castToExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToExpression) CubeFilter(io.hetu.core.spi.cube.CubeFilter) ENGLISH(java.util.Locale.ENGLISH) Identifier(io.prestosql.sql.tree.Identifier) AggregationNode.singleGroupingSet(io.prestosql.spi.plan.AggregationNode.singleGroupingSet) CubeMetaStore(io.hetu.core.spi.cube.io.CubeMetaStore) PrestoException(io.prestosql.spi.PrestoException) SymbolsExtractor(io.prestosql.sql.planner.SymbolsExtractor) SUM(io.hetu.core.spi.cube.CubeAggregateFunction.SUM) ImmutableMap(com.google.common.collect.ImmutableMap) CubeAggregateFunction(io.hetu.core.spi.cube.CubeAggregateFunction) TableScanNode(io.prestosql.spi.plan.TableScanNode) Set(java.util.Set) PlanNode(io.prestosql.spi.plan.PlanNode) UUID(java.util.UUID) ProjectNode(io.prestosql.spi.plan.ProjectNode) Collectors(java.util.stream.Collectors) Metadata(io.prestosql.metadata.Metadata) String.format(java.lang.String.format) FunctionHandle(io.prestosql.spi.function.FunctionHandle) Objects(java.util.Objects) SymbolUtils.toSymbolReference(io.prestosql.sql.planner.SymbolUtils.toSymbolReference) ReuseExchangeOperator(io.prestosql.spi.operator.ReuseExchangeOperator) List(java.util.List) ExpressionUtils(io.prestosql.sql.ExpressionUtils) LongLiteral(io.prestosql.sql.tree.LongLiteral) SymbolReference(io.prestosql.sql.tree.SymbolReference) AggregationSignature(io.hetu.core.spi.cube.aggregator.AggregationSignature) Optional(java.util.Optional) NOT_SUPPORTED(io.prestosql.spi.StandardErrorCode.NOT_SUPPORTED) TypeSignature(io.prestosql.spi.type.TypeSignature) OriginalExpressionUtils(io.prestosql.sql.relational.OriginalExpressionUtils) Iterables(com.google.common.collect.Iterables) COUNT(io.hetu.core.spi.cube.CubeAggregateFunction.COUNT) TableMetadata(io.prestosql.metadata.TableMetadata) Logger(io.airlift.log.Logger) TypeSignatureProvider(io.prestosql.sql.analyzer.TypeSignatureProvider) Literal(io.prestosql.sql.tree.Literal) HashMap(java.util.HashMap) TableHandle(io.prestosql.spi.metadata.TableHandle) ExpressionTreeRewriter(io.prestosql.sql.tree.ExpressionTreeRewriter) QualifiedObjectName(io.prestosql.spi.connector.QualifiedObjectName) ArrayList(java.util.ArrayList) HashSet(java.util.HashSet) SchemaTableName(io.prestosql.spi.connector.SchemaTableName) Lists(com.google.common.collect.Lists) ArithmeticBinaryExpression(io.prestosql.sql.tree.ArithmeticBinaryExpression) ImmutableList(com.google.common.collect.ImmutableList) ExpressionDomainTranslator(io.prestosql.sql.planner.ExpressionDomainTranslator) BooleanLiteral(io.prestosql.sql.tree.BooleanLiteral) Objects.requireNonNull(java.util.Objects.requireNonNull) Session(io.prestosql.Session) RowExpressionRewriter(io.prestosql.expressions.RowExpressionRewriter) RowExpressionTreeRewriter(io.prestosql.expressions.RowExpressionTreeRewriter) ExpressionRewriter(io.prestosql.sql.tree.ExpressionRewriter) JoinNode(io.prestosql.spi.plan.JoinNode) ParsingOptions(io.prestosql.sql.parser.ParsingOptions) Symbol(io.prestosql.spi.plan.Symbol) AssignmentUtils(io.prestosql.sql.planner.plan.AssignmentUtils) EXPIRED_CUBE(io.prestosql.spi.connector.StandardWarningCode.EXPIRED_CUBE) Assignments(io.prestosql.spi.plan.Assignments) Rule(io.prestosql.sql.planner.iterative.Rule) ColumnMetadata(io.prestosql.spi.connector.ColumnMetadata) TupleDomain(io.prestosql.spi.predicate.TupleDomain) ComparisonExpression(io.prestosql.sql.tree.ComparisonExpression) VariableReferenceExpression(io.prestosql.spi.relation.VariableReferenceExpression) SymbolAllocator(io.prestosql.spi.SymbolAllocator) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) ColumnHandle(io.prestosql.spi.connector.ColumnHandle) CUBE_ERROR(io.prestosql.spi.StandardErrorCode.CUBE_ERROR) RowExpression(io.prestosql.spi.relation.RowExpression) CubeMetadata(io.hetu.core.spi.cube.CubeMetadata) Comparator(java.util.Comparator) Expression(io.prestosql.sql.tree.Expression) ColumnMetadata(io.prestosql.spi.connector.ColumnMetadata) HashMap(java.util.HashMap) Symbol(io.prestosql.spi.plan.Symbol) SymbolUtils.toSymbolReference(io.prestosql.sql.planner.SymbolUtils.toSymbolReference) SymbolReference(io.prestosql.sql.tree.SymbolReference) Assignments(io.prestosql.spi.plan.Assignments) TypeSignature(io.prestosql.spi.type.TypeSignature) PlanNode(io.prestosql.spi.plan.PlanNode) FunctionHandle(io.prestosql.spi.function.FunctionHandle) CallExpression(io.prestosql.spi.relation.CallExpression) ColumnHandle(io.prestosql.spi.connector.ColumnHandle) TypeProvider(io.prestosql.sql.planner.TypeProvider) AggregationSignature(io.hetu.core.spi.cube.aggregator.AggregationSignature) AggregationNode(io.prestosql.spi.plan.AggregationNode) SchemaTableName(io.prestosql.spi.connector.SchemaTableName) ImmutableMap(com.google.common.collect.ImmutableMap) Type(io.prestosql.spi.type.Type) ColumnNotFoundException(io.prestosql.spi.connector.ColumnNotFoundException) ConstantExpression(io.prestosql.spi.relation.ConstantExpression) CallExpression(io.prestosql.spi.relation.CallExpression) OriginalExpressionUtils.castToRowExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToRowExpression) OriginalExpressionUtils.castToExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToExpression) ArithmeticBinaryExpression(io.prestosql.sql.tree.ArithmeticBinaryExpression) ComparisonExpression(io.prestosql.sql.tree.ComparisonExpression) VariableReferenceExpression(io.prestosql.spi.relation.VariableReferenceExpression) RowExpression(io.prestosql.spi.relation.RowExpression) Expression(io.prestosql.sql.tree.Expression) ProjectNode(io.prestosql.spi.plan.ProjectNode) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap) HashMap(java.util.HashMap)

Example 42 with FunctionHandle

use of io.prestosql.spi.function.FunctionHandle in project hetu-core by openlookeng.

the class ArrayToArrayCast method specialize.

@Override
public BuiltInScalarFunctionImplementation specialize(BoundVariables boundVariables, int arity, FunctionAndTypeManager functionAndTypeManager) {
    checkArgument(arity == 1, "Expected arity to be 1");
    Type fromType = boundVariables.getTypeVariable("F");
    Type toType = boundVariables.getTypeVariable("T");
    FunctionHandle functionHandle = functionAndTypeManager.lookupCast(CastType.CAST, fromType.getTypeSignature(), toType.getTypeSignature());
    BuiltInScalarFunctionImplementation function = functionAndTypeManager.getBuiltInScalarFunctionImplementation(functionHandle);
    Class<?> castOperatorClass = generateArrayCast(functionAndTypeManager, functionAndTypeManager.getFunctionMetadata(functionHandle), function);
    MethodHandle methodHandle = methodHandle(castOperatorClass, "castArray", ConnectorSession.class, Block.class);
    return new BuiltInScalarFunctionImplementation(false, ImmutableList.of(valueTypeArgumentProperty(RETURN_NULL_ON_NULL), valueTypeArgumentProperty(RETURN_NULL_ON_NULL)), methodHandle);
}
Also used : Type(io.prestosql.spi.type.Type) CastType(io.prestosql.metadata.CastType) BuiltInScalarFunctionImplementation(io.prestosql.spi.function.BuiltInScalarFunctionImplementation) FunctionHandle(io.prestosql.spi.function.FunctionHandle) MethodHandle(java.lang.invoke.MethodHandle)

Example 43 with FunctionHandle

use of io.prestosql.spi.function.FunctionHandle in project hetu-core by openlookeng.

the class InCodeGenerator method buildInCase.

private static BytecodeBlock buildInCase(BytecodeGeneratorContext generatorContext, Scope scope, Type type, LabelNode matchLabel, LabelNode noMatchLabel, Variable value, Collection<BytecodeNode> testValues, boolean checkForNulls, BuiltInScalarFunctionImplementation isIndeterminateFunction) {
    // caseWasNull is set to true the first time a null in `testValues` is encountered
    Variable caseWasNull = null;
    if (checkForNulls) {
        caseWasNull = scope.createTempVariable(boolean.class);
    }
    BytecodeBlock caseBlock = new BytecodeBlock();
    if (checkForNulls) {
        caseBlock.putVariable(caseWasNull, false);
    }
    LabelNode elseLabel = new LabelNode("else");
    BytecodeBlock elseBlock = new BytecodeBlock().visitLabel(elseLabel);
    Variable wasNull = generatorContext.wasNull();
    if (checkForNulls) {
        // That is incorrect. Doing an explicit check for indeterminate is required to correctly return NULL.
        if (testValues.isEmpty()) {
            elseBlock.append(new BytecodeBlock().append(generatorContext.generateCall(INDETERMINATE.name(), isIndeterminateFunction, ImmutableList.of(value))).putVariable(wasNull));
        } else {
            elseBlock.append(wasNull.set(caseWasNull));
        }
    }
    elseBlock.gotoLabel(noMatchLabel);
    FunctionHandle equalsHandle = generatorContext.getFunctionManager().resolveOperatorFunctionHandle(EQUAL, fromTypes(type, type));
    BuiltInScalarFunctionImplementation equalsFunction = generatorContext.getFunctionManager().getBuiltInScalarFunctionImplementation(equalsHandle);
    BytecodeNode elseNode = elseBlock;
    for (BytecodeNode testNode : testValues) {
        LabelNode testLabel = new LabelNode("test");
        IfStatement test = new IfStatement();
        BytecodeNode equalsCall = generatorContext.generateCall(EQUAL.name(), equalsFunction, ImmutableList.of(value, testNode));
        test.condition().visitLabel(testLabel).append(equalsCall);
        if (checkForNulls) {
            IfStatement wasNullCheck = new IfStatement("if wasNull, set caseWasNull to true, clear wasNull, pop boolean, and goto next test value");
            wasNullCheck.condition(wasNull);
            wasNullCheck.ifTrue(new BytecodeBlock().append(caseWasNull.set(constantTrue())).append(wasNull.set(constantFalse())).pop(boolean.class).gotoLabel(elseLabel));
            test.condition().append(wasNullCheck);
        }
        test.ifTrue().gotoLabel(matchLabel);
        test.ifFalse(elseNode);
        elseNode = test;
        elseLabel = testLabel;
    }
    caseBlock.append(elseNode);
    return caseBlock;
}
Also used : LabelNode(io.airlift.bytecode.instruction.LabelNode) IfStatement(io.airlift.bytecode.control.IfStatement) Variable(io.airlift.bytecode.Variable) BuiltInScalarFunctionImplementation(io.prestosql.spi.function.BuiltInScalarFunctionImplementation) BytecodeBlock(io.airlift.bytecode.BytecodeBlock) BytecodeNode(io.airlift.bytecode.BytecodeNode) FunctionHandle(io.prestosql.spi.function.FunctionHandle)

Example 44 with FunctionHandle

use of io.prestosql.spi.function.FunctionHandle in project hetu-core by openlookeng.

the class ValuePrinter method castToVarcharOrFail.

public String castToVarcharOrFail(Type type, Object value) throws OperatorNotFoundException {
    if (value == null) {
        return "NULL";
    }
    FunctionHandle coercion = metadata.getFunctionAndTypeManager().lookupCast(CAST, type.getTypeSignature(), VARCHAR.getTypeSignature());
    Slice coerced = (Slice) new InterpretedFunctionInvoker(metadata.getFunctionAndTypeManager()).invoke(coercion, session.toConnectorSession(), value);
    return coerced.toStringUtf8();
}
Also used : InterpretedFunctionInvoker(io.prestosql.sql.InterpretedFunctionInvoker) Slice(io.airlift.slice.Slice) FunctionHandle(io.prestosql.spi.function.FunctionHandle)

Example 45 with FunctionHandle

use of io.prestosql.spi.function.FunctionHandle in project hetu-core by openlookeng.

the class TestTypeValidator method testInvalidWindowFunctionCall.

@Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "type of symbol 'sum(_[0-9]+)?' is expected to be double, but the actual type is bigint")
public void testInvalidWindowFunctionCall() {
    Symbol windowSymbol = planSymbolAllocator.newSymbol("sum", DOUBLE);
    FunctionHandle functionHandle = FUNCTION_MANAGER.lookupFunction("sum", fromTypes(DOUBLE));
    WindowNode.Frame frame = new WindowNode.Frame(WindowFrameType.RANGE, FrameBoundType.UNBOUNDED_PRECEDING, Optional.empty(), FrameBoundType.UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty(), Optional.empty());
    WindowNode.Function function = new WindowNode.Function(call("sum", functionHandle, BIGINT, ImmutableList.of(VariableReferenceSymbolConverter.toVariableReference(columnA, BIGINT))), ImmutableList.of(VariableReferenceSymbolConverter.toVariableReference(columnA, BIGINT)), frame);
    WindowNode.Specification specification = new WindowNode.Specification(ImmutableList.of(), Optional.empty());
    PlanNode node = new WindowNode(newId(), baseTableScan, specification, ImmutableMap.of(windowSymbol, function), Optional.empty(), ImmutableSet.of(), 0);
    assertTypesValid(node);
}
Also used : WindowNode(io.prestosql.spi.plan.WindowNode) PlanNode(io.prestosql.spi.plan.PlanNode) Symbol(io.prestosql.spi.plan.Symbol) FunctionHandle(io.prestosql.spi.function.FunctionHandle) Test(org.testng.annotations.Test)

Aggregations

FunctionHandle (io.prestosql.spi.function.FunctionHandle)49 RowExpression (io.prestosql.spi.relation.RowExpression)24 CallExpression (io.prestosql.spi.relation.CallExpression)21 Type (io.prestosql.spi.type.Type)21 BuiltInFunctionHandle (io.prestosql.spi.function.BuiltInFunctionHandle)13 Symbol (io.prestosql.spi.plan.Symbol)13 ConstantExpression (io.prestosql.spi.relation.ConstantExpression)13 ImmutableList (com.google.common.collect.ImmutableList)11 Test (org.testng.annotations.Test)11 PrestoException (io.prestosql.spi.PrestoException)10 BuiltInScalarFunctionImplementation (io.prestosql.spi.function.BuiltInScalarFunctionImplementation)10 Signature (io.prestosql.spi.function.Signature)10 VariableReferenceExpression (io.prestosql.spi.relation.VariableReferenceExpression)10 PlanNode (io.prestosql.spi.plan.PlanNode)9 Map (java.util.Map)9 List (java.util.List)8 ArrayList (java.util.ArrayList)7 Optional (java.util.Optional)7 BytecodeBlock (io.airlift.bytecode.BytecodeBlock)6 Variable (io.airlift.bytecode.Variable)6