Search in sources :

Example 1 with CountAccumulateFunction

use of org.drools.core.base.accumulators.CountAccumulateFunction in project drools by kiegroup.

the class KiePackagesBuilder method buildAccumulate.

private Accumulate buildAccumulate(RuleContext ctx, AccumulatePattern accPattern, RuleConditionElement source, Pattern pattern, Set<String> usedVariableName, Collection<Binding> bindings) {
    boolean isGroupBy = accPattern instanceof GroupByPattern;
    AccumulateFunction[] accFunctions = accPattern.getAccumulateFunctions();
    Declaration groupByDeclaration = null;
    Class selfType = (isGroupBy || accFunctions.length > 1) ? Object[].class : accFunctions[0].getResult().getType();
    InternalReadAccessor selfReader = new SelfReferenceClassFieldReader(selfType);
    int arrayIndexOffset = 0;
    if (isGroupBy) {
        if (accFunctions.length == 0) {
            // In this situation the result is anonymous, but it still uses element position 0.
            // For this reason the i used to populate hte array index must be offset by 1.
            accFunctions = new AccumulateFunction[] { new AccumulateFunction(null, CountAccumulateFunction::new) };
            arrayIndexOffset = 1;
        }
        // GroupBy  key is always the last element in the result array
        Variable groupVar = ((GroupByPattern<?, ?>) accPattern).getVarKey();
        groupByDeclaration = new Declaration(groupVar.getName(), new ArrayElementReader(selfReader, accFunctions.length, groupVar.getType()), pattern, true);
        pattern.addDeclaration(groupByDeclaration);
    }
    Accumulate accumulate;
    Accumulator[] accumulators = new Accumulator[accFunctions.length];
    List<Declaration> requiredDeclarationList = new ArrayList<>();
    for (int i = 0; i < accFunctions.length; i++) {
        Variable boundVar = processFunctions(ctx, accPattern, source, pattern, usedVariableName, bindings, isGroupBy, accFunctions[i], selfReader, accumulators, requiredDeclarationList, arrayIndexOffset, i);
        if (isGroupBy) {
            ctx.addGroupByDeclaration(((GroupByPattern) accPattern).getVarKey(), groupByDeclaration);
        }
    }
    if (accFunctions.length == 1) {
        accumulate = new SingleAccumulate(source, requiredDeclarationList.toArray(new Declaration[requiredDeclarationList.size()]), accumulators[0]);
    } else {
        if (source instanceof Pattern) {
            requiredDeclarationList.forEach(((Pattern) source)::addDeclaration);
        }
        accumulate = new MultiAccumulate(source, new Declaration[0], accumulators, // if this is a groupby +1 for the key
        accumulators.length + ((isGroupBy) ? 1 : 0));
    }
    if (isGroupBy) {
        GroupByPatternImpl groupBy = (GroupByPatternImpl) accPattern;
        Declaration[] groupingDeclarations = new Declaration[groupBy.getVars().length];
        for (int i = 0; i < groupBy.getVars().length; i++) {
            groupingDeclarations[i] = ctx.getDeclaration(groupBy.getVars()[i]);
        }
        accumulate = new LambdaGroupByAccumulate(accumulate, groupingDeclarations, groupBy.getGroupingFunction());
    }
    for (Variable boundVar : accPattern.getBoundVariables()) {
        ctx.addAccumulateSource(boundVar, accumulate);
    }
    return accumulate;
}
Also used : LambdaAccumulator(org.drools.modelcompiler.constraints.LambdaAccumulator) Accumulator(org.drools.core.spi.Accumulator) QueryCallPattern(org.drools.model.patterns.QueryCallPattern) AccumulatePattern(org.drools.model.AccumulatePattern) Pattern(org.drools.core.rule.Pattern) GroupByPattern(org.drools.model.GroupByPattern) PrototypeVariable(org.drools.model.PrototypeVariable) Variable(org.drools.model.Variable) MultiAccumulate(org.drools.core.rule.MultiAccumulate) ArrayList(java.util.ArrayList) SingleAccumulate(org.drools.core.rule.SingleAccumulate) SingleConstraint(org.drools.model.SingleConstraint) QueryNameConstraint(org.drools.core.rule.constraint.QueryNameConstraint) LambdaConstraint(org.drools.modelcompiler.constraints.LambdaConstraint) UnificationConstraint(org.drools.modelcompiler.constraints.UnificationConstraint) EntryPoint(org.drools.model.EntryPoint) AbstractConstraint(org.drools.modelcompiler.constraints.AbstractConstraint) Constraint(org.drools.model.Constraint) CombinedConstraint(org.drools.modelcompiler.constraints.CombinedConstraint) AbstractSingleConstraint(org.drools.model.constraints.AbstractSingleConstraint) DSL.entryPoint(org.drools.model.DSL.entryPoint) LambdaGroupByAccumulate(org.drools.modelcompiler.constraints.LambdaGroupByAccumulate) MultiAccumulate(org.drools.core.rule.MultiAccumulate) SingleAccumulate(org.drools.core.rule.SingleAccumulate) Accumulate(org.drools.core.rule.Accumulate) GroupByPattern(org.drools.model.GroupByPattern) LambdaGroupByAccumulate(org.drools.modelcompiler.constraints.LambdaGroupByAccumulate) SelfReferenceClassFieldReader(org.drools.core.base.extractors.SelfReferenceClassFieldReader) GroupByPatternImpl(org.drools.model.patterns.GroupByPatternImpl) InternalReadAccessor(org.drools.core.spi.InternalReadAccessor) ArrayElementReader(org.drools.core.base.extractors.ArrayElementReader) Declaration(org.drools.core.rule.Declaration) WindowDeclaration(org.drools.core.rule.WindowDeclaration) TypeDeclarationUtil.createTypeDeclaration(org.drools.modelcompiler.util.TypeDeclarationUtil.createTypeDeclaration) TypeDeclaration(org.drools.core.rule.TypeDeclaration) CountAccumulateFunction(org.drools.core.base.accumulators.CountAccumulateFunction) AccumulateFunction(org.drools.model.functions.accumulate.AccumulateFunction)

