Search in sources :

Example 41 with RelDataType

use of org.apache.calcite.rel.type.RelDataType in project drill by apache.

the class InsertLocalExchangeVisitor method visitExchange.

@Override
public Prel visitExchange(ExchangePrel prel, Void value) throws RuntimeException {
    Prel child = ((Prel) prel.getInput()).accept(this, null);
    //   If DeMuxExchange is enabled, insert a UnorderedDeMuxExchangePrel after HashToRandomExchangePrel.
    if (!(prel instanceof HashToRandomExchangePrel)) {
        return (Prel) prel.copy(prel.getTraitSet(), Collections.singletonList(((RelNode) child)));
    }
    Prel newPrel = child;
    final HashToRandomExchangePrel hashPrel = (HashToRandomExchangePrel) prel;
    final List<String> childFields = child.getRowType().getFieldNames();
    List<RexNode> removeUpdatedExpr = null;
    if (isMuxEnabled) {
        // Insert Project Operator with new column that will be a hash for HashToRandomExchange fields
        final List<DistributionField> distFields = hashPrel.getFields();
        final List<String> outputFieldNames = Lists.newArrayList(childFields);
        final RexBuilder rexBuilder = prel.getCluster().getRexBuilder();
        final List<RelDataTypeField> childRowTypeFields = child.getRowType().getFieldList();
        final HashExpressionCreatorHelper<RexNode> hashHelper = new RexNodeBasedHashExpressionCreatorHelper(rexBuilder);
        final List<RexNode> distFieldRefs = Lists.newArrayListWithExpectedSize(distFields.size());
        for (int i = 0; i < distFields.size(); i++) {
            final int fieldId = distFields.get(i).getFieldId();
            distFieldRefs.add(rexBuilder.makeInputRef(childRowTypeFields.get(fieldId).getType(), fieldId));
        }
        final List<RexNode> updatedExpr = Lists.newArrayListWithExpectedSize(childRowTypeFields.size());
        removeUpdatedExpr = Lists.newArrayListWithExpectedSize(childRowTypeFields.size());
        for (RelDataTypeField field : childRowTypeFields) {
            RexNode rex = rexBuilder.makeInputRef(field.getType(), field.getIndex());
            updatedExpr.add(rex);
            removeUpdatedExpr.add(rex);
        }
        outputFieldNames.add(HashPrelUtil.HASH_EXPR_NAME);
        // distribution seed
        final RexNode distSeed = rexBuilder.makeBigintLiteral(BigDecimal.valueOf(HashPrelUtil.DIST_SEED));
        updatedExpr.add(HashPrelUtil.createHashBasedPartitionExpression(distFieldRefs, distSeed, hashHelper));
        RelDataType rowType = RexUtil.createStructType(prel.getCluster().getTypeFactory(), updatedExpr, outputFieldNames);
        ProjectPrel addColumnprojectPrel = new ProjectPrel(child.getCluster(), child.getTraitSet(), child, updatedExpr, rowType);
        newPrel = new UnorderedMuxExchangePrel(addColumnprojectPrel.getCluster(), addColumnprojectPrel.getTraitSet(), addColumnprojectPrel);
    }
    newPrel = new HashToRandomExchangePrel(prel.getCluster(), prel.getTraitSet(), newPrel, ((HashToRandomExchangePrel) prel).getFields());
    if (isDeMuxEnabled) {
        HashToRandomExchangePrel hashExchangePrel = (HashToRandomExchangePrel) newPrel;
        // Insert a DeMuxExchange to narrow down the number of receivers
        newPrel = new UnorderedDeMuxExchangePrel(prel.getCluster(), prel.getTraitSet(), hashExchangePrel, hashExchangePrel.getFields());
    }
    if (isMuxEnabled) {
        // remove earlier inserted Project Operator - since it creates issues down the road in HashJoin
        RelDataType removeRowType = RexUtil.createStructType(newPrel.getCluster().getTypeFactory(), removeUpdatedExpr, childFields);
        ProjectPrel removeColumnProjectPrel = new ProjectPrel(newPrel.getCluster(), newPrel.getTraitSet(), newPrel, removeUpdatedExpr, removeRowType);
        return removeColumnProjectPrel;
    }
    return newPrel;
}
Also used : ProjectPrel(org.apache.drill.exec.planner.physical.ProjectPrel) HashToRandomExchangePrel(org.apache.drill.exec.planner.physical.HashToRandomExchangePrel) RelDataType(org.apache.calcite.rel.type.RelDataType) ExchangePrel(org.apache.drill.exec.planner.physical.ExchangePrel) UnorderedDeMuxExchangePrel(org.apache.drill.exec.planner.physical.UnorderedDeMuxExchangePrel) UnorderedMuxExchangePrel(org.apache.drill.exec.planner.physical.UnorderedMuxExchangePrel) Prel(org.apache.drill.exec.planner.physical.Prel) HashToRandomExchangePrel(org.apache.drill.exec.planner.physical.HashToRandomExchangePrel) ProjectPrel(org.apache.drill.exec.planner.physical.ProjectPrel) RelDataTypeField(org.apache.calcite.rel.type.RelDataTypeField) RelNode(org.apache.calcite.rel.RelNode) UnorderedMuxExchangePrel(org.apache.drill.exec.planner.physical.UnorderedMuxExchangePrel) RexBuilder(org.apache.calcite.rex.RexBuilder) UnorderedDeMuxExchangePrel(org.apache.drill.exec.planner.physical.UnorderedDeMuxExchangePrel) DistributionField(org.apache.drill.exec.planner.physical.DrillDistributionTrait.DistributionField) RexNode(org.apache.calcite.rex.RexNode)

