use of org.drools.core.base.accumulators.CountAccumulateFunction in project drools by kiegroup.
the class KiePackagesBuilder method buildAccumulate.
private Accumulate buildAccumulate(RuleContext ctx, AccumulatePattern accPattern, RuleConditionElement source, Pattern pattern, Set<String> usedVariableName, Collection<Binding> bindings) {
boolean isGroupBy = accPattern instanceof GroupByPattern;
AccumulateFunction[] accFunctions = accPattern.getAccumulateFunctions();
Declaration groupByDeclaration = null;
Class selfType = (isGroupBy || accFunctions.length > 1) ? Object[].class : accFunctions[0].getResult().getType();
InternalReadAccessor selfReader = new SelfReferenceClassFieldReader(selfType);
int arrayIndexOffset = 0;
if (isGroupBy) {
if (accFunctions.length == 0) {
// In this situation the result is anonymous, but it still uses element position 0.
// For this reason the i used to populate hte array index must be offset by 1.
accFunctions = new AccumulateFunction[] { new AccumulateFunction(null, CountAccumulateFunction::new) };
arrayIndexOffset = 1;
}
// GroupBy key is always the last element in the result array
Variable groupVar = ((GroupByPattern<?, ?>) accPattern).getVarKey();
groupByDeclaration = new Declaration(groupVar.getName(), new ArrayElementReader(selfReader, accFunctions.length, groupVar.getType()), pattern, true);
pattern.addDeclaration(groupByDeclaration);
}
Accumulate accumulate;
Accumulator[] accumulators = new Accumulator[accFunctions.length];
List<Declaration> requiredDeclarationList = new ArrayList<>();
for (int i = 0; i < accFunctions.length; i++) {
Variable boundVar = processFunctions(ctx, accPattern, source, pattern, usedVariableName, bindings, isGroupBy, accFunctions[i], selfReader, accumulators, requiredDeclarationList, arrayIndexOffset, i);
if (isGroupBy) {
ctx.addGroupByDeclaration(((GroupByPattern) accPattern).getVarKey(), groupByDeclaration);
}
}
if (accFunctions.length == 1) {
accumulate = new SingleAccumulate(source, requiredDeclarationList.toArray(new Declaration[requiredDeclarationList.size()]), accumulators[0]);
} else {
if (source instanceof Pattern) {
requiredDeclarationList.forEach(((Pattern) source)::addDeclaration);
}
accumulate = new MultiAccumulate(source, new Declaration[0], accumulators, // if this is a groupby +1 for the key
accumulators.length + ((isGroupBy) ? 1 : 0));
}
if (isGroupBy) {
GroupByPatternImpl groupBy = (GroupByPatternImpl) accPattern;
Declaration[] groupingDeclarations = new Declaration[groupBy.getVars().length];
for (int i = 0; i < groupBy.getVars().length; i++) {
groupingDeclarations[i] = ctx.getDeclaration(groupBy.getVars()[i]);
}
accumulate = new LambdaGroupByAccumulate(accumulate, groupingDeclarations, groupBy.getGroupingFunction());
}
for (Variable boundVar : accPattern.getBoundVariables()) {
ctx.addAccumulateSource(boundVar, accumulate);
}
return accumulate;
}
use of org.drools.core.base.accumulators.CountAccumulateFunction in project drools by kiegroup.
the class GroupByTest method testNestedGroupBy3.
@Test
// <- FIXME, see comment inside (@mario)
@Ignore
public void testNestedGroupBy3() throws Exception {
// DROOLS-6045
final Global<List> var_results = D.globalOf(List.class, "defaultpkg", "results");
final Variable<Object> var_$key = D.declarationOf(Object.class);
final Variable<Object> var_$keyOuter = D.declarationOf(Object.class);
final Variable<Person> var_$p = D.declarationOf(Person.class);
final Variable<Object> var_$accresult = D.declarationOf(Object.class);
final Rule rule1 = PatternDSL.rule("R1").build(D.groupBy(D.and(D.groupBy(// Patterns
D.pattern(var_$p), // Grouping Function
var_$p, var_$key, Person::getName, D.accFunction(CountAccumulateFunction::new).as(var_$accresult)), // Bindings
D.pattern(var_$accresult).expr(// FIXME var_$accresult is collection of Long, how did this pass before(mdp) ?
c -> ((Integer) c) > 0)), var_$key, var_$accresult, var_$keyOuter, Pair::create), // Consequence
D.on(var_$keyOuter, var_results).execute(($outerKey, results) -> {
results.add($outerKey);
}));
final Model model = new ModelImpl().addRule(rule1).addGlobal(var_results);
final KieSession ksession = KieBaseBuilder.createKieBaseFromModel(model).newKieSession();
final List<Object> results = new ArrayList<>();
ksession.setGlobal("results", results);
ksession.insert("A");
ksession.insert("test");
ksession.insert(new Person("Mark", 42));
assertThat(ksession.fireAllRules()).isEqualTo(1);
assertThat(results).containsOnly(Pair.create("Mark", 1L));
}
use of org.drools.core.base.accumulators.CountAccumulateFunction in project drools by kiegroup.
the class GroupByTest method testDecomposedGroupByKeyAndAccumulate.
@Test
public void testDecomposedGroupByKeyAndAccumulate() throws Exception {
// DROOLS-6031
final Global<List> var_results = D.globalOf(List.class, "defaultpkg", "results");
final Variable<Pair<String, String>> var_$key = (Variable) D.declarationOf(Pair.class);
final Variable<Person> var_$p = D.declarationOf(Person.class);
final Variable<String> var_$subkeyA = D.declarationOf(String.class);
final Variable<String> var_$subkeyB = D.declarationOf(String.class);
final Variable<Long> var_$accresult = D.declarationOf(Long.class);
final Rule rule1 = PatternDSL.rule("R1").build(D.groupBy(// Patterns
PatternDSL.pattern(var_$p), // Grouping Function
var_$p, var_$key, person -> Pair.create(person.getName().substring(0, 1), person.getName().substring(1, 2)), D.accFunction(CountAccumulateFunction::new).as(var_$accresult)), // Bindings
D.pattern(var_$key).bind(var_$subkeyA, Pair::getKey).bind(var_$subkeyB, Pair::getValue), D.pattern(var_$accresult).expr(l -> l > 0), // Consequence
D.on(var_$subkeyA, var_$subkeyB, var_$accresult, var_results).execute(($a, $b, $accresult, results) -> {
results.add($a);
results.add($b);
results.add($accresult);
}));
final Model model = new ModelImpl().addRule(rule1).addGlobal(var_results);
final KieSession ksession = KieBaseBuilder.createKieBaseFromModel(model).newKieSession();
final List<Object> results = new ArrayList<>();
ksession.setGlobal("results", results);
ksession.insert("A");
ksession.insert("test");
ksession.insert(new Person("Mark", 42));
assertThat(ksession.fireAllRules()).isEqualTo(1);
Assertions.assertThat(results.size()).isEqualTo(3);
Assertions.assertThat(results.get(0)).isEqualTo("M");
Assertions.assertThat(results.get(1)).isEqualTo("a");
Assertions.assertThat(results.get(2)).isEqualTo(1L);
}
use of org.drools.core.base.accumulators.CountAccumulateFunction in project drools by kiegroup.
the class GroupByTest method testUnexpectedRuleMatch.
@Test
public void testUnexpectedRuleMatch() {
final Global<List> var_results = D.globalOf(List.class, "defaultpkg", "results");
// $a: Parent()
Variable<Parent> patternVar = D.declarationOf(Parent.class);
PatternDSL.PatternDef<Parent> pattern = D.pattern(patternVar);
// exists Child($a.getChild() == this)
Variable<Child> existsPatternVar = D.declarationOf(Child.class);
PatternDSL.PatternDef<Child> existsPattern = D.pattern(existsPatternVar).expr(patternVar, (child, parent) -> Objects.equals(parent.getChild(), child));
// count(Parent::getChild)
Variable<Child> groupKeyVar = D.declarationOf(Child.class);
Variable<Long> accumulateResult = D.declarationOf(Long.class);
ExprViewItem groupBy = PatternDSL.groupBy(D.and(pattern, D.exists(existsPattern)), patternVar, groupKeyVar, Parent::getChild, DSL.accFunction(CountAccumulateFunction::new).as(accumulateResult));
Rule rule1 = D.rule("R1").build(groupBy, D.on(var_results, groupKeyVar, accumulateResult).execute((results, $child, $count) -> results.add(Arrays.asList($child, $count))));
Model model = new ModelImpl().addRule(rule1).addGlobal(var_results);
KieSession ksession = KieBaseBuilder.createKieBaseFromModel(model).newKieSession();
List results = new ArrayList();
ksession.setGlobal("results", results);
Child child1 = new Child("Child1", 1);
Parent parent1 = new Parent("Parent1", child1);
Child child2 = new Child("Child2", 2);
Parent parent2 = new Parent("Parent2", child2);
ksession.insert(parent1);
ksession.insert(parent2);
FactHandle toRemove = ksession.insert(child1);
ksession.insert(child2);
// Remove child1, therefore it does not exist, therefore there should be no groupBy matches for the child.
ksession.delete(toRemove);
// Yet, we still get (Child1, 0).
ksession.fireAllRules();
Assertions.assertThat(results).containsOnly(Arrays.asList(child2, 1L));
}
Aggregations