Search in sources :

Example 1 with Function1

use of org.drools.model.functions.Function1 in project drools by kiegroup.

the class GroupByTest method testWithNull.

@Test
public void testWithNull() {
    Variable<MyType> var = D.declarationOf(MyType.class);
    Variable<MyType> groupKey = D.declarationOf(MyType.class);
    Variable<Long> count = D.declarationOf(Long.class);
    AtomicInteger mappingFunctionCallCounter = new AtomicInteger(0);
    Function1<MyType, MyType> mappingFunction = (a) -> {
        mappingFunctionCallCounter.incrementAndGet();
        return a.getNested();
    };
    D.PatternDef<MyType> onlyOnesWithNested = D.pattern(var).expr(myType -> myType.getNested() != null);
    ExprViewItem groupBy = D.groupBy(onlyOnesWithNested, var, groupKey, mappingFunction, D.accFunction(CountAccumulateFunction::new).as(count));
    List<MyType> result = new ArrayList<>();
    Rule rule = D.rule("R").build(groupBy, D.on(groupKey, count).execute((drools, key, acc) -> result.add(key)));
    Model model = new ModelImpl().addRule(rule);
    KieBase kieBase = KieBaseBuilder.createKieBaseFromModel(model);
    MyType objectWithoutNestedObject = new MyType(null);
    MyType objectWithNestedObject = new MyType(objectWithoutNestedObject);
    KieSession ksession = kieBase.newKieSession();
    ksession.insert(objectWithNestedObject);
    ksession.insert(objectWithoutNestedObject);
    ksession.fireAllRules();
    // Side issue: this number is unusually high. Perhaps we should try to implement some cache for this?
    System.out.println("GroupKey mapping function was called " + mappingFunctionCallCounter.get() + " times.");
    Assertions.assertThat(result).containsOnly(objectWithoutNestedObject);
}
Also used : Arrays(java.util.Arrays) InternalFactHandle(org.drools.core.common.InternalFactHandle) Accumulator(org.drools.core.spi.Accumulator) Assertions.assertThat(org.assertj.core.api.Assertions.assertThat) Global(org.drools.model.Global) RuleEventListener(org.kie.internal.event.rule.RuleEventListener) DSL(org.drools.model.DSL) Child(org.drools.modelcompiler.domain.Child) Match(org.kie.api.runtime.rule.Match) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) Map(java.util.Map) Declaration(org.drools.core.rule.Declaration) Assertions(org.assertj.core.api.Assertions) KieSession(org.kie.api.runtime.KieSession) Parent(org.drools.modelcompiler.domain.Parent) Set(java.util.Set) Index(org.drools.model.Index) ConsequenceBuilder(org.drools.model.consequences.ConsequenceBuilder) EvaluationUtil(org.drools.modelcompiler.util.EvaluationUtil) Objects(java.util.Objects) List(java.util.List) Tuple(org.drools.core.spi.Tuple) CountAccumulateFunction(org.drools.core.base.accumulators.CountAccumulateFunction) Person(org.drools.modelcompiler.domain.Person) DSL.from(org.drools.model.DSL.from) ModelImpl(org.drools.model.impl.ModelImpl) HashMap(java.util.HashMap) PatternDSL(org.drools.model.PatternDSL) ArrayList(java.util.ArrayList) Function1(org.drools.model.functions.Function1) CollectListAccumulateFunction(org.drools.core.base.accumulators.CollectListAccumulateFunction) RuleEventManager(org.kie.internal.event.rule.RuleEventManager) KieBase(org.kie.api.KieBase) LinkedHashSet(java.util.LinkedHashSet) Model(org.drools.model.Model) IntegerMaxAccumulateFunction(org.drools.core.base.accumulators.IntegerMaxAccumulateFunction) ReteEvaluator(org.drools.core.common.ReteEvaluator) ViewItem(org.drools.model.view.ViewItem) Pair(org.apache.commons.math3.util.Pair) Variable(org.drools.model.Variable) D(org.drools.modelcompiler.dsl.pattern.D) IntegerSumAccumulateFunction(org.drools.core.base.accumulators.IntegerSumAccumulateFunction) ToIntFunction(java.util.function.ToIntFunction) Assert.assertTrue(org.junit.Assert.assertTrue) Test(org.junit.Test) GroupKey(org.drools.model.functions.accumulate.GroupKey) FactHandle(org.kie.api.runtime.rule.FactHandle) KieBaseBuilder(org.drools.modelcompiler.builder.KieBaseBuilder) Assert.assertNull(org.junit.Assert.assertNull) Ignore(org.junit.Ignore) Rule(org.drools.model.Rule) ExprViewItem(org.drools.model.view.ExprViewItem) RuleTerminalNodeLeftTuple(org.drools.core.reteoo.RuleTerminalNodeLeftTuple) Collections(java.util.Collections) Assert.assertEquals(org.junit.Assert.assertEquals) ExprViewItem(org.drools.model.view.ExprViewItem) D(org.drools.modelcompiler.dsl.pattern.D) ArrayList(java.util.ArrayList) CountAccumulateFunction(org.drools.core.base.accumulators.CountAccumulateFunction) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) KieBase(org.kie.api.KieBase) Model(org.drools.model.Model) KieSession(org.kie.api.runtime.KieSession) Rule(org.drools.model.Rule) ModelImpl(org.drools.model.impl.ModelImpl) Test(org.junit.Test)

