Search in sources :

Example 6 with GenericLiteral

use of io.prestosql.sql.tree.GenericLiteral in project hetu-core by openlookeng.

the class TransformUnCorrelatedInPredicateSubQuerySelfJoinToAggregate method transformProjectNode.

private Optional<ProjectNode> transformProjectNode(Context context, ProjectNode projectNode) {
    // IN predicate requires only one projection
    if (projectNode.getOutputSymbols().size() > 1) {
        return Optional.empty();
    }
    PlanNode source = context.getLookup().resolve(projectNode.getSource());
    if (source instanceof CTEScanNode) {
        source = getChildFilterNode(context, context.getLookup().resolve(((CTEScanNode) source).getSource()));
    }
    if (!(source instanceof FilterNode && context.getLookup().resolve(((FilterNode) source).getSource()) instanceof JoinNode)) {
        return Optional.empty();
    }
    FilterNode filter = (FilterNode) source;
    Expression predicate = OriginalExpressionUtils.castToExpression(filter.getPredicate());
    List<SymbolReference> allPredicateSymbols = new ArrayList<>();
    getAllSymbols(predicate, allPredicateSymbols);
    JoinNode joinNode = (JoinNode) context.getLookup().resolve(((FilterNode) source).getSource());
    if (!isSelfJoin(projectNode, predicate, joinNode, context.getLookup())) {
        // Check next level for Self Join
        PlanNode left = context.getLookup().resolve(joinNode.getLeft());
        boolean changed = false;
        if (left instanceof ProjectNode) {
            Optional<ProjectNode> transformResult = transformProjectNode(context, (ProjectNode) left);
            if (transformResult.isPresent()) {
                joinNode = new JoinNode(joinNode.getId(), joinNode.getType(), transformResult.get(), joinNode.getRight(), joinNode.getCriteria(), joinNode.getOutputSymbols(), joinNode.getFilter(), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType(), joinNode.isSpillable(), joinNode.getDynamicFilters());
                changed = true;
            }
        }
        PlanNode right = context.getLookup().resolve(joinNode.getRight());
        if (right instanceof ProjectNode) {
            Optional<ProjectNode> transformResult = transformProjectNode(context, (ProjectNode) right);
            if (transformResult.isPresent()) {
                joinNode = new JoinNode(joinNode.getId(), joinNode.getType(), joinNode.getLeft(), transformResult.get(), joinNode.getCriteria(), joinNode.getOutputSymbols(), joinNode.getFilter(), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType(), joinNode.isSpillable(), joinNode.getDynamicFilters());
                changed = true;
            }
        }
        if (changed) {
            FilterNode transformedFilter = new FilterNode(filter.getId(), joinNode, filter.getPredicate());
            ProjectNode transformedProject = new ProjectNode(projectNode.getId(), transformedFilter, projectNode.getAssignments());
            return Optional.of(transformedProject);
        }
        return Optional.empty();
    }
    // Choose the table to use based on projected output.
    TableScanNode leftTable = (TableScanNode) context.getLookup().resolve(joinNode.getLeft());
    TableScanNode rightTable = (TableScanNode) context.getLookup().resolve(joinNode.getRight());
    TableScanNode tableToUse;
    List<RowExpression> aggregationSymbols;
    AggregationNode.GroupingSetDescriptor groupingSetDescriptor;
    Assignments projectionsForCTE = null;
    Assignments projectionsFromFilter = null;
    // Use non-projected column for aggregation
    if (context.getLookup().resolve(projectNode.getSource()) instanceof CTEScanNode) {
        CTEScanNode cteScanNode = (CTEScanNode) context.getLookup().resolve(projectNode.getSource());
        ProjectNode childProjectOfCte = (ProjectNode) context.getLookup().resolve(cteScanNode.getSource());
        List<Symbol> completeOutputSymbols = new ArrayList<>();
        rightTable.getOutputSymbols().forEach((s) -> completeOutputSymbols.add(s));
        leftTable.getOutputSymbols().forEach((s) -> completeOutputSymbols.add(s));
        List<Symbol> outputSymbols = new ArrayList<>();
        for (int i = 0; i < completeOutputSymbols.size(); i++) {
            Symbol outputSymbol = completeOutputSymbols.get(i);
            for (Symbol symbol : projectNode.getOutputSymbols()) {
                if (childProjectOfCte.getAssignments().getMap().containsKey(symbol)) {
                    if (((SymbolReference) OriginalExpressionUtils.castToExpression(childProjectOfCte.getAssignments().getMap().get(symbol))).getName().equals(outputSymbol.getName())) {
                        outputSymbols.add(outputSymbol);
                    }
                }
            }
        }
        Map<Symbol, RowExpression> projectionsForCTEMap = new HashMap<>();
        Map<Symbol, RowExpression> projectionsFromFilterMap = new HashMap<>();
        for (Map.Entry entry : childProjectOfCte.getAssignments().getMap().entrySet()) {
            if (entry.getKey().equals(getOnlyElement(projectNode.getOutputSymbols()))) {
                projectionsForCTEMap.put((Symbol) entry.getKey(), (RowExpression) entry.getValue());
            }
            if (entry.getKey().equals(getOnlyElement(projectNode.getOutputSymbols()))) {
                projectionsFromFilterMap.put(getOnlyElement(outputSymbols), (RowExpression) entry.getValue());
            }
        }
        projectionsForCTE = new Assignments(projectionsForCTEMap);
        projectionsFromFilter = new Assignments(projectionsFromFilterMap);
        tableToUse = leftTable.getOutputSymbols().contains(getOnlyElement(outputSymbols)) ? leftTable : rightTable;
        aggregationSymbols = allPredicateSymbols.stream().filter(s -> tableToUse.getOutputSymbols().contains(SymbolUtils.from(s))).filter(s -> !outputSymbols.contains(SymbolUtils.from(s))).map(OriginalExpressionUtils::castToRowExpression).collect(Collectors.toList());
        // Create aggregation
        groupingSetDescriptor = new AggregationNode.GroupingSetDescriptor(ImmutableList.copyOf(outputSymbols), 1, ImmutableSet.of());
    } else {
        tableToUse = leftTable.getOutputSymbols().contains(getOnlyElement(projectNode.getOutputSymbols())) ? leftTable : rightTable;
        aggregationSymbols = allPredicateSymbols.stream().filter(s -> tableToUse.getOutputSymbols().contains(SymbolUtils.from(s))).filter(s -> !projectNode.getOutputSymbols().contains(SymbolUtils.from(s))).map(OriginalExpressionUtils::castToRowExpression).collect(Collectors.toList());
        // Create aggregation
        groupingSetDescriptor = new AggregationNode.GroupingSetDescriptor(ImmutableList.copyOf(projectNode.getOutputSymbols()), 1, ImmutableSet.of());
    }
    AggregationNode.Aggregation aggregation = new AggregationNode.Aggregation(new CallExpression("count", functionResolution.countFunction(), BIGINT, aggregationSymbols), aggregationSymbols, // mark DISTINCT since NOT_EQUALS predicate
    true, Optional.empty(), Optional.empty(), Optional.empty());
    ImmutableMap.Builder<Symbol, AggregationNode.Aggregation> aggregationsBuilder = ImmutableMap.builder();
    Symbol countSymbol = context.getSymbolAllocator().newSymbol(aggregation.getFunctionCall().getDisplayName(), BIGINT);
    aggregationsBuilder.put(countSymbol, aggregation);
    AggregationNode aggregationNode = new AggregationNode(context.getIdAllocator().getNextId(), tableToUse, aggregationsBuilder.build(), groupingSetDescriptor, ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty(), AggregationNode.AggregationType.HASH, Optional.empty());
    // Filter rows with count < 1 from aggregation results to match the NOT_EQUALS clause in original query.
    FilterNode filterNode = new FilterNode(context.getIdAllocator().getNextId(), aggregationNode, OriginalExpressionUtils.castToRowExpression(new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, SymbolUtils.toSymbolReference(countSymbol), new GenericLiteral("BIGINT", "1"))));
    // Project the aggregated+filtered rows.
    ProjectNode transformedSubquery = new ProjectNode(projectNode.getId(), filterNode, projectNode.getAssignments());
    if (context.getLookup().resolve(projectNode.getSource()) instanceof CTEScanNode) {
        CTEScanNode cteScanNode = (CTEScanNode) context.getLookup().resolve(projectNode.getSource());
        PlanNode projectNodeForCTE = context.getLookup().resolve(cteScanNode.getSource());
        PlanNode projectNodeFromFilter = context.getLookup().resolve(projectNodeForCTE);
        projectNodeFromFilter = new ProjectNode(projectNodeFromFilter.getId(), filterNode, projectionsFromFilter);
        projectNodeForCTE = new ProjectNode(projectNodeForCTE.getId(), projectNodeFromFilter, projectionsForCTE);
        cteScanNode = (CTEScanNode) cteScanNode.replaceChildren(ImmutableList.of(projectNodeForCTE));
        cteScanNode.setOutputSymbols(projectNode.getOutputSymbols());
        transformedSubquery = new ProjectNode(projectNode.getId(), cteScanNode, projectNode.getAssignments());
    }
    return Optional.of(transformedSubquery);
}
Also used : Apply.subQuery(io.prestosql.sql.planner.plan.Patterns.Apply.subQuery) Lookup(io.prestosql.sql.planner.iterative.Lookup) Patterns.applyNode(io.prestosql.sql.planner.plan.Patterns.applyNode) HashMap(java.util.HashMap) Pattern(io.prestosql.matching.Pattern) StandardFunctionResolution(io.prestosql.spi.function.StandardFunctionResolution) CTEScanNode(io.prestosql.spi.plan.CTEScanNode) ArrayList(java.util.ArrayList) AggregationNode(io.prestosql.spi.plan.AggregationNode) CallExpression(io.prestosql.spi.relation.CallExpression) Capture.newCapture(io.prestosql.matching.Capture.newCapture) ImmutableList(com.google.common.collect.ImmutableList) FilterNode(io.prestosql.spi.plan.FilterNode) Map(java.util.Map) Objects.requireNonNull(java.util.Objects.requireNonNull) Session(io.prestosql.Session) BIGINT(io.prestosql.spi.type.BigintType.BIGINT) Patterns.project(io.prestosql.sql.planner.plan.Patterns.project) JoinNode(io.prestosql.spi.plan.JoinNode) Symbol(io.prestosql.spi.plan.Symbol) SymbolUtils(io.prestosql.sql.planner.SymbolUtils) LogicalBinaryExpression(io.prestosql.sql.tree.LogicalBinaryExpression) ApplyNode(io.prestosql.sql.planner.plan.ApplyNode) ImmutableSet(com.google.common.collect.ImmutableSet) ImmutableMap(com.google.common.collect.ImmutableMap) Assignments(io.prestosql.spi.plan.Assignments) Rule(io.prestosql.sql.planner.iterative.Rule) Pattern.empty(io.prestosql.matching.Pattern.empty) TableScanNode(io.prestosql.spi.plan.TableScanNode) ComparisonExpression(io.prestosql.sql.tree.ComparisonExpression) Iterables.getOnlyElement(com.google.common.collect.Iterables.getOnlyElement) PlanNode(io.prestosql.spi.plan.PlanNode) ProjectNode(io.prestosql.spi.plan.ProjectNode) Collectors(java.util.stream.Collectors) Metadata(io.prestosql.metadata.Metadata) Captures(io.prestosql.matching.Captures) List(java.util.List) FunctionResolution(io.prestosql.sql.relational.FunctionResolution) GenericLiteral(io.prestosql.sql.tree.GenericLiteral) SymbolReference(io.prestosql.sql.tree.SymbolReference) TRANSFORM_SELF_JOIN_TO_GROUPBY(io.prestosql.SystemSessionProperties.TRANSFORM_SELF_JOIN_TO_GROUPBY) Apply.correlation(io.prestosql.sql.planner.plan.Patterns.Apply.correlation) Capture(io.prestosql.matching.Capture) RowExpression(io.prestosql.spi.relation.RowExpression) Optional(java.util.Optional) Expression(io.prestosql.sql.tree.Expression) OriginalExpressionUtils(io.prestosql.sql.relational.OriginalExpressionUtils) InPredicate(io.prestosql.sql.tree.InPredicate) HashMap(java.util.HashMap) SymbolReference(io.prestosql.sql.tree.SymbolReference) Symbol(io.prestosql.spi.plan.Symbol) FilterNode(io.prestosql.spi.plan.FilterNode) ArrayList(java.util.ArrayList) Assignments(io.prestosql.spi.plan.Assignments) GenericLiteral(io.prestosql.sql.tree.GenericLiteral) PlanNode(io.prestosql.spi.plan.PlanNode) CTEScanNode(io.prestosql.spi.plan.CTEScanNode) CallExpression(io.prestosql.spi.relation.CallExpression) JoinNode(io.prestosql.spi.plan.JoinNode) RowExpression(io.prestosql.spi.relation.RowExpression) AggregationNode(io.prestosql.spi.plan.AggregationNode) ImmutableMap(com.google.common.collect.ImmutableMap) ComparisonExpression(io.prestosql.sql.tree.ComparisonExpression) TableScanNode(io.prestosql.spi.plan.TableScanNode) CallExpression(io.prestosql.spi.relation.CallExpression) LogicalBinaryExpression(io.prestosql.sql.tree.LogicalBinaryExpression) ComparisonExpression(io.prestosql.sql.tree.ComparisonExpression) RowExpression(io.prestosql.spi.relation.RowExpression) Expression(io.prestosql.sql.tree.Expression) ProjectNode(io.prestosql.spi.plan.ProjectNode) OriginalExpressionUtils(io.prestosql.sql.relational.OriginalExpressionUtils) HashMap(java.util.HashMap) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap)