Example 2 with CountAccumulateFunction

use of org.drools.core.base.accumulators.CountAccumulateFunction in project drools by kiegroup.

the class GroupByTest method testWithGroupByAfterExists.

@Test
public void testWithGroupByAfterExists() {
    Global<Map> groupResultVar = D.globalOf(Map.class, "defaultPkg", "glob");
    Variable<Integer> patternVar = D.declarationOf(Integer.class);
    Variable<String> existsVar = D.declarationOf(String.class);
    Variable<Integer> keyVar = D.declarationOf(Integer.class);
    Variable<Long> resultVar = D.declarationOf(Long.class);
    D.PatternDef<Integer> pattern = D.pattern(patternVar);
    D.PatternDef<String> exist = D.pattern(existsVar);
    ViewItem patternAndExists = D.and(pattern, D.exists(exist));
    ViewItem groupBy = D.groupBy(patternAndExists, patternVar, keyVar, Math::abs, DSL.accFunction(CountAccumulateFunction::new).as(resultVar));
    ConsequenceBuilder._3 consequence = D.on(keyVar, resultVar, groupResultVar).execute((key, count, result) -> {
        result.put(key, count.intValue());
    });
    Rule rule = D.rule("R").build(groupBy, consequence);
    Model model = new ModelImpl().addRule(rule).addGlobal(groupResultVar);
    KieBase kieBase = KieBaseBuilder.createKieBaseFromModel(model);
    KieSession session = kieBase.newKieSession();
    Map<Integer, Integer> global = new HashMap<>();
    session.setGlobal("glob", global);
    session.insert("Something");
    session.insert(-1);
    session.insert(1);
    session.insert(2);
    session.fireAllRules();
    assertEquals(2, global.size());
    // -1 and 1 will map to the same key, and count twice.
    assertEquals(2, (int) global.get(1));
    // 2 maps to a key, and counts once.
    assertEquals(1, (int) global.get(2));
}
Also used : D(org.drools.modelcompiler.dsl.pattern.D) ViewItem(org.drools.model.view.ViewItem) ExprViewItem(org.drools.model.view.ExprViewItem) HashMap(java.util.HashMap) CountAccumulateFunction(org.drools.core.base.accumulators.CountAccumulateFunction) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) ConsequenceBuilder(org.drools.model.consequences.ConsequenceBuilder) KieBase(org.kie.api.KieBase) Model(org.drools.model.Model) KieSession(org.kie.api.runtime.KieSession) Rule(org.drools.model.Rule) ModelImpl(org.drools.model.impl.ModelImpl) Map(java.util.Map) HashMap(java.util.HashMap) Test(org.junit.Test)