Example 2 with Function1

use of org.drools.model.functions.Function1 in project drools by kiegroup.

the class GroupByTest method testTwoGroupByUsingBindings.

@Test
// FIXME This does not work, because Declaration only works with function1
@Ignore
public void testTwoGroupByUsingBindings() {
    // DROOLS-5697
    Global<Map> var_results = D.globalOf(Map.class, "defaultpkg", "results");
    Variable<String> var_$key_1 = D.declarationOf(String.class);
    Variable<Person> var_$p = D.declarationOf(Person.class);
    Variable<Integer> var_$age = D.declarationOf(Integer.class);
    Variable<Integer> var_$sumOfAges = D.declarationOf(Integer.class);
    // "$g1", D.from(var_$key_1, var_$sumOfAges, ($k, $v) -> new Group($k, $v)));
    Variable<Group> var_$g1 = D.declarationOf(Group.class);
    Variable<Integer> var_$g1_value = D.declarationOf(Integer.class);
    Variable<String> var_$key_2 = D.declarationOf(String.class);
    Variable<Integer> var_$maxOfValues = D.declarationOf(Integer.class);
    Rule rule1 = D.rule("R1").build(D.groupBy(D.and(D.groupBy(D.pattern(var_$p).bind(var_$age, person -> person.getAge()), var_$p, var_$key_1, person -> person.getName().substring(0, 3), D.accFunction(IntegerSumAccumulateFunction::new, var_$age).as(var_$sumOfAges)), // Currently this does not work
    D.pattern(var_$key_1).bind(var_$g1, var_$sumOfAges, ($k, $v) -> new Group($k, $v)), D.pattern(var_$g1).bind(var_$g1_value, group -> (Integer) group.getValue())), var_$g1, var_$key_2, groupResult -> ((String) groupResult.getKey()).substring(0, 2), D.accFunction(IntegerMaxAccumulateFunction::new, var_$g1_value).as(var_$maxOfValues)), D.on(var_$key_2, var_results, var_$maxOfValues).execute(($key, results, $maxOfValues) -> {
        System.out.println($key + " -> " + $maxOfValues);
        results.put($key, $maxOfValues);
    }));
    Model model = new ModelImpl().addRule(rule1).addGlobal(var_results);
    KieSession ksession = KieBaseBuilder.createKieBaseFromModel(model).newKieSession();
    Map results = new HashMap();
    ksession.setGlobal("results", results);
    ksession.insert(new Person("Mark", 42));
    ksession.insert(new Person("Edoardo", 33));
    FactHandle meFH = ksession.insert(new Person("Mario", 45));
    ksession.insert(new Person("Maciej", 39));
    ksession.insert(new Person("Edson", 38));
    FactHandle geoffreyFH = ksession.insert(new Person("Geoffrey", 35));
    ksession.fireAllRules();
    System.out.println("-----");
    /*
         * In the first groupBy:
         *   Mark+Mario become "(Mar, 87)"
         *   Maciej becomes "(Mac, 39)"
         *   Geoffrey becomes "(Geo, 35)"
         *   Edson becomes "(Eds, 38)"
         *   Edoardo becomes "(Edo, 33)"
         *
         * Then in the second groupBy:
         *   "(Mar, 87)" and "(Mac, 39)" become "(Ma, 87)"
         *   "(Eds, 38)" and "(Edo, 33)" become "(Ed, 38)"
         *   "(Geo, 35)" becomes "(Ge, 35)"
         */
    assertEquals(3, results.size());
    assertEquals(87, results.get("Ma"));
    assertEquals(38, results.get("Ed"));
    assertEquals(35, results.get("Ge"));
    results.clear();
    ksession.delete(meFH);
    ksession.fireAllRules();
    System.out.println("-----");
    // No Mario anymore, so "(Mar, 42)" instead of "(Mar, 87)".
    // Therefore "(Ma, 42)".
    assertEquals(1, results.size());
    assertEquals(42, results.get("Ma"));
    results.clear();
    // "(Geo, 35)" is gone.
    // "(Mat, 38)" is added, but Mark still wins, so "(Ma, 42)" stays.
    ksession.delete(geoffreyFH);
    ksession.insert(new Person("Matteo", 38));
    ksession.fireAllRules();
    assertEquals(1, results.size());
    assertEquals(42, results.get("Ma"));
}
Also used : Arrays(java.util.Arrays) InternalFactHandle(org.drools.core.common.InternalFactHandle) Accumulator(org.drools.core.spi.Accumulator) Assertions.assertThat(org.assertj.core.api.Assertions.assertThat) Global(org.drools.model.Global) RuleEventListener(org.kie.internal.event.rule.RuleEventListener) DSL(org.drools.model.DSL) Child(org.drools.modelcompiler.domain.Child) Match(org.kie.api.runtime.rule.Match) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) Map(java.util.Map) Declaration(org.drools.core.rule.Declaration) Assertions(org.assertj.core.api.Assertions) KieSession(org.kie.api.runtime.KieSession) Parent(org.drools.modelcompiler.domain.Parent) Set(java.util.Set) Index(org.drools.model.Index) ConsequenceBuilder(org.drools.model.consequences.ConsequenceBuilder) EvaluationUtil(org.drools.modelcompiler.util.EvaluationUtil) Objects(java.util.Objects) List(java.util.List) Tuple(org.drools.core.spi.Tuple) CountAccumulateFunction(org.drools.core.base.accumulators.CountAccumulateFunction) Person(org.drools.modelcompiler.domain.Person) DSL.from(org.drools.model.DSL.from) ModelImpl(org.drools.model.impl.ModelImpl) HashMap(java.util.HashMap) PatternDSL(org.drools.model.PatternDSL) ArrayList(java.util.ArrayList) Function1(org.drools.model.functions.Function1) CollectListAccumulateFunction(org.drools.core.base.accumulators.CollectListAccumulateFunction) RuleEventManager(org.kie.internal.event.rule.RuleEventManager) KieBase(org.kie.api.KieBase) LinkedHashSet(java.util.LinkedHashSet) Model(org.drools.model.Model) IntegerMaxAccumulateFunction(org.drools.core.base.accumulators.IntegerMaxAccumulateFunction) ReteEvaluator(org.drools.core.common.ReteEvaluator) ViewItem(org.drools.model.view.ViewItem) Pair(org.apache.commons.math3.util.Pair) Variable(org.drools.model.Variable) D(org.drools.modelcompiler.dsl.pattern.D) IntegerSumAccumulateFunction(org.drools.core.base.accumulators.IntegerSumAccumulateFunction) ToIntFunction(java.util.function.ToIntFunction) Assert.assertTrue(org.junit.Assert.assertTrue) Test(org.junit.Test) GroupKey(org.drools.model.functions.accumulate.GroupKey) FactHandle(org.kie.api.runtime.rule.FactHandle) KieBaseBuilder(org.drools.modelcompiler.builder.KieBaseBuilder) Assert.assertNull(org.junit.Assert.assertNull) Ignore(org.junit.Ignore) Rule(org.drools.model.Rule) ExprViewItem(org.drools.model.view.ExprViewItem) RuleTerminalNodeLeftTuple(org.drools.core.reteoo.RuleTerminalNodeLeftTuple) Collections(java.util.Collections) Assert.assertEquals(org.junit.Assert.assertEquals) HashMap(java.util.HashMap) InternalFactHandle(org.drools.core.common.InternalFactHandle) FactHandle(org.kie.api.runtime.rule.FactHandle) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) Model(org.drools.model.Model) KieSession(org.kie.api.runtime.KieSession) Rule(org.drools.model.Rule) ModelImpl(org.drools.model.impl.ModelImpl) Map(java.util.Map) HashMap(java.util.HashMap) Person(org.drools.modelcompiler.domain.Person) Ignore(org.junit.Ignore) Test(org.junit.Test)