Example 7 with GenericLiteral

use of io.prestosql.sql.tree.GenericLiteral in project hetu-core by openlookeng.

the class GroupingOperationRewriter method rewriteGroupingOperation.

public static Expression rewriteGroupingOperation(GroupingOperation expression, List<Set<Integer>> groupingSets, Map<NodeRef<Expression>, FieldId> columnReferenceFields, Optional<Symbol> groupIdSymbol) {
    requireNonNull(groupIdSymbol, "groupIdSymbol is null");
    // See SQL:2011:4.16.2 and SQL:2011:6.9.10.
    if (groupingSets.size() == 1) {
        return new LongLiteral("0");
    } else {
        checkState(groupIdSymbol.isPresent(), "groupId symbol is missing");
        RelationId relationId = columnReferenceFields.get(NodeRef.of(expression.getGroupingColumns().get(0))).getRelationId();
        List<Integer> columns = expression.getGroupingColumns().stream().map(NodeRef::of).peek(groupingColumn -> checkState(columnReferenceFields.containsKey(groupingColumn), "the grouping column is not in the columnReferencesField map")).map(columnReferenceFields::get).map(fieldId -> translateFieldToInteger(fieldId, relationId)).collect(toImmutableList());
        List<Expression> groupingResults = groupingSets.stream().map(groupingSet -> String.valueOf(calculateGrouping(groupingSet, columns))).map(LongLiteral::new).collect(toImmutableList());
        // It is necessary to add a 1 to the groupId because the underlying array is indexed starting at 1
        return new SubscriptExpression(new ArrayConstructor(groupingResults), new ArithmeticBinaryExpression(ADD, toSymbolReference(groupIdSymbol.get()), new GenericLiteral("BIGINT", "1")));
    }
}
Also used : Symbol(io.prestosql.spi.plan.Symbol) GroupingOperation(io.prestosql.sql.tree.GroupingOperation) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) Set(java.util.Set) RelationId(io.prestosql.sql.analyzer.RelationId) NodeRef(io.prestosql.sql.tree.NodeRef) Preconditions.checkState(com.google.common.base.Preconditions.checkState) SymbolUtils.toSymbolReference(io.prestosql.sql.planner.SymbolUtils.toSymbolReference) List(java.util.List) ArithmeticBinaryExpression(io.prestosql.sql.tree.ArithmeticBinaryExpression) GenericLiteral(io.prestosql.sql.tree.GenericLiteral) LongLiteral(io.prestosql.sql.tree.LongLiteral) ADD(io.prestosql.sql.tree.ArithmeticBinaryExpression.Operator.ADD) Map(java.util.Map) Objects.requireNonNull(java.util.Objects.requireNonNull) ArrayConstructor(io.prestosql.sql.tree.ArrayConstructor) Optional(java.util.Optional) SubscriptExpression(io.prestosql.sql.tree.SubscriptExpression) FieldId(io.prestosql.sql.analyzer.FieldId) Expression(io.prestosql.sql.tree.Expression) ArithmeticBinaryExpression(io.prestosql.sql.tree.ArithmeticBinaryExpression) LongLiteral(io.prestosql.sql.tree.LongLiteral) GenericLiteral(io.prestosql.sql.tree.GenericLiteral) NodeRef(io.prestosql.sql.tree.NodeRef) ArithmeticBinaryExpression(io.prestosql.sql.tree.ArithmeticBinaryExpression) SubscriptExpression(io.prestosql.sql.tree.SubscriptExpression) Expression(io.prestosql.sql.tree.Expression) RelationId(io.prestosql.sql.analyzer.RelationId) SubscriptExpression(io.prestosql.sql.tree.SubscriptExpression) ArrayConstructor(io.prestosql.sql.tree.ArrayConstructor)

