use of org.drools.core.base.accumulators.CountAccumulateFunction in project drools by kiegroup.
the class KiePackagesBuilder method buildAccumulate.
private Accumulate buildAccumulate(RuleContext ctx, AccumulatePattern accPattern, RuleConditionElement source, Pattern pattern, Set<String> usedVariableName, Collection<Binding> bindings) {
boolean isGroupBy = accPattern instanceof GroupByPattern;
AccumulateFunction[] accFunctions = accPattern.getAccumulateFunctions();
Declaration groupByDeclaration = null;
Class selfType = (isGroupBy || accFunctions.length > 1) ? Object[].class : accFunctions[0].getResult().getType();
InternalReadAccessor selfReader = new SelfReferenceClassFieldReader(selfType);
int arrayIndexOffset = 0;
if (isGroupBy) {
if (accFunctions.length == 0) {
// In this situation the result is anonymous, but it still uses element position 0.
// For this reason the i used to populate hte array index must be offset by 1.
accFunctions = new AccumulateFunction[] { new AccumulateFunction(null, CountAccumulateFunction::new) };
arrayIndexOffset = 1;
}
// GroupBy key is always the last element in the result array
Variable groupVar = ((GroupByPattern<?, ?>) accPattern).getVarKey();
groupByDeclaration = new Declaration(groupVar.getName(), new ArrayElementReader(selfReader, accFunctions.length, groupVar.getType()), pattern, true);
pattern.addDeclaration(groupByDeclaration);
}
Accumulate accumulate;
Accumulator[] accumulators = new Accumulator[accFunctions.length];
List<Declaration> requiredDeclarationList = new ArrayList<>();
for (int i = 0; i < accFunctions.length; i++) {
Variable boundVar = processFunctions(ctx, accPattern, source, pattern, usedVariableName, bindings, isGroupBy, accFunctions[i], selfReader, accumulators, requiredDeclarationList, arrayIndexOffset, i);
if (isGroupBy) {
ctx.addGroupByDeclaration(((GroupByPattern) accPattern).getVarKey(), groupByDeclaration);
}
}
if (accFunctions.length == 1) {
accumulate = new SingleAccumulate(source, requiredDeclarationList.toArray(new Declaration[requiredDeclarationList.size()]), accumulators[0]);
} else {
if (source instanceof Pattern) {
requiredDeclarationList.forEach(((Pattern) source)::addDeclaration);
}
accumulate = new MultiAccumulate(source, new Declaration[0], accumulators, // if this is a groupby +1 for the key
accumulators.length + ((isGroupBy) ? 1 : 0));
}
if (isGroupBy) {
GroupByPatternImpl groupBy = (GroupByPatternImpl) accPattern;
Declaration[] groupingDeclarations = new Declaration[groupBy.getVars().length];
for (int i = 0; i < groupBy.getVars().length; i++) {
groupingDeclarations[i] = ctx.getDeclaration(groupBy.getVars()[i]);
}
accumulate = new LambdaGroupByAccumulate(accumulate, groupingDeclarations, groupBy.getGroupingFunction());
}
for (Variable boundVar : accPattern.getBoundVariables()) {
ctx.addAccumulateSource(boundVar, accumulate);
}
return accumulate;
}
use of org.drools.core.base.accumulators.CountAccumulateFunction in project drools by kiegroup.
the class GroupByTest method testWithGroupByAfterExists.
@Test
public void testWithGroupByAfterExists() {
Global<Map> groupResultVar = D.globalOf(Map.class, "defaultPkg", "glob");
Variable<Integer> patternVar = D.declarationOf(Integer.class);
Variable<String> existsVar = D.declarationOf(String.class);
Variable<Integer> keyVar = D.declarationOf(Integer.class);
Variable<Long> resultVar = D.declarationOf(Long.class);
D.PatternDef<Integer> pattern = D.pattern(patternVar);
D.PatternDef<String> exist = D.pattern(existsVar);
ViewItem patternAndExists = D.and(pattern, D.exists(exist));
ViewItem groupBy = D.groupBy(patternAndExists, patternVar, keyVar, Math::abs, DSL.accFunction(CountAccumulateFunction::new).as(resultVar));
ConsequenceBuilder._3 consequence = D.on(keyVar, resultVar, groupResultVar).execute((key, count, result) -> {
result.put(key, count.intValue());
});
Rule rule = D.rule("R").build(groupBy, consequence);
Model model = new ModelImpl().addRule(rule).addGlobal(groupResultVar);
KieBase kieBase = KieBaseBuilder.createKieBaseFromModel(model);
KieSession session = kieBase.newKieSession();
Map<Integer, Integer> global = new HashMap<>();
session.setGlobal("glob", global);
session.insert("Something");
session.insert(-1);
session.insert(1);
session.insert(2);
session.fireAllRules();
assertEquals(2, global.size());
// -1 and 1 will map to the same key, and count twice.
assertEquals(2, (int) global.get(1));
// 2 maps to a key, and counts once.
assertEquals(1, (int) global.get(2));
}
use of org.drools.core.base.accumulators.CountAccumulateFunction in project drools by kiegroup.
the class GroupByTest method testWithNull.
@Test
public void testWithNull() {
Variable<MyType> var = D.declarationOf(MyType.class);
Variable<MyType> groupKey = D.declarationOf(MyType.class);
Variable<Long> count = D.declarationOf(Long.class);
AtomicInteger mappingFunctionCallCounter = new AtomicInteger(0);
Function1<MyType, MyType> mappingFunction = (a) -> {
mappingFunctionCallCounter.incrementAndGet();
return a.getNested();
};
D.PatternDef<MyType> onlyOnesWithNested = D.pattern(var).expr(myType -> myType.getNested() != null);
ExprViewItem groupBy = D.groupBy(onlyOnesWithNested, var, groupKey, mappingFunction, D.accFunction(CountAccumulateFunction::new).as(count));
List<MyType> result = new ArrayList<>();
Rule rule = D.rule("R").build(groupBy, D.on(groupKey, count).execute((drools, key, acc) -> result.add(key)));
Model model = new ModelImpl().addRule(rule);
KieBase kieBase = KieBaseBuilder.createKieBaseFromModel(model);
MyType objectWithoutNestedObject = new MyType(null);
MyType objectWithNestedObject = new MyType(objectWithoutNestedObject);
KieSession ksession = kieBase.newKieSession();
ksession.insert(objectWithNestedObject);
ksession.insert(objectWithoutNestedObject);
ksession.fireAllRules();
// Side issue: this number is unusually high. Perhaps we should try to implement some cache for this?
System.out.println("GroupKey mapping function was called " + mappingFunctionCallCounter.get() + " times.");
Assertions.assertThat(result).containsOnly(objectWithoutNestedObject);
}
use of org.drools.core.base.accumulators.CountAccumulateFunction in project drools by kiegroup.
the class GroupByTest method testWithGroupByAfterExistsWithFrom.
@Test
public void testWithGroupByAfterExistsWithFrom() {
Global<Map> groupResultVar = D.globalOf(Map.class, "defaultPkg", "glob");
Variable<Integer> patternVar = D.declarationOf(Integer.class);
Variable<String> existsVar = D.declarationOf(String.class);
Variable<Integer> keyVar = D.declarationOf(Integer.class);
Variable<Long> resultVar = D.declarationOf(Long.class);
Variable<Integer> mappedResultVar = D.declarationOf(Integer.class);
D.PatternDef<Integer> pattern = D.pattern(patternVar);
D.PatternDef<String> exist = D.pattern(existsVar);
ViewItem patternAndExists = D.and(pattern, D.exists(exist));
ViewItem groupBy = D.groupBy(patternAndExists, patternVar, keyVar, Math::abs, DSL.accFunction(CountAccumulateFunction::new).as(resultVar));
PatternDSL.PatternDef mappedResult = D.pattern(resultVar).bind(mappedResultVar, Long::intValue);
ConsequenceBuilder._3 consequence = D.on(keyVar, mappedResultVar, groupResultVar).execute((key, count, result) -> {
result.put(key, count);
});
Rule rule = D.rule("R").build(groupBy, mappedResult, consequence);
Model model = new ModelImpl().addRule(rule).addGlobal(groupResultVar);
KieBase kieBase = KieBaseBuilder.createKieBaseFromModel(model);
KieSession session = kieBase.newKieSession();
Map<Integer, Integer> global = new HashMap<>();
session.setGlobal("glob", global);
session.insert("Something");
session.insert(-1);
session.insert(1);
session.insert(2);
session.fireAllRules();
assertEquals(2, global.size());
// -1 and 1 will map to the same key, and count twice.
assertEquals(2, (int) global.get(1));
// 2 maps to a key, and counts once.
assertEquals(1, (int) global.get(2));
}
use of org.drools.core.base.accumulators.CountAccumulateFunction in project drools by kiegroup.
the class GroupByTest method testFromAfterGroupBy.
@Test
public void testFromAfterGroupBy() {
Global<Set> var_results = D.globalOf(Set.class, "defaultpkg", "results");
Variable var_$p1 = D.declarationOf(Person.class);
Variable var_$key = D.declarationOf(String.class);
Variable var_$count = D.declarationOf(Long.class);
Variable var_$remapped1 = D.declarationOf(Object.class, from(var_$key));
Variable var_$remapped2 = D.declarationOf(Long.class, from(var_$count));
PatternDSL.PatternDef<Person> p1pattern = D.pattern(var_$p1).expr(p -> ((Person) p).getName() != null);
Rule rule1 = D.rule("R1").build(D.groupBy(p1pattern, var_$p1, var_$key, Person::getName, DSL.accFunction(CountAccumulateFunction::new, var_$p1).as(var_$count)), D.pattern(var_$remapped1), D.pattern(var_$remapped2), D.on(var_$remapped1, var_$remapped2).execute((ctx, name, count) -> {
if (!(name instanceof String)) {
throw new IllegalStateException("Name not String, but " + name.getClass());
}
}));
Model model = new ModelImpl().addRule(rule1).addGlobal(var_results);
KieSession ksession = KieBaseBuilder.createKieBaseFromModel(model).newKieSession();
Set<Integer> results = new LinkedHashSet<>();
ksession.setGlobal("results", results);
ksession.insert(new Person("Mark", 42));
ksession.insert(new Person("Edson", 38));
ksession.insert(new Person("Edoardo", 33));
int fireCount = ksession.fireAllRules();
assertThat(fireCount).isGreaterThan(0);
}
Aggregations