Example 3 with CountAccumulateFunction

use of org.drools.core.base.accumulators.CountAccumulateFunction in project drools by kiegroup.

the class GroupByTest method testWithNull.

@Test
public void testWithNull() {
    Variable<MyType> var = D.declarationOf(MyType.class);
    Variable<MyType> groupKey = D.declarationOf(MyType.class);
    Variable<Long> count = D.declarationOf(Long.class);
    AtomicInteger mappingFunctionCallCounter = new AtomicInteger(0);
    Function1<MyType, MyType> mappingFunction = (a) -> {
        mappingFunctionCallCounter.incrementAndGet();
        return a.getNested();
    };
    D.PatternDef<MyType> onlyOnesWithNested = D.pattern(var).expr(myType -> myType.getNested() != null);
    ExprViewItem groupBy = D.groupBy(onlyOnesWithNested, var, groupKey, mappingFunction, D.accFunction(CountAccumulateFunction::new).as(count));
    List<MyType> result = new ArrayList<>();
    Rule rule = D.rule("R").build(groupBy, D.on(groupKey, count).execute((drools, key, acc) -> result.add(key)));
    Model model = new ModelImpl().addRule(rule);
    KieBase kieBase = KieBaseBuilder.createKieBaseFromModel(model);
    MyType objectWithoutNestedObject = new MyType(null);
    MyType objectWithNestedObject = new MyType(objectWithoutNestedObject);
    KieSession ksession = kieBase.newKieSession();
    ksession.insert(objectWithNestedObject);
    ksession.insert(objectWithoutNestedObject);
    ksession.fireAllRules();
    // Side issue: this number is unusually high. Perhaps we should try to implement some cache for this?
    System.out.println("GroupKey mapping function was called " + mappingFunctionCallCounter.get() + " times.");
    Assertions.assertThat(result).containsOnly(objectWithoutNestedObject);
}
Also used : Arrays(java.util.Arrays) InternalFactHandle(org.drools.core.common.InternalFactHandle) Accumulator(org.drools.core.spi.Accumulator) Assertions.assertThat(org.assertj.core.api.Assertions.assertThat) Global(org.drools.model.Global) RuleEventListener(org.kie.internal.event.rule.RuleEventListener) DSL(org.drools.model.DSL) Child(org.drools.modelcompiler.domain.Child) Match(org.kie.api.runtime.rule.Match) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) Map(java.util.Map) Declaration(org.drools.core.rule.Declaration) Assertions(org.assertj.core.api.Assertions) KieSession(org.kie.api.runtime.KieSession) Parent(org.drools.modelcompiler.domain.Parent) Set(java.util.Set) Index(org.drools.model.Index) ConsequenceBuilder(org.drools.model.consequences.ConsequenceBuilder) EvaluationUtil(org.drools.modelcompiler.util.EvaluationUtil) Objects(java.util.Objects) List(java.util.List) Tuple(org.drools.core.spi.Tuple) CountAccumulateFunction(org.drools.core.base.accumulators.CountAccumulateFunction) Person(org.drools.modelcompiler.domain.Person) DSL.from(org.drools.model.DSL.from) ModelImpl(org.drools.model.impl.ModelImpl) HashMap(java.util.HashMap) PatternDSL(org.drools.model.PatternDSL) ArrayList(java.util.ArrayList) Function1(org.drools.model.functions.Function1) CollectListAccumulateFunction(org.drools.core.base.accumulators.CollectListAccumulateFunction) RuleEventManager(org.kie.internal.event.rule.RuleEventManager) KieBase(org.kie.api.KieBase) LinkedHashSet(java.util.LinkedHashSet) Model(org.drools.model.Model) IntegerMaxAccumulateFunction(org.drools.core.base.accumulators.IntegerMaxAccumulateFunction) ReteEvaluator(org.drools.core.common.ReteEvaluator) ViewItem(org.drools.model.view.ViewItem) Pair(org.apache.commons.math3.util.Pair) Variable(org.drools.model.Variable) D(org.drools.modelcompiler.dsl.pattern.D) IntegerSumAccumulateFunction(org.drools.core.base.accumulators.IntegerSumAccumulateFunction) ToIntFunction(java.util.function.ToIntFunction) Assert.assertTrue(org.junit.Assert.assertTrue) Test(org.junit.Test) GroupKey(org.drools.model.functions.accumulate.GroupKey) FactHandle(org.kie.api.runtime.rule.FactHandle) KieBaseBuilder(org.drools.modelcompiler.builder.KieBaseBuilder) Assert.assertNull(org.junit.Assert.assertNull) Ignore(org.junit.Ignore) Rule(org.drools.model.Rule) ExprViewItem(org.drools.model.view.ExprViewItem) RuleTerminalNodeLeftTuple(org.drools.core.reteoo.RuleTerminalNodeLeftTuple) Collections(java.util.Collections) Assert.assertEquals(org.junit.Assert.assertEquals) ExprViewItem(org.drools.model.view.ExprViewItem) D(org.drools.modelcompiler.dsl.pattern.D) ArrayList(java.util.ArrayList) CountAccumulateFunction(org.drools.core.base.accumulators.CountAccumulateFunction) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) KieBase(org.kie.api.KieBase) Model(org.drools.model.Model) KieSession(org.kie.api.runtime.KieSession) Rule(org.drools.model.Rule) ModelImpl(org.drools.model.impl.ModelImpl) Test(org.junit.Test)

