use of org.apache.calcite.rex.RexLiteral in project hive by apache.
the class HiveIntersectRewriteRule method onMatch.
// ~ Methods ----------------------------------------------------------------
public void onMatch(RelOptRuleCall call) {
final HiveIntersect hiveIntersect = call.rel(0);
final RelOptCluster cluster = hiveIntersect.getCluster();
final RexBuilder rexBuilder = cluster.getRexBuilder();
int numOfBranch = hiveIntersect.getInputs().size();
Builder<RelNode> bldr = new ImmutableList.Builder<RelNode>();
// 1st level GB: create a GB (col0, col1, count(1) as c) for each branch
for (int index = 0; index < numOfBranch; index++) {
RelNode input = hiveIntersect.getInputs().get(index);
final List<RexNode> gbChildProjLst = Lists.newArrayList();
final List<Integer> groupSetPositions = Lists.newArrayList();
for (int cInd = 0; cInd < input.getRowType().getFieldList().size(); cInd++) {
gbChildProjLst.add(rexBuilder.makeInputRef(input, cInd));
groupSetPositions.add(cInd);
}
gbChildProjLst.add(rexBuilder.makeBigintLiteral(new BigDecimal(1)));
// create the project before GB because we need a new project with extra column '1'.
RelNode gbInputRel = null;
try {
gbInputRel = HiveProject.create(input, gbChildProjLst, null);
} catch (CalciteSemanticException e) {
LOG.debug(e.toString());
throw new RuntimeException(e);
}
// groupSetPosition includes all the positions
final ImmutableBitSet groupSet = ImmutableBitSet.of(groupSetPositions);
List<AggregateCall> aggregateCalls = Lists.newArrayList();
RelDataType aggFnRetType = TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory());
// count(1), 1's position is input.getRowType().getFieldList().size()
AggregateCall aggregateCall = HiveCalciteUtil.createSingleArgAggCall("count", cluster, TypeInfoFactory.longTypeInfo, input.getRowType().getFieldList().size(), aggFnRetType);
aggregateCalls.add(aggregateCall);
HiveRelNode aggregateRel = new HiveAggregate(cluster, cluster.traitSetOf(HiveRelNode.CONVENTION), gbInputRel, groupSet, null, aggregateCalls);
bldr.add(aggregateRel);
}
// create a union above all the branches
HiveRelNode union = new HiveUnion(cluster, TraitsUtil.getDefaultTraitSet(cluster), bldr.build());
// 2nd level GB: create a GB (col0, col1, count(c)) for each branch
final List<Integer> groupSetPositions = Lists.newArrayList();
// the index of c
int cInd = union.getRowType().getFieldList().size() - 1;
for (int index = 0; index < union.getRowType().getFieldList().size(); index++) {
if (index != cInd) {
groupSetPositions.add(index);
}
}
List<AggregateCall> aggregateCalls = Lists.newArrayList();
RelDataType aggFnRetType = TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory());
AggregateCall aggregateCall = HiveCalciteUtil.createSingleArgAggCall("count", cluster, TypeInfoFactory.longTypeInfo, cInd, aggFnRetType);
aggregateCalls.add(aggregateCall);
if (hiveIntersect.all) {
aggregateCall = HiveCalciteUtil.createSingleArgAggCall("min", cluster, TypeInfoFactory.longTypeInfo, cInd, aggFnRetType);
aggregateCalls.add(aggregateCall);
}
final ImmutableBitSet groupSet = ImmutableBitSet.of(groupSetPositions);
HiveRelNode aggregateRel = new HiveAggregate(cluster, cluster.traitSetOf(HiveRelNode.CONVENTION), union, groupSet, null, aggregateCalls);
// add a filter count(c) = #branches
int countInd = cInd;
List<RexNode> childRexNodeLst = new ArrayList<RexNode>();
RexInputRef ref = rexBuilder.makeInputRef(aggregateRel, countInd);
RexLiteral literal = rexBuilder.makeBigintLiteral(new BigDecimal(numOfBranch));
childRexNodeLst.add(ref);
childRexNodeLst.add(literal);
ImmutableList.Builder<RelDataType> calciteArgTypesBldr = new ImmutableList.Builder<RelDataType>();
calciteArgTypesBldr.add(TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory()));
calciteArgTypesBldr.add(TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory()));
RexNode factoredFilterExpr = null;
try {
factoredFilterExpr = rexBuilder.makeCall(SqlFunctionConverter.getCalciteFn("=", calciteArgTypesBldr.build(), TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory()), true, false), childRexNodeLst);
} catch (CalciteSemanticException e) {
LOG.debug(e.toString());
throw new RuntimeException(e);
}
RelNode filterRel = new HiveFilter(cluster, cluster.traitSetOf(HiveRelNode.CONVENTION), aggregateRel, factoredFilterExpr);
if (!hiveIntersect.all) {
// the schema for intersect distinct is like this
// R3 on all attributes + count(c) as cnt
// finally add a project to project out the last column
Set<Integer> projectOutColumnPositions = new HashSet<>();
projectOutColumnPositions.add(filterRel.getRowType().getFieldList().size() - 1);
try {
call.transformTo(HiveCalciteUtil.createProjectWithoutColumn(filterRel, projectOutColumnPositions));
} catch (CalciteSemanticException e) {
LOG.debug(e.toString());
throw new RuntimeException(e);
}
} else {
// the schema for intersect all is like this
// R3 + count(c) as cnt + min(c) as m
// we create a input project for udtf whose schema is like this
// min(c) as m + R3
List<RexNode> originalInputRefs = Lists.transform(filterRel.getRowType().getFieldList(), new Function<RelDataTypeField, RexNode>() {
@Override
public RexNode apply(RelDataTypeField input) {
return new RexInputRef(input.getIndex(), input.getType());
}
});
List<RexNode> copyInputRefs = new ArrayList<>();
copyInputRefs.add(originalInputRefs.get(originalInputRefs.size() - 1));
for (int i = 0; i < originalInputRefs.size() - 2; i++) {
copyInputRefs.add(originalInputRefs.get(i));
}
RelNode srcRel = null;
try {
srcRel = HiveProject.create(filterRel, copyInputRefs, null);
HiveTableFunctionScan udtf = HiveCalciteUtil.createUDTFForSetOp(cluster, srcRel);
// finally add a project to project out the 1st column
Set<Integer> projectOutColumnPositions = new HashSet<>();
projectOutColumnPositions.add(0);
call.transformTo(HiveCalciteUtil.createProjectWithoutColumn(udtf, projectOutColumnPositions));
} catch (SemanticException e) {
LOG.debug(e.toString());
throw new RuntimeException(e);
}
}
}
use of org.apache.calcite.rex.RexLiteral in project hive by apache.
the class RelFieldTrimmer method trimFields.
/**
* Variant of {@link #trimFields(RelNode, ImmutableBitSet, Set)} for
* {@link org.apache.calcite.rel.logical.LogicalValues}.
*/
public TrimResult trimFields(LogicalValues values, ImmutableBitSet fieldsUsed, Set<RelDataTypeField> extraFields) {
final RelDataType rowType = values.getRowType();
final int fieldCount = rowType.getFieldCount();
// which is unlikely to be a system field.
if (fieldsUsed.isEmpty()) {
fieldsUsed = ImmutableBitSet.range(fieldCount - 1, fieldCount);
}
// If all fields are used, return unchanged.
if (fieldsUsed.equals(ImmutableBitSet.range(fieldCount))) {
Mapping mapping = Mappings.createIdentity(fieldCount);
return result(values, mapping);
}
final ImmutableList.Builder<ImmutableList<RexLiteral>> newTuples = ImmutableList.builder();
for (ImmutableList<RexLiteral> tuple : values.getTuples()) {
ImmutableList.Builder<RexLiteral> newTuple = ImmutableList.builder();
for (int field : fieldsUsed) {
newTuple.add(tuple.get(field));
}
newTuples.add(newTuple.build());
}
final Mapping mapping = createMapping(fieldsUsed, fieldCount);
final RelDataType newRowType = RelOptUtil.permute(values.getCluster().getTypeFactory(), rowType, mapping);
final LogicalValues newValues = LogicalValues.create(values.getCluster(), newRowType, newTuples.build());
return result(newValues, mapping);
}
use of org.apache.calcite.rex.RexLiteral in project hive by apache.
the class HiveExceptRewriteRule method makeFilterExprForExceptDistinct.
private RexNode makeFilterExprForExceptDistinct(HiveRelNode input, int columnSize, RelOptCluster cluster, RexBuilder rexBuilder) throws CalciteSemanticException {
List<RexNode> childRexNodeLst = new ArrayList<RexNode>();
RexInputRef a = rexBuilder.makeInputRef(input, columnSize - 2);
RexLiteral zero = rexBuilder.makeBigintLiteral(new BigDecimal(0));
childRexNodeLst.add(a);
childRexNodeLst.add(zero);
ImmutableList.Builder<RelDataType> calciteArgTypesBldr = new ImmutableList.Builder<RelDataType>();
calciteArgTypesBldr.add(TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory()));
calciteArgTypesBldr.add(TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory()));
// a>0
RexNode aMorethanZero = rexBuilder.makeCall(SqlFunctionConverter.getCalciteFn(">", calciteArgTypesBldr.build(), TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory()), false, false), childRexNodeLst);
childRexNodeLst = new ArrayList<RexNode>();
RexLiteral two = rexBuilder.makeBigintLiteral(new BigDecimal(2));
childRexNodeLst.add(a);
childRexNodeLst.add(two);
// 2*a
RexNode twoa = rexBuilder.makeCall(SqlFunctionConverter.getCalciteFn("*", calciteArgTypesBldr.build(), TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory()), false, false), childRexNodeLst);
childRexNodeLst = new ArrayList<RexNode>();
RexInputRef b = rexBuilder.makeInputRef(input, columnSize - 1);
childRexNodeLst.add(twoa);
childRexNodeLst.add(b);
// 2a=b
RexNode twoaEqualTob = rexBuilder.makeCall(SqlFunctionConverter.getCalciteFn("=", calciteArgTypesBldr.build(), TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory()), false, false), childRexNodeLst);
childRexNodeLst = new ArrayList<RexNode>();
childRexNodeLst.add(aMorethanZero);
childRexNodeLst.add(twoaEqualTob);
// a>0 && 2a=b
return rexBuilder.makeCall(SqlFunctionConverter.getCalciteFn("and", calciteArgTypesBldr.build(), TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory()), false, false), childRexNodeLst);
}
use of org.apache.calcite.rex.RexLiteral in project hive by apache.
the class HiveAggregateReduceFunctionsRule method reduceStddev.
private RexNode reduceStddev(Aggregate oldAggRel, AggregateCall oldCall, boolean biased, boolean sqrt, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, List<RexNode> inputExprs) {
// stddev_pop(x) ==>
// power(
// (sum(x * x) - sum(x) * sum(x) / count(x))
// / count(x),
// .5)
//
// stddev_samp(x) ==>
// power(
// (sum(x * x) - sum(x) * sum(x) / count(x))
// / nullif(count(x) - 1, 0),
// .5)
final int nGroups = oldAggRel.getGroupCount();
final RelOptCluster cluster = oldAggRel.getCluster();
final RexBuilder rexBuilder = cluster.getRexBuilder();
final RelDataTypeFactory typeFactory = cluster.getTypeFactory();
assert oldCall.getArgList().size() == 1 : oldCall.getArgList();
final int argOrdinal = oldCall.getArgList().get(0);
final RelDataType argOrdinalType = getFieldType(oldAggRel.getInput(), argOrdinal);
final RelDataType oldCallType = typeFactory.createTypeWithNullability(oldCall.getType(), true);
final RexNode argRef = rexBuilder.ensureType(oldCallType, inputExprs.get(argOrdinal), false);
final int argRefOrdinal = lookupOrAdd(inputExprs, argRef);
final RelDataType sumReturnType = getSumReturnType(rexBuilder.getTypeFactory(), argRef.getType());
final RexNode argSquared = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, argRef, argRef);
final int argSquaredOrdinal = lookupOrAdd(inputExprs, argSquared);
final RelDataType sumSquaredReturnType = getSumReturnType(rexBuilder.getTypeFactory(), argSquared.getType());
final AggregateCall sumArgSquaredAggCall = createAggregateCallWithBinding(typeFactory, new HiveSqlSumAggFunction(oldCall.isDistinct(), ReturnTypes.explicit(sumSquaredReturnType), InferTypes.explicit(Collections.singletonList(argSquared.getType())), // SqlStdOperatorTable.SUM,
oldCall.getAggregation().getOperandTypeChecker()), argSquared.getType(), oldAggRel, oldCall, argSquaredOrdinal);
final RexNode sumArgSquared = rexBuilder.addAggCall(sumArgSquaredAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(sumArgSquaredAggCall.getType()));
final AggregateCall sumArgAggCall = AggregateCall.create(new HiveSqlSumAggFunction(oldCall.isDistinct(), ReturnTypes.explicit(sumReturnType), InferTypes.explicit(Collections.singletonList(argOrdinalType)), // SqlStdOperatorTable.SUM,
oldCall.getAggregation().getOperandTypeChecker()), oldCall.isDistinct(), oldCall.isApproximate(), ImmutableIntList.of(argRefOrdinal), oldCall.filterArg, oldAggRel.getGroupCount(), oldAggRel.getInput(), null, null);
final RexNode sumArg = rexBuilder.addAggCall(sumArgAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(sumArgAggCall.getType()));
final RexNode sumArgCast = rexBuilder.ensureType(oldCallType, sumArg, true);
final RexNode sumSquaredArg = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, sumArgCast, sumArgCast);
RelDataType countRetType = typeFactory.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BIGINT), true);
final AggregateCall countArgAggCall = AggregateCall.create(new HiveSqlCountAggFunction(oldCall.isDistinct(), ReturnTypes.explicit(countRetType), oldCall.getAggregation().getOperandTypeInference(), // SqlStdOperatorTable.COUNT,
oldCall.getAggregation().getOperandTypeChecker()), oldCall.isDistinct(), oldCall.isApproximate(), oldCall.getArgList(), oldCall.filterArg, oldAggRel.getGroupCount(), oldAggRel.getInput(), countRetType, null);
final RexNode countArg = rexBuilder.addAggCall(countArgAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argOrdinalType));
final RexNode avgSumSquaredArg = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, sumSquaredArg, countArg);
final RexNode diff = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, sumArgSquared, avgSumSquaredArg);
final RexNode denominator;
if (biased) {
denominator = countArg;
} else {
final RexLiteral one = rexBuilder.makeExactLiteral(BigDecimal.ONE);
final RexNode nul = rexBuilder.makeCast(countArg.getType(), rexBuilder.constantNull());
final RexNode countMinusOne = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, countArg, one);
final RexNode countEqOne = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, countArg, one);
denominator = rexBuilder.makeCall(SqlStdOperatorTable.CASE, countEqOne, nul, countMinusOne);
}
final RexNode div = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, diff, denominator);
RexNode result = div;
if (sqrt) {
final RexNode half = rexBuilder.makeExactLiteral(new BigDecimal("0.5"));
result = rexBuilder.makeCall(SqlStdOperatorTable.POWER, div, half);
}
return rexBuilder.makeCast(oldCall.getType(), result);
}
use of org.apache.calcite.rex.RexLiteral in project hive by apache.
the class HiveFunctionHelper method createConstantObjectInspector.
/**
* Returns constant object inspector or null if it could not be generated.
*/
private ConstantObjectInspector createConstantObjectInspector(RexNode expr) {
if (RexUtil.isLiteral(expr, true)) {
// Literal or cast on literal
final ExprNodeConstantDesc constant;
if (expr.isA(SqlKind.LITERAL)) {
constant = ExprNodeConverter.toExprNodeConstantDesc((RexLiteral) expr);
} else {
RexNode foldedExpr = foldExpression(expr);
if (!foldedExpr.isA(SqlKind.LITERAL)) {
// Constant could not be generated
return null;
}
constant = ExprNodeConverter.toExprNodeConstantDesc((RexLiteral) foldedExpr);
}
PrimitiveTypeInfo typeInfo = (PrimitiveTypeInfo) constant.getTypeInfo();
Object value = constant.getValue();
Object writableValue = value == null ? null : PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(typeInfo).getPrimitiveWritableObject(value);
return PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(typeInfo, writableValue);
} else if (expr instanceof RexCall) {
RexCall call = (RexCall) expr;
if (call.getOperator() == SqlStdOperatorTable.ROW) {
// Struct
List<String> fieldNames = new ArrayList<>();
List<ObjectInspector> fieldObjectInspectors = new ArrayList<>();
List<Object> writableValues = new ArrayList<>();
for (int i = 0; i < call.getOperands().size(); i++) {
RexNode input = call.getOperands().get(i);
ConstantObjectInspector objectInspector = createConstantObjectInspector(input);
if (objectInspector == null) {
// Constant could not be generated
return null;
}
fieldNames.add(expr.getType().getFieldList().get(i).getName());
fieldObjectInspectors.add(objectInspector);
writableValues.add(objectInspector.getWritableConstantValue());
}
return ObjectInspectorFactory.getStandardConstantStructObjectInspector(fieldNames, fieldObjectInspectors, writableValues);
} else if (call.getOperator() == SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR) {
// List
ListTypeInfo listTypeInfo = (ListTypeInfo) TypeConverter.convert(expr.getType());
TypeInfo typeInfo = listTypeInfo.getListElementTypeInfo();
List<Object> writableValues = new ArrayList<>();
for (RexNode input : call.getOperands()) {
ConstantObjectInspector objectInspector = createConstantObjectInspector(input);
if (objectInspector == null) {
// Constant could not be generated
return null;
}
writableValues.add(objectInspector.getWritableConstantValue());
}
return ObjectInspectorFactory.getStandardConstantListObjectInspector(ObjectInspectorUtils.getStandardObjectInspector(TypeInfoUtils.getStandardJavaObjectInspectorFromTypeInfo(typeInfo), ObjectInspectorCopyOption.WRITABLE), writableValues);
} else if (call.getOperator() == SqlStdOperatorTable.MAP_VALUE_CONSTRUCTOR) {
// Map
MapTypeInfo mapTypeInfo = (MapTypeInfo) TypeConverter.convert(expr.getType());
Map<Object, Object> writableValues = new HashMap<>();
Iterator<RexNode> it = call.getOperands().iterator();
while (it.hasNext()) {
ConstantObjectInspector keyObjectInspector = createConstantObjectInspector(it.next());
if (keyObjectInspector == null) {
// Constant could not be generated
return null;
}
ConstantObjectInspector valueObjectInspector = createConstantObjectInspector(it.next());
if (valueObjectInspector == null) {
// Constant could not be generated
return null;
}
writableValues.put(keyObjectInspector.getWritableConstantValue(), valueObjectInspector.getWritableConstantValue());
}
return ObjectInspectorFactory.getStandardConstantMapObjectInspector(ObjectInspectorUtils.getStandardObjectInspector(TypeInfoUtils.getStandardJavaObjectInspectorFromTypeInfo(mapTypeInfo.getMapKeyTypeInfo()), ObjectInspectorCopyOption.WRITABLE), ObjectInspectorUtils.getStandardObjectInspector(TypeInfoUtils.getStandardJavaObjectInspectorFromTypeInfo(mapTypeInfo.getMapValueTypeInfo()), ObjectInspectorCopyOption.WRITABLE), writableValues);
}
}
// Constant could not be generated
return null;
}
Aggregations