use of com.hazelcast.jet.sql.impl.aggregate.ValueSqlAggregation in project hazelcast by hazelcast.
the class AggregateAbstractPhysicalRule method aggregateOperation.
protected static AggregateOperation<?, JetSqlRow> aggregateOperation(RelDataType inputType, ImmutableBitSet groupSet, List<AggregateCall> aggregateCalls) {
List<QueryDataType> operandTypes = OptUtils.schema(inputType).getTypes();
List<SupplierEx<SqlAggregation>> aggregationProviders = new ArrayList<>();
List<FunctionEx<JetSqlRow, Object>> valueProviders = new ArrayList<>();
for (Integer groupIndex : groupSet.toList()) {
aggregationProviders.add(ValueSqlAggregation::new);
// getMaybeSerialized is safe for ValueAggr because it only passes the value on
valueProviders.add(new RowGetMaybeSerializedFn(groupIndex));
}
for (AggregateCall aggregateCall : aggregateCalls) {
boolean distinct = aggregateCall.isDistinct();
List<Integer> aggregateCallArguments = aggregateCall.getArgList();
SqlKind kind = aggregateCall.getAggregation().getKind();
switch(kind) {
case COUNT:
if (distinct) {
int countIndex = aggregateCallArguments.get(0);
aggregationProviders.add(new AggregateCountSupplier(true, true));
// getMaybeSerialized is safe for COUNT because the aggregation only looks whether it is null or not
valueProviders.add(new RowGetMaybeSerializedFn(countIndex));
} else if (aggregateCallArguments.size() == 1) {
int countIndex = aggregateCallArguments.get(0);
aggregationProviders.add(new AggregateCountSupplier(true, false));
valueProviders.add(new RowGetMaybeSerializedFn(countIndex));
} else {
aggregationProviders.add(new AggregateCountSupplier(false, false));
valueProviders.add(NullFunction.INSTANCE);
}
break;
case MIN:
int minIndex = aggregateCallArguments.get(0);
aggregationProviders.add(MinSqlAggregation::new);
valueProviders.add(new RowGetFn(minIndex));
break;
case MAX:
int maxIndex = aggregateCallArguments.get(0);
aggregationProviders.add(MaxSqlAggregation::new);
valueProviders.add(new RowGetFn(maxIndex));
break;
case SUM:
int sumIndex = aggregateCallArguments.get(0);
QueryDataType sumOperandType = operandTypes.get(sumIndex);
aggregationProviders.add(new AggregateSumSupplier(distinct, sumOperandType));
valueProviders.add(new RowGetFn(sumIndex));
break;
case AVG:
int avgIndex = aggregateCallArguments.get(0);
QueryDataType avgOperandType = operandTypes.get(avgIndex);
aggregationProviders.add(new AggregateAvgSupplier(distinct, avgOperandType));
valueProviders.add(new RowGetFn(avgIndex));
break;
default:
throw QueryException.error("Unsupported aggregation function: " + kind);
}
}
return AggregateOperation.withCreate(new AggregateCreateSupplier(aggregationProviders)).andAccumulate(new AggregateAccumulateFunction(valueProviders)).andCombine(AggregateCombineFunction.INSTANCE).andExportFinish(AggregateExportFinishFunction.INSTANCE);
}
Aggregations