Example 8 with GenericLiteral

use of io.prestosql.sql.tree.GenericLiteral in project hetu-core by openlookeng.

the class LiteralEncoder method toExpression.

public Expression toExpression(Object inputObject, Type type) {
    requireNonNull(type, "type is null");
    Object object = inputObject;
    if (object instanceof Expression) {
        return (Expression) object;
    }
    if (object == null) {
        if (type.equals(UNKNOWN)) {
            return new NullLiteral();
        }
        return new Cast(new NullLiteral(), type.getTypeSignature().toString(), false, true);
    }
    // AbstractIntType internally uses long as javaType. So specially handled for AbstractIntType types.
    Class<?> wrap = Primitives.wrap(type.getJavaType());
    checkArgument(wrap.isInstance(object) || (type instanceof AbstractIntType && wrap == Long.class && Integer.class.isInstance(object)), "object.getClass (%s) and type.getJavaType (%s) do not agree", object.getClass(), type.getJavaType());
    if (type.equals(TINYINT)) {
        return new GenericLiteral("TINYINT", object.toString());
    }
    if (type.equals(SMALLINT)) {
        return new GenericLiteral("SMALLINT", object.toString());
    }
    if (type.equals(INTEGER)) {
        return new LongLiteral(object.toString());
    }
    if (type.equals(BIGINT)) {
        LongLiteral expression = new LongLiteral(object.toString());
        if (expression.getValue() >= Integer.MIN_VALUE && expression.getValue() <= Integer.MAX_VALUE) {
            return new GenericLiteral("BIGINT", object.toString());
        }
        return new LongLiteral(object.toString());
    }
    if (type.equals(DOUBLE)) {
        Double value = (Double) object;
        // When changing this, don't forget about similar code for REAL below
        if (value.isNaN()) {
            return new FunctionCallBuilder(metadata).setName(QualifiedName.of("nan")).build();
        }
        if (value.equals(Double.NEGATIVE_INFINITY)) {
            return ArithmeticUnaryExpression.negative(new FunctionCallBuilder(metadata).setName(QualifiedName.of("infinity")).build());
        }
        if (value.equals(Double.POSITIVE_INFINITY)) {
            return new FunctionCallBuilder(metadata).setName(QualifiedName.of("infinity")).build();
        }
        return new DoubleLiteral(object.toString());
    }
    if (type.equals(REAL)) {
        Float value = intBitsToFloat(((Long) object).intValue());
        // WARNING for ORC predicate code as above (for double)
        if (value.isNaN()) {
            return new Cast(new FunctionCallBuilder(metadata).setName(QualifiedName.of("nan")).build(), StandardTypes.REAL);
        }
        if (value.equals(Float.NEGATIVE_INFINITY)) {
            return ArithmeticUnaryExpression.negative(new Cast(new FunctionCallBuilder(metadata).setName(QualifiedName.of("infinity")).build(), StandardTypes.REAL));
        }
        if (value.equals(Float.POSITIVE_INFINITY)) {
            return new Cast(new FunctionCallBuilder(metadata).setName(QualifiedName.of("infinity")).build(), StandardTypes.REAL);
        }
        return new GenericLiteral("REAL", value.toString());
    }
    if (type instanceof DecimalType) {
        String string;
        if (isShortDecimal(type)) {
            string = Decimals.toString((long) object, ((DecimalType) type).getScale());
        } else {
            string = Decimals.toString((Slice) object, ((DecimalType) type).getScale());
        }
        return new Cast(new DecimalLiteral(string), type.getDisplayName());
    }
    if (type instanceof VarcharType) {
        VarcharType varcharType = (VarcharType) type;
        Slice value = (Slice) object;
        StringLiteral stringLiteral = new StringLiteral(value.toStringUtf8());
        if (!varcharType.isUnbounded() && varcharType.getBoundedLength() == SliceUtf8.countCodePoints(value)) {
            return stringLiteral;
        }
        return new Cast(stringLiteral, type.getDisplayName(), false, true);
    }
    if (type instanceof CharType) {
        StringLiteral stringLiteral = new StringLiteral(((Slice) object).toStringUtf8());
        return new Cast(stringLiteral, type.getDisplayName(), false, true);
    }
    if (type.equals(BOOLEAN)) {
        return new BooleanLiteral(object.toString());
    }
    if (type.equals(DATE)) {
        return new GenericLiteral("DATE", new SqlDate(toIntExact((Long) object)).toString());
    }
    if (type.equals(TIMESTAMP)) {
        return new GenericLiteral("TIMESTAMP", new SqlTimestamp((Long) object).toString());
    }
    if (object instanceof Block) {
        SliceOutput output = new DynamicSliceOutput(toIntExact(((Block) object).getSizeInBytes()));
        BlockSerdeUtil.writeBlock(metadata.getFunctionAndTypeManager().getBlockEncodingSerde(), output, (Block) object);
        object = output.slice();
    // This if condition will evaluate to true: object instanceof Slice && !type.equals(VARCHAR)
    }
    Signature signature = LiteralFunction.getLiteralFunctionSignature(type);
    Type argumentType = typeForLiteralFunctionArgument(type);
    Expression argument;
    if (object instanceof Slice) {
        // HACK: we need to serialize VARBINARY in a format that can be embedded in an expression to be
        // able to encode it in the plan that gets sent to workers.
        // We do this by transforming the in-memory varbinary into a call to from_base64(<base64-encoded value>)
        Slice encoded = VarbinaryFunctions.toBase64((Slice) object);
        argument = new FunctionCallBuilder(metadata).setName(QualifiedName.of("from_base64")).addArgument(VARCHAR, new StringLiteral(encoded.toStringUtf8())).build();
    } else {
        argument = toExpression(object, argumentType);
    }
    return new FunctionCallBuilder(metadata).setName(QualifiedName.of(signature.getNameSuffix())).addArgument(argumentType, argument).build();
}
Also used : Cast(io.prestosql.sql.tree.Cast) SliceOutput(io.airlift.slice.SliceOutput) DynamicSliceOutput(io.airlift.slice.DynamicSliceOutput) VarcharType(io.prestosql.spi.type.VarcharType) BooleanLiteral(io.prestosql.sql.tree.BooleanLiteral) SqlTimestamp(io.prestosql.spi.type.SqlTimestamp) GenericLiteral(io.prestosql.sql.tree.GenericLiteral) DecimalLiteral(io.prestosql.sql.tree.DecimalLiteral) DynamicSliceOutput(io.airlift.slice.DynamicSliceOutput) LongLiteral(io.prestosql.sql.tree.LongLiteral) AbstractIntType(io.prestosql.spi.type.AbstractIntType) Float.intBitsToFloat(java.lang.Float.intBitsToFloat) DecimalType(io.prestosql.spi.type.DecimalType) Type(io.prestosql.spi.type.Type) CharType(io.prestosql.spi.type.CharType) AbstractIntType(io.prestosql.spi.type.AbstractIntType) VarcharType(io.prestosql.spi.type.VarcharType) StringLiteral(io.prestosql.sql.tree.StringLiteral) ArithmeticUnaryExpression(io.prestosql.sql.tree.ArithmeticUnaryExpression) RowExpression(io.prestosql.spi.relation.RowExpression) Expression(io.prestosql.sql.tree.Expression) Slice(io.airlift.slice.Slice) SqlDate(io.prestosql.spi.type.SqlDate) TypeSignature(io.prestosql.spi.type.TypeSignature) Signature(io.prestosql.spi.function.Signature) DecimalType(io.prestosql.spi.type.DecimalType) Block(io.prestosql.spi.block.Block) DoubleLiteral(io.prestosql.sql.tree.DoubleLiteral) CharType(io.prestosql.spi.type.CharType) NullLiteral(io.prestosql.sql.tree.NullLiteral)