Example 42 with RelDataType

use of org.apache.calcite.rel.type.RelDataType in project drill by apache.

the class StarColumnConverter method visitProject.

@Override
public Prel visitProject(ProjectPrel prel, Void value) throws RuntimeException {
    ProjectPrel proj = (ProjectPrel) prel;
    // Require prefix rename : there exists other expression, in addition to a star column.
    if (// not set yet.
    !prefixedForStar && StarColumnHelper.containsStarColumnInProject(prel.getInput().getRowType(), proj.getProjects()) && prel.getRowType().getFieldNames().size() > 1) {
        prefixedForStar = true;
    }
    // For project, we need make sure that the project's field name is same as the input,
    // when the project expression is RexInPutRef, since we may insert a PAS which will
    // rename the projected fields.
    RelNode child = ((Prel) prel.getInput(0)).accept(this, null);
    List<String> fieldNames = Lists.newArrayList();
    for (Pair<String, RexNode> pair : Pair.zip(prel.getRowType().getFieldNames(), proj.getProjects())) {
        if (pair.right instanceof RexInputRef) {
            String name = child.getRowType().getFieldNames().get(((RexInputRef) pair.right).getIndex());
            fieldNames.add(name);
        } else {
            fieldNames.add(pair.left);
        }
    }
    // Make sure the field names are unique : no allow of duplicate field names in a rowType.
    fieldNames = makeUniqueNames(fieldNames);
    RelDataType rowType = RexUtil.createStructType(prel.getCluster().getTypeFactory(), proj.getProjects(), fieldNames);
    ProjectPrel newProj = (ProjectPrel) proj.copy(proj.getTraitSet(), child, proj.getProjects(), rowType);
    if (ProjectRemoveRule.isTrivial(newProj)) {
        return (Prel) child;
    } else {
        return newProj;
    }
}
Also used : ProjectPrel(org.apache.drill.exec.planner.physical.ProjectPrel) RelNode(org.apache.calcite.rel.RelNode) RexInputRef(org.apache.calcite.rex.RexInputRef) RelDataType(org.apache.calcite.rel.type.RelDataType) ScanPrel(org.apache.drill.exec.planner.physical.ScanPrel) Prel(org.apache.drill.exec.planner.physical.Prel) ProjectAllowDupPrel(org.apache.drill.exec.planner.physical.ProjectAllowDupPrel) ProjectPrel(org.apache.drill.exec.planner.physical.ProjectPrel) ScreenPrel(org.apache.drill.exec.planner.physical.ScreenPrel) WriterPrel(org.apache.drill.exec.planner.physical.WriterPrel) RexNode(org.apache.calcite.rex.RexNode)