Example 3 with Function1

use of org.drools.model.functions.Function1 in project drools by kiegroup.

the class ToStringTest method testToString.

/**
 * Users may depend on seeing {@link Rule#toString()} in log files giving useful information, in order to understand
 * the rules that are being created. The format is not required to be backwards compatible - this test merely checks
 * that it does not change unknowingly.
 */
@Test
public void testToString() {
    Variable<Person> markV = declarationOf(Person.class);
    Variable<Integer> markAge = declarationOf(Integer.class);
    Variable<Person> olderV = declarationOf(Person.class);
    Variable<Double> resultAvg = declarationOf(Double.class);
    Variable<Integer> age = declarationOf(Integer.class);
    String person = "Mark";
    Function1<Person, String> nameGetter = Person::getName;
    Function1<Person, Integer> ageGetter = Person::getAge;
    Predicate1<Person> markPredicate = p -> p.getName().equals(person);
    PatternDSL.PatternDef<Person> pattern1 = pattern(markV).expr("exprA", markPredicate, alphaIndexedBy(String.class, Index.ConstraintType.EQUAL, 0, nameGetter, person)).bind(markAge, ageGetter);
    Predicate1<Person> notMarkPredicate = markPredicate.negate();
    Predicate2<Person, Integer> agePredicate = (p1, someAge) -> p1.getAge() > someAge;
    Function1<Integer, Integer> ageCaster = int.class::cast;
    PatternDSL.PatternDef<Person> pattern2 = pattern(olderV).expr("exprB", notMarkPredicate, alphaIndexedBy(String.class, Index.ConstraintType.NOT_EQUAL, 1, nameGetter, person)).expr("exprC", markAge, agePredicate, betaIndexedBy(int.class, Index.ConstraintType.GREATER_THAN, 0, ageGetter, ageCaster));
    AccumulateFunction<AverageAccumulateFunction.AverageData> f = new AverageAccumulateFunction();
    Supplier<AccumulateFunction<AverageAccumulateFunction.AverageData>> accumulateSupplier = () -> f;
    org.drools.model.functions.accumulate.AccumulateFunction actualAccumulate = accFunction(accumulateSupplier, age);
    ExprViewItem<Person> accumulate = accumulate(pattern(olderV).expr("exprD", notMarkPredicate).bind(age, ageGetter), actualAccumulate.as(resultAvg));
    Rule rule = rule("beta").build(pattern1, pattern2, accumulate, on(olderV, markV).execute((drools, p1, p2) -> drools.insert(p1.getName() + " is older than " + p2.getName())));
    String pattern1toString = "PatternImpl (type: PATTERN, inputVars: null, " + "outputVar: " + markV + ", " + "constraint: Constraint for 'exprA' (index: AlphaIndex #0 (EQUAL, left: lambda " + System.identityHashCode(nameGetter) + ", right: " + person + ")))";
    String pattern2toString = "PatternImpl (type: PATTERN, inputVars: null, " + "outputVar: " + olderV + ", " + "constraint: MultipleConstraints (constraints: [" + "Constraint for 'exprB' (index: AlphaIndex #1 (NOT_EQUAL, left: lambda " + System.identityHashCode(nameGetter) + ", right: " + person + ")), " + "Constraint for 'exprC' (index: BetaIndex #0 (GREATER_THAN, left: lambda " + System.identityHashCode(ageGetter) + ", right: lambda " + System.identityHashCode(ageCaster) + "))]))";
    String accumulatePatternToString = "PatternImpl (type: PATTERN, inputVars: null, " + "outputVar: " + olderV + ", " + "constraint: Constraint for 'exprD' (index: null))";
    String accumulateToString = "AccumulatePatternImpl (functions: [" + actualAccumulate + "], " + "condition: " + accumulatePatternToString + ", " + "pattern: " + accumulatePatternToString + ")";
    String consequenceToString = "ConsequenceImpl (variables: [" + olderV + ", " + markV + "], language: java, breaking: false)";
    String expectedToString = "Rule: defaultpkg.beta (" + "view: CompositePatterns of AND (vars: null, patterns: [" + pattern1toString + ", " + pattern2toString + ", " + accumulateToString + ", NamedConsequenceImpl 'default' (breaking: false)], " + "consequences: {default=" + consequenceToString + "}), consequences: {default=" + consequenceToString + "})";
    Assertions.assertThat(rule).hasToString(expectedToString);
}
Also used : PatternDSL.rule(org.drools.model.PatternDSL.rule) Variable(org.drools.model.Variable) AverageAccumulateFunction(org.drools.core.base.accumulators.AverageAccumulateFunction) PatternDSL(org.drools.model.PatternDSL) Test(org.junit.Test) DSL.on(org.drools.model.DSL.on) Index(org.drools.model.Index) Predicate1(org.drools.model.functions.Predicate1) Supplier(java.util.function.Supplier) PatternDSL.pattern(org.drools.model.PatternDSL.pattern) Predicate2(org.drools.model.functions.Predicate2) DSL.accumulate(org.drools.model.DSL.accumulate) Function1(org.drools.model.functions.Function1) DSL.accFunction(org.drools.model.DSL.accFunction) PatternDSL.alphaIndexedBy(org.drools.model.PatternDSL.alphaIndexedBy) AccumulateFunction(org.kie.api.runtime.rule.AccumulateFunction) Rule(org.drools.model.Rule) Assertions(org.assertj.core.api.Assertions) ExprViewItem(org.drools.model.view.ExprViewItem) Person(org.drools.modelcompiler.domain.Person) PatternDSL.betaIndexedBy(org.drools.model.PatternDSL.betaIndexedBy) DSL.declarationOf(org.drools.model.DSL.declarationOf) DSL.accumulate(org.drools.model.DSL.accumulate) AverageAccumulateFunction(org.drools.core.base.accumulators.AverageAccumulateFunction) PatternDSL(org.drools.model.PatternDSL) Rule(org.drools.model.Rule) AverageAccumulateFunction(org.drools.core.base.accumulators.AverageAccumulateFunction) AccumulateFunction(org.kie.api.runtime.rule.AccumulateFunction) Person(org.drools.modelcompiler.domain.Person) Test(org.junit.Test)