Example 4 with CountAccumulateFunction

use of org.drools.core.base.accumulators.CountAccumulateFunction in project drools by kiegroup.

the class GroupByTest method testWithGroupByAfterExistsWithFrom.

@Test
public void testWithGroupByAfterExistsWithFrom() {
    Global<Map> groupResultVar = D.globalOf(Map.class, "defaultPkg", "glob");
    Variable<Integer> patternVar = D.declarationOf(Integer.class);
    Variable<String> existsVar = D.declarationOf(String.class);
    Variable<Integer> keyVar = D.declarationOf(Integer.class);
    Variable<Long> resultVar = D.declarationOf(Long.class);
    Variable<Integer> mappedResultVar = D.declarationOf(Integer.class);
    D.PatternDef<Integer> pattern = D.pattern(patternVar);
    D.PatternDef<String> exist = D.pattern(existsVar);
    ViewItem patternAndExists = D.and(pattern, D.exists(exist));
    ViewItem groupBy = D.groupBy(patternAndExists, patternVar, keyVar, Math::abs, DSL.accFunction(CountAccumulateFunction::new).as(resultVar));
    PatternDSL.PatternDef mappedResult = D.pattern(resultVar).bind(mappedResultVar, Long::intValue);
    ConsequenceBuilder._3 consequence = D.on(keyVar, mappedResultVar, groupResultVar).execute((key, count, result) -> {
        result.put(key, count);
    });
    Rule rule = D.rule("R").build(groupBy, mappedResult, consequence);
    Model model = new ModelImpl().addRule(rule).addGlobal(groupResultVar);
    KieBase kieBase = KieBaseBuilder.createKieBaseFromModel(model);
    KieSession session = kieBase.newKieSession();
    Map<Integer, Integer> global = new HashMap<>();
    session.setGlobal("glob", global);
    session.insert("Something");
    session.insert(-1);
    session.insert(1);
    session.insert(2);
    session.fireAllRules();
    assertEquals(2, global.size());
    // -1 and 1 will map to the same key, and count twice.
    assertEquals(2, (int) global.get(1));
    // 2 maps to a key, and counts once.
    assertEquals(1, (int) global.get(2));
}
Also used : D(org.drools.modelcompiler.dsl.pattern.D) HashMap(java.util.HashMap) CountAccumulateFunction(org.drools.core.base.accumulators.CountAccumulateFunction) ConsequenceBuilder(org.drools.model.consequences.ConsequenceBuilder) KieBase(org.kie.api.KieBase) KieSession(org.kie.api.runtime.KieSession) ModelImpl(org.drools.model.impl.ModelImpl) PatternDSL(org.drools.model.PatternDSL) ViewItem(org.drools.model.view.ViewItem) ExprViewItem(org.drools.model.view.ExprViewItem) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) Model(org.drools.model.Model) Rule(org.drools.model.Rule) Map(java.util.Map) HashMap(java.util.HashMap) Test(org.junit.Test)