Example 43 with RelDataType

use of org.apache.calcite.rel.type.RelDataType in project drill by apache.

the class DrillReduceAggregatesRule method reduceStddev.

private RexNode reduceStddev(Aggregate oldAggRel, AggregateCall oldCall, boolean biased, boolean sqrt, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, List<RexNode> inputExprs) {
    // stddev_pop(x) ==>
    //   power(
    //     (sum(x * x) - sum(x) * sum(x) / count(x))
    //     / count(x),
    //     .5)
    //
    // stddev_samp(x) ==>
    //   power(
    //     (sum(x * x) - sum(x) * sum(x) / count(x))
    //     / nullif(count(x) - 1, 0),
    //     .5)
    final PlannerSettings plannerSettings = (PlannerSettings) oldAggRel.getCluster().getPlanner().getContext();
    final boolean isInferenceEnabled = plannerSettings.isTypeInferenceEnabled();
    final int nGroups = oldAggRel.getGroupCount();
    RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
    final RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
    assert oldCall.getArgList().size() == 1 : oldCall.getArgList();
    final int argOrdinal = oldCall.getArgList().get(0);
    final RelDataType argType = getFieldType(oldAggRel.getInput(), argOrdinal);
    // final RexNode argRef = inputExprs.get(argOrdinal);
    RexNode argRef = rexBuilder.makeCall(CastHighOp, inputExprs.get(argOrdinal));
    inputExprs.set(argOrdinal, argRef);
    final RexNode argSquared = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, argRef, argRef);
    final int argSquaredOrdinal = lookupOrAdd(inputExprs, argSquared);
    final RelDataType sumType = typeFactory.createTypeWithNullability(argType, true);
    final AggregateCall sumArgSquaredAggCall = AggregateCall.create(new SqlSumAggFunction(sumType), oldCall.isDistinct(), ImmutableIntList.of(argSquaredOrdinal), -1, sumType, null);
    final RexNode sumArgSquared = rexBuilder.addAggCall(sumArgSquaredAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argType));
    final AggregateCall sumArgAggCall = AggregateCall.create(new SqlSumAggFunction(sumType), oldCall.isDistinct(), ImmutableIntList.of(argOrdinal), -1, sumType, null);
    final RexNode sumArg = rexBuilder.addAggCall(sumArgAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argType));
    final RexNode sumSquaredArg = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, sumArg, sumArg);
    final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT;
    final RelDataType countType = countAgg.getReturnType(typeFactory);
    final AggregateCall countArgAggCall = AggregateCall.create(countAgg, oldCall.isDistinct(), oldCall.getArgList(), -1, countType, null);
    final RexNode countArg = rexBuilder.addAggCall(countArgAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argType));
    final RexNode avgSumSquaredArg = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, sumSquaredArg, countArg);
    final RexNode diff = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, sumArgSquared, avgSumSquaredArg);
    final RexNode denominator;
    if (biased) {
        denominator = countArg;
    } else {
        final RexLiteral one = rexBuilder.makeExactLiteral(BigDecimal.ONE);
        final RexNode nul = rexBuilder.makeNullLiteral(countArg.getType().getSqlTypeName());
        final RexNode countMinusOne = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, countArg, one);
        final RexNode countEqOne = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, countArg, one);
        denominator = rexBuilder.makeCall(SqlStdOperatorTable.CASE, countEqOne, nul, countMinusOne);
    }
    final SqlOperator divide;
    if (isInferenceEnabled) {
        divide = new DrillSqlOperator("divide", 2, true, oldCall.getType(), false);
    } else {
        divide = SqlStdOperatorTable.DIVIDE;
    }
    final RexNode div = rexBuilder.makeCall(divide, diff, denominator);
    RexNode result = div;
    if (sqrt) {
        final RexNode half = rexBuilder.makeExactLiteral(new BigDecimal("0.5"));
        result = rexBuilder.makeCall(SqlStdOperatorTable.POWER, div, half);
    }
    if (isInferenceEnabled) {
        return result;
    } else {
        /*
      * Currently calcite's strategy to infer the return type of aggregate functions
      * is wrong because it uses the first known argument to determine output type. For
      * instance if we are performing stddev on an integer column then it interprets the
      * output type to be integer which is incorrect as it should be double. So based on
      * this if we add cast after rewriting the aggregate we add an additional cast which
      * would cause wrong results. So we simply add a cast to ANY.
      */
        return rexBuilder.makeCast(typeFactory.createSqlType(SqlTypeName.ANY), result);
    }
}
Also used : RexLiteral(org.apache.calcite.rex.RexLiteral) PlannerSettings(org.apache.drill.exec.planner.physical.PlannerSettings) DrillSqlOperator(org.apache.drill.exec.planner.sql.DrillSqlOperator) SqlOperator(org.apache.calcite.sql.SqlOperator) RelDataType(org.apache.calcite.rel.type.RelDataType) SqlCountAggFunction(org.apache.calcite.sql.fun.SqlCountAggFunction) BigDecimal(java.math.BigDecimal) DrillSqlOperator(org.apache.drill.exec.planner.sql.DrillSqlOperator) AggregateCall(org.apache.calcite.rel.core.AggregateCall) RelDataTypeFactory(org.apache.calcite.rel.type.RelDataTypeFactory) SqlSumAggFunction(org.apache.calcite.sql.fun.SqlSumAggFunction) RexBuilder(org.apache.calcite.rex.RexBuilder) RexNode(org.apache.calcite.rex.RexNode)

