use of org.drools.modelcompiler.constraints.LambdaGroupByAccumulate in project drools by kiegroup.
the class KiePackagesBuilder method buildAccumulate.
private Accumulate buildAccumulate(RuleContext ctx, AccumulatePattern accPattern, RuleConditionElement source, Pattern pattern, Set<String> usedVariableName, Collection<Binding> bindings) {
boolean isGroupBy = accPattern instanceof GroupByPattern;
AccumulateFunction[] accFunctions = accPattern.getAccumulateFunctions();
Declaration groupByDeclaration = null;
Class selfType = (isGroupBy || accFunctions.length > 1) ? Object[].class : accFunctions[0].getResult().getType();
InternalReadAccessor selfReader = new SelfReferenceClassFieldReader(selfType);
int arrayIndexOffset = 0;
if (isGroupBy) {
if (accFunctions.length == 0) {
// In this situation the result is anonymous, but it still uses element position 0.
// For this reason the i used to populate hte array index must be offset by 1.
accFunctions = new AccumulateFunction[] { new AccumulateFunction(null, CountAccumulateFunction::new) };
arrayIndexOffset = 1;
}
// GroupBy key is always the last element in the result array
Variable groupVar = ((GroupByPattern<?, ?>) accPattern).getVarKey();
groupByDeclaration = new Declaration(groupVar.getName(), new ArrayElementReader(selfReader, accFunctions.length, groupVar.getType()), pattern, true);
pattern.addDeclaration(groupByDeclaration);
}
Accumulate accumulate;
Accumulator[] accumulators = new Accumulator[accFunctions.length];
List<Declaration> requiredDeclarationList = new ArrayList<>();
for (int i = 0; i < accFunctions.length; i++) {
Variable boundVar = processFunctions(ctx, accPattern, source, pattern, usedVariableName, bindings, isGroupBy, accFunctions[i], selfReader, accumulators, requiredDeclarationList, arrayIndexOffset, i);
if (isGroupBy) {
ctx.addGroupByDeclaration(((GroupByPattern) accPattern).getVarKey(), groupByDeclaration);
}
}
if (accFunctions.length == 1) {
accumulate = new SingleAccumulate(source, requiredDeclarationList.toArray(new Declaration[requiredDeclarationList.size()]), accumulators[0]);
} else {
if (source instanceof Pattern) {
requiredDeclarationList.forEach(((Pattern) source)::addDeclaration);
}
accumulate = new MultiAccumulate(source, new Declaration[0], accumulators, // if this is a groupby +1 for the key
accumulators.length + ((isGroupBy) ? 1 : 0));
}
if (isGroupBy) {
GroupByPatternImpl groupBy = (GroupByPatternImpl) accPattern;
Declaration[] groupingDeclarations = new Declaration[groupBy.getVars().length];
for (int i = 0; i < groupBy.getVars().length; i++) {
groupingDeclarations[i] = ctx.getDeclaration(groupBy.getVars()[i]);
}
accumulate = new LambdaGroupByAccumulate(accumulate, groupingDeclarations, groupBy.getGroupingFunction());
}
for (Variable boundVar : accPattern.getBoundVariables()) {
ctx.addAccumulateSource(boundVar, accumulate);
}
return accumulate;
}
Aggregations