Example 9 with GenericLiteral

use of io.prestosql.sql.tree.GenericLiteral in project hetu-core by openlookeng.

the class HashGenerationOptimizer method getHashExpression.

public static Optional<Expression> getHashExpression(Metadata metadata, PlanSymbolAllocator planSymbolAllocator, List<Symbol> symbols) {
    if (symbols.isEmpty()) {
        return Optional.empty();
    }
    Expression result = new GenericLiteral(StandardTypes.BIGINT, String.valueOf(INITIAL_HASH_VALUE));
    for (Symbol symbol : symbols) {
        Expression hashField = new FunctionCallBuilder(metadata).setName(QualifiedName.of(HASH_CODE)).addArgument(planSymbolAllocator.getTypes().get(symbol), new SymbolReference(symbol.getName())).build();
        hashField = new CoalesceExpression(hashField, new LongLiteral(String.valueOf(NULL_HASH_CODE)));
        result = new FunctionCallBuilder(metadata).setName(QualifiedName.of("combine_hash")).addArgument(BIGINT, result).addArgument(BIGINT, hashField).build();
    }
    return Optional.of(result);
}
Also used : CallExpression(io.prestosql.spi.relation.CallExpression) CoalesceExpression(io.prestosql.sql.tree.CoalesceExpression) OriginalExpressionUtils.isExpression(io.prestosql.sql.relational.OriginalExpressionUtils.isExpression) VariableReferenceExpression(io.prestosql.spi.relation.VariableReferenceExpression) RowExpression(io.prestosql.spi.relation.RowExpression) Expression(io.prestosql.sql.tree.Expression) LongLiteral(io.prestosql.sql.tree.LongLiteral) Symbol(io.prestosql.spi.plan.Symbol) SymbolReference(io.prestosql.sql.tree.SymbolReference) CoalesceExpression(io.prestosql.sql.tree.CoalesceExpression) GenericLiteral(io.prestosql.sql.tree.GenericLiteral) FunctionCallBuilder(io.prestosql.sql.planner.FunctionCallBuilder)