Example 4 with Function1

use of org.drools.model.functions.Function1 in project drools by kiegroup.

the class KiePackagesBuilder method buildPattern.

private RuleConditionElement buildPattern(RuleContext ctx, GroupElement group, org.drools.model.Pattern<?> modelPattern) {
    Variable patternVariable = modelPattern.getPatternVariable();
    Pattern pattern = addPatternForVariable(ctx, group, patternVariable, modelPattern.getType());
    Arrays.stream(modelPattern.getWatchedProps()).forEach(pattern::addWatchedProperty);
    pattern.setPassive(modelPattern.isPassive());
    for (Binding binding : modelPattern.getBindings()) {
        // FIXME this is returning null for BindViewItem2, BindViewItem3 etc (mdp)
        Function1 f1 = getBindingFunction(ctx, patternVariable, binding);
        Declaration declaration = new Declaration(binding.getBoundVariable().getName(), new LambdaReadAccessor(binding.getBoundVariable().getType(), f1), pattern, true);
        pattern.addDeclaration(declaration);
        if (binding.getReactOn() != null) {
            Arrays.stream(binding.getReactOn()).forEach(pattern::addBoundProperty);
        }
        ctx.addDeclaration(binding.getBoundVariable(), declaration);
    }
    Declaration queryArgDecl = ctx.getQueryDeclaration(patternVariable);
    if (queryArgDecl != null) {
        pattern.addConstraint(new UnificationConstraint(queryArgDecl));
    }
    addConstraintsToPattern(ctx, pattern, modelPattern.getConstraint());
    addReactiveMasksToPattern(pattern, modelPattern.getPatternClassMetadata());
    return pattern;
}
Also used : Binding(org.drools.model.Binding) QueryCallPattern(org.drools.model.patterns.QueryCallPattern) AccumulatePattern(org.drools.model.AccumulatePattern) Pattern(org.drools.core.rule.Pattern) GroupByPattern(org.drools.model.GroupByPattern) PrototypeVariable(org.drools.model.PrototypeVariable) Variable(org.drools.model.Variable) Function1(org.drools.model.functions.Function1) Declaration(org.drools.core.rule.Declaration) WindowDeclaration(org.drools.core.rule.WindowDeclaration) TypeDeclarationUtil.createTypeDeclaration(org.drools.modelcompiler.util.TypeDeclarationUtil.createTypeDeclaration) TypeDeclaration(org.drools.core.rule.TypeDeclaration) LambdaReadAccessor(org.drools.modelcompiler.constraints.LambdaReadAccessor) UnificationConstraint(org.drools.modelcompiler.constraints.UnificationConstraint)

Aggregations

Variable (org.drools.model.Variable)4 Function1 (org.drools.model.functions.Function1)4 Assertions (org.assertj.core.api.Assertions)3 Declaration (org.drools.core.rule.Declaration)3 Index (org.drools.model.Index)3 PatternDSL (org.drools.model.PatternDSL)3 Rule (org.drools.model.Rule)3 ExprViewItem (org.drools.model.view.ExprViewItem)3 ArrayList (java.util.ArrayList)2 Arrays (java.util.Arrays)2 Collections (java.util.Collections)2 HashMap (java.util.HashMap)2 LinkedHashSet (java.util.LinkedHashSet)2 List (java.util.List)2 Map (java.util.Map)2 Objects (java.util.Objects)2 Set (java.util.Set)2 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)2 ToIntFunction (java.util.function.ToIntFunction)2 Pair (org.apache.commons.math3.util.Pair)2