Example 5 with CountAccumulateFunction

use of org.drools.core.base.accumulators.CountAccumulateFunction in project drools by kiegroup.

the class GroupByTest method testFromAfterGroupBy.

@Test
public void testFromAfterGroupBy() {
    Global<Set> var_results = D.globalOf(Set.class, "defaultpkg", "results");
    Variable var_$p1 = D.declarationOf(Person.class);
    Variable var_$key = D.declarationOf(String.class);
    Variable var_$count = D.declarationOf(Long.class);
    Variable var_$remapped1 = D.declarationOf(Object.class, from(var_$key));
    Variable var_$remapped2 = D.declarationOf(Long.class, from(var_$count));
    PatternDSL.PatternDef<Person> p1pattern = D.pattern(var_$p1).expr(p -> ((Person) p).getName() != null);
    Rule rule1 = D.rule("R1").build(D.groupBy(p1pattern, var_$p1, var_$key, Person::getName, DSL.accFunction(CountAccumulateFunction::new, var_$p1).as(var_$count)), D.pattern(var_$remapped1), D.pattern(var_$remapped2), D.on(var_$remapped1, var_$remapped2).execute((ctx, name, count) -> {
        if (!(name instanceof String)) {
            throw new IllegalStateException("Name not String, but " + name.getClass());
        }
    }));
    Model model = new ModelImpl().addRule(rule1).addGlobal(var_results);
    KieSession ksession = KieBaseBuilder.createKieBaseFromModel(model).newKieSession();
    Set<Integer> results = new LinkedHashSet<>();
    ksession.setGlobal("results", results);
    ksession.insert(new Person("Mark", 42));
    ksession.insert(new Person("Edson", 38));
    ksession.insert(new Person("Edoardo", 33));
    int fireCount = ksession.fireAllRules();
    assertThat(fireCount).isGreaterThan(0);
}
Also used : Arrays(java.util.Arrays) InternalFactHandle(org.drools.core.common.InternalFactHandle) Accumulator(org.drools.core.spi.Accumulator) Assertions.assertThat(org.assertj.core.api.Assertions.assertThat) Global(org.drools.model.Global) RuleEventListener(org.kie.internal.event.rule.RuleEventListener) DSL(org.drools.model.DSL) Child(org.drools.modelcompiler.domain.Child) Match(org.kie.api.runtime.rule.Match) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) Map(java.util.Map) Declaration(org.drools.core.rule.Declaration) Assertions(org.assertj.core.api.Assertions) KieSession(org.kie.api.runtime.KieSession) Parent(org.drools.modelcompiler.domain.Parent) Set(java.util.Set) Index(org.drools.model.Index) ConsequenceBuilder(org.drools.model.consequences.ConsequenceBuilder) EvaluationUtil(org.drools.modelcompiler.util.EvaluationUtil) Objects(java.util.Objects) List(java.util.List) Tuple(org.drools.core.spi.Tuple) CountAccumulateFunction(org.drools.core.base.accumulators.CountAccumulateFunction) Person(org.drools.modelcompiler.domain.Person) DSL.from(org.drools.model.DSL.from) ModelImpl(org.drools.model.impl.ModelImpl) HashMap(java.util.HashMap) PatternDSL(org.drools.model.PatternDSL) ArrayList(java.util.ArrayList) Function1(org.drools.model.functions.Function1) CollectListAccumulateFunction(org.drools.core.base.accumulators.CollectListAccumulateFunction) RuleEventManager(org.kie.internal.event.rule.RuleEventManager) KieBase(org.kie.api.KieBase) LinkedHashSet(java.util.LinkedHashSet) Model(org.drools.model.Model) IntegerMaxAccumulateFunction(org.drools.core.base.accumulators.IntegerMaxAccumulateFunction) ReteEvaluator(org.drools.core.common.ReteEvaluator) ViewItem(org.drools.model.view.ViewItem) Pair(org.apache.commons.math3.util.Pair) Variable(org.drools.model.Variable) D(org.drools.modelcompiler.dsl.pattern.D) IntegerSumAccumulateFunction(org.drools.core.base.accumulators.IntegerSumAccumulateFunction) ToIntFunction(java.util.function.ToIntFunction) Assert.assertTrue(org.junit.Assert.assertTrue) Test(org.junit.Test) GroupKey(org.drools.model.functions.accumulate.GroupKey) FactHandle(org.kie.api.runtime.rule.FactHandle) KieBaseBuilder(org.drools.modelcompiler.builder.KieBaseBuilder) Assert.assertNull(org.junit.Assert.assertNull) Ignore(org.junit.Ignore) Rule(org.drools.model.Rule) ExprViewItem(org.drools.model.view.ExprViewItem) RuleTerminalNodeLeftTuple(org.drools.core.reteoo.RuleTerminalNodeLeftTuple) Collections(java.util.Collections) Assert.assertEquals(org.junit.Assert.assertEquals) LinkedHashSet(java.util.LinkedHashSet) PatternDSL(org.drools.model.PatternDSL) Set(java.util.Set) LinkedHashSet(java.util.LinkedHashSet) Variable(org.drools.model.Variable) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) Model(org.drools.model.Model) KieSession(org.kie.api.runtime.KieSession) Rule(org.drools.model.Rule) ModelImpl(org.drools.model.impl.ModelImpl) Person(org.drools.modelcompiler.domain.Person) Test(org.junit.Test)

Aggregations

CountAccumulateFunction (org.drools.core.base.accumulators.CountAccumulateFunction)9 HashMap (java.util.HashMap)8 Map (java.util.Map)8 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)8 Model (org.drools.model.Model)8 Rule (org.drools.model.Rule)8 ConsequenceBuilder (org.drools.model.consequences.ConsequenceBuilder)8 ModelImpl (org.drools.model.impl.ModelImpl)8 ExprViewItem (org.drools.model.view.ExprViewItem)8 ViewItem (org.drools.model.view.ViewItem)8 ArrayList (java.util.ArrayList)7 Declaration (org.drools.core.rule.Declaration)7 Accumulator (org.drools.core.spi.Accumulator)7 PatternDSL (org.drools.model.PatternDSL)7 Variable (org.drools.model.Variable)7 Arrays (java.util.Arrays)6 Collections (java.util.Collections)6 LinkedHashSet (java.util.LinkedHashSet)6 List (java.util.List)6 Objects (java.util.Objects)6