Example 44 with RelDataType

use of org.apache.calcite.rel.type.RelDataType in project drill by apache.

the class DrillReduceAggregatesRule method reduceSum.

private RexNode reduceSum(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping) {
    final PlannerSettings plannerSettings = (PlannerSettings) oldAggRel.getCluster().getPlanner().getContext();
    final boolean isInferenceEnabled = plannerSettings.isTypeInferenceEnabled();
    final int nGroups = oldAggRel.getGroupCount();
    RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
    RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
    int arg = oldCall.getArgList().get(0);
    RelDataType argType = getFieldType(oldAggRel.getInput(), arg);
    final RelDataType sumType;
    final SqlAggFunction sumZeroAgg;
    if (isInferenceEnabled) {
        sumType = oldCall.getType();
        sumZeroAgg = new DrillCalciteSqlAggFunctionWrapper(new SqlSumEmptyIsZeroAggFunction(), sumType);
    } else {
        sumType = typeFactory.createTypeWithNullability(argType, argType.isNullable());
        sumZeroAgg = new SqlSumEmptyIsZeroAggFunction();
    }
    AggregateCall sumZeroCall = AggregateCall.create(sumZeroAgg, oldCall.isDistinct(), oldCall.getArgList(), -1, sumType, null);
    final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT;
    final RelDataType countType = countAgg.getReturnType(typeFactory);
    AggregateCall countCall = AggregateCall.create(countAgg, oldCall.isDistinct(), oldCall.getArgList(), -1, countType, null);
    // NOTE:  these references are with respect to the output
    // of newAggRel
    RexNode sumZeroRef = rexBuilder.addAggCall(sumZeroCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argType));
    if (!oldCall.getType().isNullable()) {
        // null). Therefore we translate to SUM0(x).
        return sumZeroRef;
    }
    RexNode countRef = rexBuilder.addAggCall(countCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argType));
    return rexBuilder.makeCall(SqlStdOperatorTable.CASE, rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, countRef, rexBuilder.makeExactLiteral(BigDecimal.ZERO)), rexBuilder.constantNull(), sumZeroRef);
}
Also used : SqlSumEmptyIsZeroAggFunction(org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction) AggregateCall(org.apache.calcite.rel.core.AggregateCall) PlannerSettings(org.apache.drill.exec.planner.physical.PlannerSettings) RelDataTypeFactory(org.apache.calcite.rel.type.RelDataTypeFactory) RexBuilder(org.apache.calcite.rex.RexBuilder) RelDataType(org.apache.calcite.rel.type.RelDataType) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) SqlCountAggFunction(org.apache.calcite.sql.fun.SqlCountAggFunction) DrillCalciteSqlAggFunctionWrapper(org.apache.drill.exec.planner.sql.DrillCalciteSqlAggFunctionWrapper) RexNode(org.apache.calcite.rex.RexNode)

