use of org.apache.hadoop.hive.ql.exec.vector.VectorizedUDAFs in project hive by apache.
the class Vectorizer method getVectorAggregationDesc.
private static ImmutablePair<VectorAggregationDesc, String> getVectorAggregationDesc(AggregationDesc aggrDesc, VectorizationContext vContext) throws HiveException {
String aggregateName = aggrDesc.getGenericUDAFName();
ArrayList<ExprNodeDesc> parameterList = aggrDesc.getParameters();
final int parameterCount = parameterList.size();
final GenericUDAFEvaluator.Mode udafEvaluatorMode = aggrDesc.getMode();
/*
* Look at evaluator to get output type info.
*/
GenericUDAFEvaluator evaluator = aggrDesc.getGenericUDAFEvaluator();
ArrayList<ExprNodeDesc> parameters = aggrDesc.getParameters();
ObjectInspector[] parameterObjectInspectors = new ObjectInspector[parameterCount];
for (int i = 0; i < parameterCount; i++) {
TypeInfo typeInfo = parameters.get(i).getTypeInfo();
parameterObjectInspectors[i] = TypeInfoUtils.getStandardWritableObjectInspectorFromTypeInfo(typeInfo);
}
// The only way to get the return object inspector (and its return type) is to
// initialize it...
ObjectInspector returnOI = evaluator.init(aggrDesc.getMode(), parameterObjectInspectors);
VectorizedUDAFs annotation = AnnotationUtils.getAnnotation(evaluator.getClass(), VectorizedUDAFs.class);
if (annotation == null) {
String issue = "Evaluator " + evaluator.getClass().getSimpleName() + " does not have a " + "vectorized UDAF annotation (aggregation: \"" + aggregateName + "\"). " + "Vectorization not supported";
return new ImmutablePair<VectorAggregationDesc, String>(null, issue);
}
final Class<? extends VectorAggregateExpression>[] vecAggrClasses = annotation.value();
final TypeInfo outputTypeInfo = TypeInfoUtils.getTypeInfoFromTypeString(returnOI.getTypeName());
// Not final since it may change later due to DECIMAL_64.
ColumnVector.Type outputColVectorType = VectorizationContext.getColumnVectorTypeFromTypeInfo(outputTypeInfo);
/*
* Determine input type info.
*/
final TypeInfo inputTypeInfo;
// Not final since it may change later due to DECIMAL_64.
VectorExpression inputExpression;
ColumnVector.Type inputColVectorType;
if (parameterCount == 0) {
// COUNT(*)
inputTypeInfo = null;
inputColVectorType = null;
inputExpression = null;
} else if (parameterCount == 1) {
ExprNodeDesc exprNodeDesc = parameterList.get(0);
inputTypeInfo = exprNodeDesc.getTypeInfo();
if (inputTypeInfo == null) {
String issue = "Aggregations with null parameter type not supported " + aggregateName + "(" + parameterList.toString() + ")";
return new ImmutablePair<VectorAggregationDesc, String>(null, issue);
}
/*
* Determine an *initial* input vector expression.
*
* Note: we may have to convert it later from DECIMAL_64 to regular decimal.
*/
inputExpression = vContext.getVectorExpression(exprNodeDesc, VectorExpressionDescriptor.Mode.PROJECTION);
if (inputExpression == null) {
String issue = "Parameter expression " + exprNodeDesc.toString() + " not supported " + aggregateName + "(" + parameterList.toString() + ")";
return new ImmutablePair<VectorAggregationDesc, String>(null, issue);
}
if (inputExpression.getOutputTypeInfo() == null) {
String issue = "Parameter expression " + exprNodeDesc.toString() + " with null type not supported " + aggregateName + "(" + parameterList.toString() + ")";
return new ImmutablePair<VectorAggregationDesc, String>(null, issue);
}
inputColVectorType = inputExpression.getOutputColumnVectorType();
} else {
// No multi-parameter aggregations supported.
String issue = "Aggregations with > 1 parameter are not supported " + aggregateName + "(" + parameterList.toString() + ")";
return new ImmutablePair<VectorAggregationDesc, String>(null, issue);
}
/*
* When we have DECIMAL_64 as the input parameter then we have to see if there is a special
* vector UDAF for it. If not we will need to convert the input parameter.
*/
if (inputTypeInfo != null && inputColVectorType == ColumnVector.Type.DECIMAL_64) {
if (outputColVectorType == ColumnVector.Type.DECIMAL) {
DecimalTypeInfo outputDecimalTypeInfo = (DecimalTypeInfo) outputTypeInfo;
if (HiveDecimalWritable.isPrecisionDecimal64(outputDecimalTypeInfo.getPrecision())) {
// Try with DECIMAL_64 input and DECIMAL_64 output.
final Class<? extends VectorAggregateExpression> vecAggrClass = findVecAggrClass(vecAggrClasses, aggregateName, inputColVectorType, ColumnVector.Type.DECIMAL_64, udafEvaluatorMode);
if (vecAggrClass != null) {
final VectorAggregationDesc vecAggrDesc = new VectorAggregationDesc(aggrDesc, evaluator, inputTypeInfo, inputColVectorType, inputExpression, outputTypeInfo, ColumnVector.Type.DECIMAL_64, vecAggrClass);
return new ImmutablePair<VectorAggregationDesc, String>(vecAggrDesc, null);
}
}
// Try with regular DECIMAL output type.
final Class<? extends VectorAggregateExpression> vecAggrClass = findVecAggrClass(vecAggrClasses, aggregateName, inputColVectorType, outputColVectorType, udafEvaluatorMode);
if (vecAggrClass != null) {
final VectorAggregationDesc vecAggrDesc = new VectorAggregationDesc(aggrDesc, evaluator, inputTypeInfo, inputColVectorType, inputExpression, outputTypeInfo, outputColVectorType, vecAggrClass);
return new ImmutablePair<VectorAggregationDesc, String>(vecAggrDesc, null);
}
// No support for DECIMAL_64 input. We must convert.
inputExpression = vContext.wrapWithDecimal64ToDecimalConversion(inputExpression);
inputColVectorType = ColumnVector.Type.DECIMAL;
// Fall through...
} else {
// Try with with DECIMAL_64 input and desired output type.
final Class<? extends VectorAggregateExpression> vecAggrClass = findVecAggrClass(vecAggrClasses, aggregateName, inputColVectorType, outputColVectorType, udafEvaluatorMode);
if (vecAggrClass != null) {
final VectorAggregationDesc vecAggrDesc = new VectorAggregationDesc(aggrDesc, evaluator, inputTypeInfo, inputColVectorType, inputExpression, outputTypeInfo, outputColVectorType, vecAggrClass);
return new ImmutablePair<VectorAggregationDesc, String>(vecAggrDesc, null);
}
// No support for DECIMAL_64 input. We must convert.
inputExpression = vContext.wrapWithDecimal64ToDecimalConversion(inputExpression);
inputColVectorType = ColumnVector.Type.DECIMAL;
// Fall through...
}
}
/*
* Look for normal match.
*/
Class<? extends VectorAggregateExpression> vecAggrClass = findVecAggrClass(vecAggrClasses, aggregateName, inputColVectorType, outputColVectorType, udafEvaluatorMode);
if (vecAggrClass != null) {
final VectorAggregationDesc vecAggrDesc = new VectorAggregationDesc(aggrDesc, evaluator, inputTypeInfo, inputColVectorType, inputExpression, outputTypeInfo, outputColVectorType, vecAggrClass);
return new ImmutablePair<VectorAggregationDesc, String>(vecAggrDesc, null);
}
// No match?
String issue = "Vector aggregation : \"" + aggregateName + "\" " + "for input type: " + (inputColVectorType == null ? "any" : "\"" + inputColVectorType) + "\" " + "and output type: \"" + outputColVectorType + "\" " + "and mode: " + udafEvaluatorMode + " not supported for " + "evaluator " + evaluator.getClass().getSimpleName();
return new ImmutablePair<VectorAggregationDesc, String>(null, issue);
}
Aggregations