Example 10 with GenericLiteral

use of io.prestosql.sql.tree.GenericLiteral in project hetu-core by openlookeng.

the class CubeConsole method isSupportedExpression.

/**
 * Gets whether the expression is supported for processing the create cube query
 *
 * @param createCube createCube
 * @return void
 */
private boolean isSupportedExpression(CreateCube createCube, QueryRunner queryRunner, ClientOptions.OutputFormat outputFormat, Runnable schemaChanged, boolean usePager, boolean showProgress, Terminal terminal, PrintStream out, PrintStream errorChannel) {
    boolean supportedExpression = false;
    boolean success = true;
    Optional<Expression> expression = createCube.getWhere();
    if (expression.isPresent()) {
        ImmutableSet.Builder<Identifier> identifierBuilder = new ImmutableSet.Builder<>();
        new DefaultExpressionTraversalVisitor<Void, ImmutableSet.Builder<Identifier>>() {

            @Override
            protected Void visitIdentifier(Identifier node, ImmutableSet.Builder<Identifier> builder) {
                builder.add(node);
                return null;
            }
        }.process(expression.get(), identifierBuilder);
        int sizeIdentifiers = identifierBuilder.build().asList().size();
        if (sizeIdentifiers == SUPPORTED_INDENTIFIER_SIZE) {
            String whereClause = createCube.getWhere().get().toString();
            QualifiedName sourceTableName = createCube.getSourceTableName();
            if (expression.get() instanceof BetweenPredicate) {
                BetweenPredicate betweenPredicate = (BetweenPredicate) (expression.get());
                String columnName = betweenPredicate.getValue().toString();
                String columnDataTypeQuery;
                String catalogName;
                String tableName = sourceTableName.getSuffix();
                checkArgument(tableName.matches("[\\p{Alnum}_]+"), "Invalid table name");
                if (hasInvalidSymbol(columnName)) {
                    return false;
                }
                if (sourceTableName.getPrefix().isPresent() && sourceTableName.getPrefix().get().getPrefix().isPresent()) {
                    catalogName = sourceTableName.getPrefix().get().getPrefix().get().toString();
                    checkArgument(catalogName.matches("[\\p{Alnum}_]+"), "Invalid catalog name");
                    columnDataTypeQuery = String.format(SELECT_DATA_TYPE_STRING, catalogName, tableName, columnName);
                } else if (queryRunner.getSession().getCatalog() != null) {
                    catalogName = queryRunner.getSession().getCatalog();
                    checkArgument(catalogName.matches("[\\p{Alnum}_]+"), "Invalid catalog name");
                    columnDataTypeQuery = String.format(SELECT_DATA_TYPE_STRING, catalogName, tableName, columnName);
                } else {
                    return false;
                }
                if (!processCubeInitialQuery(queryRunner, columnDataTypeQuery, outputFormat, schemaChanged, usePager, showProgress, terminal, out, errorChannel)) {
                    return false;
                }
                String resInitCubeQuery;
                resInitCubeQuery = getResultInitCubeQuery();
                if (resInitCubeQuery != null) {
                    cubeColumnDataType = resInitCubeQuery;
                }
                if (cubeColumnDataType.contains(DATATYPE_DECIMAL)) {
                    cubeColumnDataType = DATATYPE_DECIMAL;
                }
                if (cubeColumnDataType.contains(DATATYPE_VARCHAR)) {
                    cubeColumnDataType = DATATYPE_VARCHAR;
                }
                if (!isSupportedDatatype(cubeColumnDataType)) {
                    return false;
                }
                if (betweenPredicate.getMin() instanceof LongLiteral || betweenPredicate.getMin() instanceof LongLiteral || betweenPredicate.getMin() instanceof TimestampLiteral || betweenPredicate.getMin() instanceof TimestampLiteral || betweenPredicate.getMin() instanceof GenericLiteral || betweenPredicate.getMin() instanceof StringLiteral || betweenPredicate.getMin() instanceof DoubleLiteral) {
                    if (betweenPredicate.getMax() instanceof LongLiteral || betweenPredicate.getMax() instanceof LongLiteral || betweenPredicate.getMax() instanceof TimestampLiteral || betweenPredicate.getMax() instanceof TimestampLiteral || betweenPredicate.getMax() instanceof GenericLiteral || betweenPredicate.getMax() instanceof StringLiteral || betweenPredicate.getMax() instanceof DoubleLiteral) {
                        // initial query to get the total number of distinct column values in the table
                        String countDistinctQuery = String.format(SELECT_COUNT_DISTINCT_FROM_STRING, columnName, sourceTableName.toString(), whereClause);
                        if (!processCubeInitialQuery(queryRunner, countDistinctQuery, outputFormat, schemaChanged, usePager, showProgress, terminal, out, errorChannel)) {
                            return false;
                        }
                        Long valueCountDistinctQuery = INITIAL_QUERY_RESULT_VALUE;
                        resInitCubeQuery = getResultInitCubeQuery();
                        if (resInitCubeQuery != null) {
                            valueCountDistinctQuery = Long.parseLong(resInitCubeQuery);
                        }
                        if (valueCountDistinctQuery < MAX_BUFFERED_ROWS && valueCountDistinctQuery * rowBufferTempMultiplier < Integer.MAX_VALUE) {
                            supportedExpression = true;
                            rowBufferListSize = (int) ((valueCountDistinctQuery).intValue() * rowBufferTempMultiplier);
                        }
                    }
                }
            }
            if (expression.get() instanceof ComparisonExpression) {
                ComparisonExpression comparisonExpression = (ComparisonExpression) (createCube.getWhere().get());
                ComparisonExpression.Operator operator = comparisonExpression.getOperator();
                Expression left = comparisonExpression.getLeft();
                Expression right = comparisonExpression.getRight();
                if (!(left instanceof SymbolReference) && right instanceof SymbolReference) {
                    comparisonExpression = new ComparisonExpression(operator.flip(), right, left);
                }
                if (left instanceof Literal && !(right instanceof Literal)) {
                    comparisonExpression = new ComparisonExpression(operator.flip(), right, left);
                }
                if (comparisonExpression.getRight() instanceof LongLiteral) {
                    supportedExpression = true;
                }
                String catalogName;
                String tableName = sourceTableName.getSuffix();
                String columnName = comparisonExpression.getLeft().toString();
                String columnDataTypeQuery;
                checkArgument(tableName.matches("[\\p{Alnum}_]+"), "Invalid table name");
                if (hasInvalidSymbol(columnName)) {
                    return false;
                }
                if (sourceTableName.getPrefix().isPresent() && sourceTableName.getPrefix().get().getPrefix().isPresent()) {
                    catalogName = sourceTableName.getPrefix().get().getPrefix().get().toString();
                    checkArgument(catalogName.matches("[\\p{Alnum}_]+"), "Invalid catalog name");
                    columnDataTypeQuery = String.format(SELECT_DATA_TYPE_STRING, catalogName, tableName, columnName);
                } else if (queryRunner.getSession().getCatalog() != null) {
                    catalogName = queryRunner.getSession().getCatalog();
                    checkArgument(catalogName.matches("[\\p{Alnum}_]+"), "Invalid catalog name");
                    columnDataTypeQuery = String.format(SELECT_DATA_TYPE_STRING, catalogName, tableName, columnName);
                } else {
                    return false;
                }
                if (!processCubeInitialQuery(queryRunner, columnDataTypeQuery, outputFormat, schemaChanged, usePager, showProgress, terminal, out, errorChannel)) {
                    return false;
                }
                String resInitCubeQuery;
                resInitCubeQuery = getResultInitCubeQuery();
                if (resInitCubeQuery != null) {
                    cubeColumnDataType = resInitCubeQuery.toLowerCase(Locale.ENGLISH);
                }
                if (cubeColumnDataType.contains(DATATYPE_DECIMAL)) {
                    cubeColumnDataType = DATATYPE_DECIMAL;
                }
                if (cubeColumnDataType.contains(DATATYPE_VARCHAR)) {
                    cubeColumnDataType = DATATYPE_VARCHAR;
                }
                if (!isSupportedDatatype(cubeColumnDataType)) {
                    return false;
                }
                if (comparisonExpression.getRight() instanceof GenericLiteral || comparisonExpression.getRight() instanceof StringLiteral || comparisonExpression.getRight() instanceof DoubleLiteral || comparisonExpression.getRight() instanceof LongLiteral || comparisonExpression.getRight() instanceof TimestampLiteral) {
                    // initial query to get the total number of distinct column values in the table
                    String countDistinctQuery = String.format(SELECT_COUNT_DISTINCT_FROM_STRING, columnName, sourceTableName.toString(), whereClause);
                    if (!processCubeInitialQuery(queryRunner, countDistinctQuery, outputFormat, schemaChanged, usePager, showProgress, terminal, out, errorChannel)) {
                        return false;
                    }
                    Long valueCountDistinctQuery = INITIAL_QUERY_RESULT_VALUE;
                    resInitCubeQuery = getResultInitCubeQuery();
                    if (resInitCubeQuery != null) {
                        valueCountDistinctQuery = Long.parseLong(resInitCubeQuery);
                    }
                    if (valueCountDistinctQuery < MAX_BUFFERED_ROWS) {
                        supportedExpression = true;
                    }
                }
            }
        }
    }
    return supportedExpression;
}
Also used : BetweenPredicate(io.prestosql.sql.tree.BetweenPredicate) TimestampLiteral(io.prestosql.sql.tree.TimestampLiteral) LongLiteral(io.prestosql.sql.tree.LongLiteral) SymbolReference(io.prestosql.sql.tree.SymbolReference) QualifiedName(io.prestosql.sql.tree.QualifiedName) GenericLiteral(io.prestosql.sql.tree.GenericLiteral) ComparisonExpression(io.prestosql.sql.tree.ComparisonExpression) Identifier(io.prestosql.sql.tree.Identifier) ImmutableSet(com.google.common.collect.ImmutableSet) StringLiteral(io.prestosql.sql.tree.StringLiteral) ComparisonExpression(io.prestosql.sql.tree.ComparisonExpression) Expression(io.prestosql.sql.tree.Expression) Literal(io.prestosql.sql.tree.Literal) TimestampLiteral(io.prestosql.sql.tree.TimestampLiteral) DoubleLiteral(io.prestosql.sql.tree.DoubleLiteral) GenericLiteral(io.prestosql.sql.tree.GenericLiteral) LongLiteral(io.prestosql.sql.tree.LongLiteral) StringLiteral(io.prestosql.sql.tree.StringLiteral) DoubleLiteral(io.prestosql.sql.tree.DoubleLiteral)