Example 45 with RelDataType

use of org.apache.calcite.rel.type.RelDataType in project drill by apache.

the class DrillReduceAggregatesRule method reduceAgg.

private RexNode reduceAgg(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, List<RexNode> inputExprs) {
    final SqlAggFunction sqlAggFunction = DrillCalciteWrapperUtility.extractSqlOperatorFromWrapper(oldCall.getAggregation());
    if (sqlAggFunction instanceof SqlSumAggFunction) {
        // case COUNT(x) when 0 then null else SUM0(x) end
        return reduceSum(oldAggRel, oldCall, newCalls, aggCallMapping);
    }
    if (sqlAggFunction instanceof SqlAvgAggFunction) {
        final SqlAvgAggFunction.Subtype subtype = ((SqlAvgAggFunction) sqlAggFunction).getSubtype();
        switch(subtype) {
            case AVG:
                // replace original AVG(x) with SUM(x) / COUNT(x)
                return reduceAvg(oldAggRel, oldCall, newCalls, aggCallMapping);
            case STDDEV_POP:
                //     / COUNT(x))
                return reduceStddev(oldAggRel, oldCall, true, true, newCalls, aggCallMapping, inputExprs);
            case STDDEV_SAMP:
                //     / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1 END)
                return reduceStddev(oldAggRel, oldCall, false, true, newCalls, aggCallMapping, inputExprs);
            case VAR_POP:
                //     / COUNT(x)
                return reduceStddev(oldAggRel, oldCall, true, false, newCalls, aggCallMapping, inputExprs);
            case VAR_SAMP:
                //     / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1 END
                return reduceStddev(oldAggRel, oldCall, false, false, newCalls, aggCallMapping, inputExprs);
            default:
                throw Util.unexpected(subtype);
        }
    } else {
        // anything else:  preserve original call
        RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
        final int nGroups = oldAggRel.getGroupCount();
        List<RelDataType> oldArgTypes = new ArrayList<>();
        List<Integer> ordinals = oldCall.getArgList();
        assert ordinals.size() <= inputExprs.size();
        for (int ordinal : ordinals) {
            oldArgTypes.add(inputExprs.get(ordinal).getType());
        }
        return rexBuilder.addAggCall(oldCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, oldArgTypes);
    }
}
Also used : SqlAvgAggFunction(org.apache.calcite.sql.fun.SqlAvgAggFunction) SqlSumAggFunction(org.apache.calcite.sql.fun.SqlSumAggFunction) ArrayList(java.util.ArrayList) RexBuilder(org.apache.calcite.rex.RexBuilder) RelDataType(org.apache.calcite.rel.type.RelDataType) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction)

Aggregations

RelDataType (org.apache.calcite.rel.type.RelDataType)88 RexNode (org.apache.calcite.rex.RexNode)48 RexBuilder (org.apache.calcite.rex.RexBuilder)28 RelNode (org.apache.calcite.rel.RelNode)27 RelDataTypeField (org.apache.calcite.rel.type.RelDataTypeField)25 ArrayList (java.util.ArrayList)21 RexInputRef (org.apache.calcite.rex.RexInputRef)16 RelDataTypeFactory (org.apache.calcite.rel.type.RelDataTypeFactory)14 AggregateCall (org.apache.calcite.rel.core.AggregateCall)13 ImmutableList (com.google.common.collect.ImmutableList)9 BigDecimal (java.math.BigDecimal)8 CalciteSemanticException (org.apache.hadoop.hive.ql.optimizer.calcite.CalciteSemanticException)8 SqlAggFunction (org.apache.calcite.sql.SqlAggFunction)7 ImmutableBitSet (org.apache.calcite.util.ImmutableBitSet)7 RelOptCluster (org.apache.calcite.plan.RelOptCluster)6 RelBuilder (org.apache.calcite.tools.RelBuilder)6 Prel (org.apache.drill.exec.planner.physical.Prel)6 ProjectPrel (org.apache.drill.exec.planner.physical.ProjectPrel)6 Builder (com.google.common.collect.ImmutableList.Builder)5 LinkedList (java.util.LinkedList)5