Search in sources :

Example 1 with VectorizedUDAFs

use of org.apache.hadoop.hive.ql.exec.vector.VectorizedUDAFs in project hive by apache.

the class Vectorizer method getVectorAggregationDesc.

private static ImmutablePair<VectorAggregationDesc, String> getVectorAggregationDesc(AggregationDesc aggrDesc, VectorizationContext vContext) throws HiveException {
    String aggregateName = aggrDesc.getGenericUDAFName();
    ArrayList<ExprNodeDesc> parameterList = aggrDesc.getParameters();
    final int parameterCount = parameterList.size();
    final GenericUDAFEvaluator.Mode udafEvaluatorMode = aggrDesc.getMode();
    /*
     * Look at evaluator to get output type info.
     */
    GenericUDAFEvaluator evaluator = aggrDesc.getGenericUDAFEvaluator();
    ArrayList<ExprNodeDesc> parameters = aggrDesc.getParameters();
    ObjectInspector[] parameterObjectInspectors = new ObjectInspector[parameterCount];
    for (int i = 0; i < parameterCount; i++) {
        TypeInfo typeInfo = parameters.get(i).getTypeInfo();
        parameterObjectInspectors[i] = TypeInfoUtils.getStandardWritableObjectInspectorFromTypeInfo(typeInfo);
    }
    // The only way to get the return object inspector (and its return type) is to
    // initialize it...
    ObjectInspector returnOI = evaluator.init(aggrDesc.getMode(), parameterObjectInspectors);
    VectorizedUDAFs annotation = AnnotationUtils.getAnnotation(evaluator.getClass(), VectorizedUDAFs.class);
    if (annotation == null) {
        String issue = "Evaluator " + evaluator.getClass().getSimpleName() + " does not have a " + "vectorized UDAF annotation (aggregation: \"" + aggregateName + "\"). " + "Vectorization not supported";
        return new ImmutablePair<VectorAggregationDesc, String>(null, issue);
    }
    final Class<? extends VectorAggregateExpression>[] vecAggrClasses = annotation.value();
    final TypeInfo outputTypeInfo = TypeInfoUtils.getTypeInfoFromTypeString(returnOI.getTypeName());
    // Not final since it may change later due to DECIMAL_64.
    ColumnVector.Type outputColVectorType = VectorizationContext.getColumnVectorTypeFromTypeInfo(outputTypeInfo);
    /*
     * Determine input type info.
     */
    final TypeInfo inputTypeInfo;
    // Not final since it may change later due to DECIMAL_64.
    VectorExpression inputExpression;
    ColumnVector.Type inputColVectorType;
    if (parameterCount == 0) {
        // COUNT(*)
        inputTypeInfo = null;
        inputColVectorType = null;
        inputExpression = null;
    } else if (parameterCount == 1) {
        ExprNodeDesc exprNodeDesc = parameterList.get(0);
        inputTypeInfo = exprNodeDesc.getTypeInfo();
        if (inputTypeInfo == null) {
            String issue = "Aggregations with null parameter type not supported " + aggregateName + "(" + parameterList.toString() + ")";
            return new ImmutablePair<VectorAggregationDesc, String>(null, issue);
        }
        /*
       * Determine an *initial* input vector expression.
       *
       * Note: we may have to convert it later from DECIMAL_64 to regular decimal.
       */
        inputExpression = vContext.getVectorExpression(exprNodeDesc, VectorExpressionDescriptor.Mode.PROJECTION);
        if (inputExpression == null) {
            String issue = "Parameter expression " + exprNodeDesc.toString() + " not supported " + aggregateName + "(" + parameterList.toString() + ")";
            return new ImmutablePair<VectorAggregationDesc, String>(null, issue);
        }
        if (inputExpression.getOutputTypeInfo() == null) {
            String issue = "Parameter expression " + exprNodeDesc.toString() + " with null type not supported " + aggregateName + "(" + parameterList.toString() + ")";
            return new ImmutablePair<VectorAggregationDesc, String>(null, issue);
        }
        inputColVectorType = inputExpression.getOutputColumnVectorType();
    } else {
        // No multi-parameter aggregations supported.
        String issue = "Aggregations with > 1 parameter are not supported " + aggregateName + "(" + parameterList.toString() + ")";
        return new ImmutablePair<VectorAggregationDesc, String>(null, issue);
    }
    /*
     * When we have DECIMAL_64 as the input parameter then we have to see if there is a special
     * vector UDAF for it.  If not we will need to convert the input parameter.
     */
    if (inputTypeInfo != null && inputColVectorType == ColumnVector.Type.DECIMAL_64) {
        if (outputColVectorType == ColumnVector.Type.DECIMAL) {
            DecimalTypeInfo outputDecimalTypeInfo = (DecimalTypeInfo) outputTypeInfo;
            if (HiveDecimalWritable.isPrecisionDecimal64(outputDecimalTypeInfo.getPrecision())) {
                // Try with DECIMAL_64 input and DECIMAL_64 output.
                final Class<? extends VectorAggregateExpression> vecAggrClass = findVecAggrClass(vecAggrClasses, aggregateName, inputColVectorType, ColumnVector.Type.DECIMAL_64, udafEvaluatorMode);
                if (vecAggrClass != null) {
                    final VectorAggregationDesc vecAggrDesc = new VectorAggregationDesc(aggrDesc, evaluator, inputTypeInfo, inputColVectorType, inputExpression, outputTypeInfo, ColumnVector.Type.DECIMAL_64, vecAggrClass);
                    return new ImmutablePair<VectorAggregationDesc, String>(vecAggrDesc, null);
                }
            }
            // Try with regular DECIMAL output type.
            final Class<? extends VectorAggregateExpression> vecAggrClass = findVecAggrClass(vecAggrClasses, aggregateName, inputColVectorType, outputColVectorType, udafEvaluatorMode);
            if (vecAggrClass != null) {
                final VectorAggregationDesc vecAggrDesc = new VectorAggregationDesc(aggrDesc, evaluator, inputTypeInfo, inputColVectorType, inputExpression, outputTypeInfo, outputColVectorType, vecAggrClass);
                return new ImmutablePair<VectorAggregationDesc, String>(vecAggrDesc, null);
            }
            // No support for DECIMAL_64 input.  We must convert.
            inputExpression = vContext.wrapWithDecimal64ToDecimalConversion(inputExpression);
            inputColVectorType = ColumnVector.Type.DECIMAL;
        // Fall through...
        } else {
            // Try with with DECIMAL_64 input and desired output type.
            final Class<? extends VectorAggregateExpression> vecAggrClass = findVecAggrClass(vecAggrClasses, aggregateName, inputColVectorType, outputColVectorType, udafEvaluatorMode);
            if (vecAggrClass != null) {
                final VectorAggregationDesc vecAggrDesc = new VectorAggregationDesc(aggrDesc, evaluator, inputTypeInfo, inputColVectorType, inputExpression, outputTypeInfo, outputColVectorType, vecAggrClass);
                return new ImmutablePair<VectorAggregationDesc, String>(vecAggrDesc, null);
            }
            // No support for DECIMAL_64 input.  We must convert.
            inputExpression = vContext.wrapWithDecimal64ToDecimalConversion(inputExpression);
            inputColVectorType = ColumnVector.Type.DECIMAL;
        // Fall through...
        }
    }
    /*
     * Look for normal match.
     */
    Class<? extends VectorAggregateExpression> vecAggrClass = findVecAggrClass(vecAggrClasses, aggregateName, inputColVectorType, outputColVectorType, udafEvaluatorMode);
    if (vecAggrClass != null) {
        final VectorAggregationDesc vecAggrDesc = new VectorAggregationDesc(aggrDesc, evaluator, inputTypeInfo, inputColVectorType, inputExpression, outputTypeInfo, outputColVectorType, vecAggrClass);
        return new ImmutablePair<VectorAggregationDesc, String>(vecAggrDesc, null);
    }
    // No match?
    String issue = "Vector aggregation : \"" + aggregateName + "\" " + "for input type: " + (inputColVectorType == null ? "any" : "\"" + inputColVectorType) + "\" " + "and output type: \"" + outputColVectorType + "\" " + "and mode: " + udafEvaluatorMode + " not supported for " + "evaluator " + evaluator.getClass().getSimpleName();
    return new ImmutablePair<VectorAggregationDesc, String>(null, issue);
}
Also used : StructObjectInspector(org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector) ObjectInspector(org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector) VectorAggregateExpression(org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression) UDFToString(org.apache.hadoop.hive.ql.udf.UDFToString) VectorizedUDAFs(org.apache.hadoop.hive.ql.exec.vector.VectorizedUDAFs) TypeInfo(org.apache.hadoop.hive.serde2.typeinfo.TypeInfo) StructTypeInfo(org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo) DecimalTypeInfo(org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo) PrimitiveTypeInfo(org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo) ColumnVector(org.apache.hadoop.hive.ql.exec.vector.ColumnVector) VectorAggregationDesc(org.apache.hadoop.hive.ql.exec.vector.VectorAggregationDesc) DecimalTypeInfo(org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) Type(org.apache.hadoop.hive.ql.exec.vector.ColumnVector.Type) VectorExpression(org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression) ExprNodeDesc(org.apache.hadoop.hive.ql.plan.ExprNodeDesc)

Aggregations

ImmutablePair (org.apache.commons.lang3.tuple.ImmutablePair)1 ColumnVector (org.apache.hadoop.hive.ql.exec.vector.ColumnVector)1 Type (org.apache.hadoop.hive.ql.exec.vector.ColumnVector.Type)1 VectorAggregationDesc (org.apache.hadoop.hive.ql.exec.vector.VectorAggregationDesc)1 VectorizedUDAFs (org.apache.hadoop.hive.ql.exec.vector.VectorizedUDAFs)1 VectorExpression (org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression)1 VectorAggregateExpression (org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression)1 ExprNodeDesc (org.apache.hadoop.hive.ql.plan.ExprNodeDesc)1 UDFToString (org.apache.hadoop.hive.ql.udf.UDFToString)1 ObjectInspector (org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector)1 StructObjectInspector (org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector)1 DecimalTypeInfo (org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo)1 PrimitiveTypeInfo (org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo)1 StructTypeInfo (org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo)1 TypeInfo (org.apache.hadoop.hive.serde2.typeinfo.TypeInfo)1