Aggregations

GenericLiteral (io.prestosql.sql.tree.GenericLiteral)10 Expression (io.prestosql.sql.tree.Expression)6 Symbol (io.prestosql.spi.plan.Symbol)5 ComparisonExpression (io.prestosql.sql.tree.ComparisonExpression)5 StringLiteral (io.prestosql.sql.tree.StringLiteral)5 DoubleLiteral (io.prestosql.sql.tree.DoubleLiteral)4 LongLiteral (io.prestosql.sql.tree.LongLiteral)4 FilterNode (io.prestosql.spi.plan.FilterNode)3 PlanNode (io.prestosql.spi.plan.PlanNode)3 ProjectNode (io.prestosql.spi.plan.ProjectNode)3 Preconditions.checkState (com.google.common.base.Preconditions.checkState)2 ImmutableList.toImmutableList (com.google.common.collect.ImmutableList.toImmutableList)2 ImmutableMap (com.google.common.collect.ImmutableMap)2 ImmutableSet (com.google.common.collect.ImmutableSet)2 Slice (io.airlift.slice.Slice)2 Session (io.prestosql.Session)2 Metadata (io.prestosql.metadata.Metadata)2 Signature (io.prestosql.spi.function.Signature)2 RowExpression (io.prestosql.spi.relation.RowExpression)2 SymbolReference (io.prestosql.sql.tree.SymbolReference)2