Search in sources :

Example 1 with DMNImportPMMLInfo

use of org.kie.dmn.core.pmml.DMNImportPMMLInfo in project drools by kiegroup.

the class DMNRuntimePMMLTest method testMultiOutputs.

@Test
public void testMultiOutputs() {
    final DMNRuntime runtime = DMNRuntimeUtil.createRuntimeWithAdditionalResources("KiePMMLRegressionClax.dmn", DMNRuntimePMMLTest.class, "test_regression_clax.pmml");
    final DMNModel dmnModel = runtime.getModel("http://www.trisotech.com/definitions/_ca466dbe-20b4-4e88-a43f-4ce3aff26e4f", "KiePMMLRegressionClax");
    assertThat(dmnModel, notNullValue());
    assertThat(DMNRuntimeUtil.formatMessages(dmnModel.getMessages()), dmnModel.hasErrors(), is(false));
    final DMNContext dmnContext = DMNFactory.newContext();
    dmnContext.set("fld1", 1.0);
    dmnContext.set("fld2", 1.0);
    dmnContext.set("fld3", "x");
    final DMNResult dmnResult = runtime.evaluateAll(dmnModel, dmnContext);
    LOG.debug("{}", dmnResult);
    assertThat(DMNRuntimeUtil.formatMessages(dmnResult.getMessages()), dmnResult.hasErrors(), is(false));
    final DMNContext resultContext = dmnResult.getContext();
    final Map<String, Object> result = (Map<String, Object>) resultContext.get("my decision");
    assertEquals("catD", (String) result.get("RegOut"));
    assertEquals(0.8279559384018024, ((BigDecimal) result.get("RegProb")).doubleValue(), COMPARISON_DELTA);
    assertEquals(0.0022681396056233208, ((BigDecimal) result.get("RegProbA")).doubleValue(), COMPARISON_DELTA);
    DMNType dmnFEELNumber = ((DMNModelImpl) dmnModel).getTypeRegistry().resolveType(dmnModel.getDefinitions().getURIFEEL(), BuiltInType.NUMBER.getName());
    DMNType dmnFEELString = ((DMNModelImpl) dmnModel).getTypeRegistry().resolveType(dmnModel.getDefinitions().getURIFEEL(), BuiltInType.STRING.getName());
    // additional import info.
    Map<String, DMNImportPMMLInfo> pmmlImportInfo = ((DMNModelImpl) dmnModel).getPmmlImportInfo();
    assertThat(pmmlImportInfo.keySet(), hasSize(1));
    DMNImportPMMLInfo p0 = pmmlImportInfo.values().iterator().next();
    assertThat(p0.getImportName(), is("test_regression_clax"));
    assertThat(p0.getModels(), hasSize(1));
    DMNPMMLModelInfo m0 = p0.getModels().iterator().next();
    assertThat(m0.getName(), is("LinReg"));
    Map<String, DMNType> inputFields = m0.getInputFields();
    SimpleTypeImpl fld1 = (SimpleTypeImpl) inputFields.get("fld1");
    assertEquals("test_regression_clax", fld1.getNamespace());
    assertEquals(BuiltInType.NUMBER, fld1.getFeelType());
    assertEquals(dmnFEELNumber, fld1.getBaseType());
    SimpleTypeImpl fld2 = (SimpleTypeImpl) inputFields.get("fld2");
    assertEquals("test_regression_clax", fld2.getNamespace());
    assertEquals(BuiltInType.NUMBER, fld2.getFeelType());
    assertEquals(dmnFEELNumber, fld2.getBaseType());
    SimpleTypeImpl fld3 = (SimpleTypeImpl) inputFields.get("fld3");
    assertEquals("test_regression_clax", fld3.getNamespace());
    assertEquals(BuiltInType.STRING, fld3.getFeelType());
    assertEquals(dmnFEELString, fld3.getBaseType());
    Map<String, DMNType> outputFields = m0.getOutputFields();
    CompositeTypeImpl output = (CompositeTypeImpl) outputFields.get("LinReg");
    assertEquals("test_regression_clax", output.getNamespace());
    Map<String, DMNType> fields = output.getFields();
    SimpleTypeImpl regOut = (SimpleTypeImpl) fields.get("RegOut");
    assertEquals("test_regression_clax", regOut.getNamespace());
    assertEquals(BuiltInType.STRING, regOut.getFeelType());
    assertEquals(dmnFEELString, regOut.getBaseType());
    SimpleTypeImpl regProb = (SimpleTypeImpl) fields.get("RegProb");
    assertEquals("test_regression_clax", regProb.getNamespace());
    assertEquals(BuiltInType.NUMBER, regProb.getFeelType());
    assertEquals(dmnFEELNumber, regProb.getBaseType());
    SimpleTypeImpl regProbA = (SimpleTypeImpl) fields.get("RegProbA");
    assertEquals("test_regression_clax", regProbA.getNamespace());
    assertEquals(BuiltInType.NUMBER, regProbA.getFeelType());
    assertEquals(dmnFEELNumber, regProbA.getBaseType());
}
Also used : DMNResult(org.kie.dmn.api.core.DMNResult) DMNContext(org.kie.dmn.api.core.DMNContext) DMNPMMLModelInfo(org.kie.dmn.core.pmml.DMNPMMLModelInfo) DMNModelImpl(org.kie.dmn.core.impl.DMNModelImpl) DMNRuntime(org.kie.dmn.api.core.DMNRuntime) SimpleTypeImpl(org.kie.dmn.core.impl.SimpleTypeImpl) DMNImportPMMLInfo(org.kie.dmn.core.pmml.DMNImportPMMLInfo) Map(java.util.Map) DMNModel(org.kie.dmn.api.core.DMNModel) CompositeTypeImpl(org.kie.dmn.core.impl.CompositeTypeImpl) DMNType(org.kie.dmn.api.core.DMNType) Test(org.junit.Test)

Example 2 with DMNImportPMMLInfo

use of org.kie.dmn.core.pmml.DMNImportPMMLInfo in project drools by kiegroup.

the class DMNRuntimePMMLTest method runDMNModelInvokingPMML.

static void runDMNModelInvokingPMML(final DMNRuntime runtime) {
    final DMNModel dmnModel = runtime.getModel("http://www.trisotech.com/definitions/_ca466dbe-20b4-4e88-a43f-4ce3aff26e4f", "KiePMMLScoreCard");
    assertThat(dmnModel, notNullValue());
    assertThat(DMNRuntimeUtil.formatMessages(dmnModel.getMessages()), dmnModel.hasErrors(), is(false));
    final DMNContext emptyContext = DMNFactory.newContext();
    final DMNResult dmnResult = runtime.evaluateAll(dmnModel, emptyContext);
    LOG.debug("{}", dmnResult);
    assertThat(DMNRuntimeUtil.formatMessages(dmnResult.getMessages()), dmnResult.hasErrors(), is(false));
    final DMNContext result = dmnResult.getContext();
    assertThat(result.get("my decision"), is(new BigDecimal("41.345")));
    // additional import info.
    Map<String, DMNImportPMMLInfo> pmmlImportInfo = ((DMNModelImpl) dmnModel).getPmmlImportInfo();
    assertThat(pmmlImportInfo.keySet(), hasSize(1));
    DMNImportPMMLInfo p0 = pmmlImportInfo.values().iterator().next();
    assertThat(p0.getImportName(), is("iris"));
    assertThat(p0.getModels(), hasSize(1));
    DMNPMMLModelInfo m0 = p0.getModels().iterator().next();
    assertThat(m0.getName(), is("Sample Score"));
    assertThat(m0.getInputFields(), hasEntry(is("age"), anything()));
    assertThat(m0.getInputFields(), hasEntry(is("occupation"), anything()));
    assertThat(m0.getInputFields(), hasEntry(is("residenceState"), anything()));
    assertThat(m0.getInputFields(), hasEntry(is("validLicense"), anything()));
    assertThat(m0.getInputFields(), not(hasEntry(is("overallScore"), anything())));
    assertThat(m0.getInputFields(), not(hasEntry(is("calculatedScore"), anything())));
    assertThat(m0.getOutputFields(), hasEntry(is("calculatedScore"), anything()));
}
Also used : DMNResult(org.kie.dmn.api.core.DMNResult) DMNContext(org.kie.dmn.api.core.DMNContext) DMNPMMLModelInfo(org.kie.dmn.core.pmml.DMNPMMLModelInfo) DMNImportPMMLInfo(org.kie.dmn.core.pmml.DMNImportPMMLInfo) DMNModelImpl(org.kie.dmn.core.impl.DMNModelImpl) DMNModel(org.kie.dmn.api.core.DMNModel) BigDecimal(java.math.BigDecimal)

Example 3 with DMNImportPMMLInfo

use of org.kie.dmn.core.pmml.DMNImportPMMLInfo in project drools by kiegroup.

the class DMNEvaluatorCompiler method compileFunctionDefinitionPMML.

private DMNExpressionEvaluator compileFunctionDefinitionPMML(DMNCompilerContext ctx, DMNModelImpl model, DMNBaseNode node, String functionName, FunctionDefinition funcDef) {
    if (funcDef.getExpression() instanceof Context) {
        Context context = (Context) funcDef.getExpression();
        String pmmlDocument = null;
        String pmmlModel = null;
        for (ContextEntry ce : context.getContextEntry()) {
            if (ce.getVariable() != null && ce.getVariable().getName() != null && ce.getExpression() instanceof LiteralExpression) {
                LiteralExpression ceLitExpr = (LiteralExpression) ce.getExpression();
                if (ce.getVariable().getName().equals("document")) {
                    if (ceLitExpr.getText() != null) {
                        pmmlDocument = stripQuotes(ceLitExpr.getText().trim());
                    }
                } else if (ce.getVariable().getName().equals("model")) {
                    if (ceLitExpr.getText() != null) {
                        pmmlModel = stripQuotes(ceLitExpr.getText().trim());
                    }
                }
            }
        }
        final String nameLookup = pmmlDocument;
        Optional<Import> lookupImport = model.getDefinitions().getImport().stream().filter(x -> x.getName().equals(nameLookup)).findFirst();
        if (lookupImport.isPresent()) {
            Import theImport = lookupImport.get();
            logger.trace("theImport: {}", theImport);
            Resource pmmlResource = DMNCompilerImpl.resolveRelativeResource(getRootClassLoader(), model, theImport, funcDef, ctx.getRelativeResolver());
            logger.trace("pmmlResource: {}", pmmlResource);
            DMNImportPMMLInfo pmmlInfo = model.getPmmlImportInfo().get(pmmlDocument);
            logger.trace("pmmlInfo: {}", pmmlInfo);
            if (pmmlModel == null || pmmlModel.isEmpty()) {
                List<String> pmmlModelNames = pmmlInfo.getModels().stream().map(PMMLModelInfo::getName).filter(x -> x != null).collect(Collectors.toList());
                if (pmmlModelNames.size() > 0) {
                    MsgUtil.reportMessage(logger, DMNMessage.Severity.WARN, funcDef, model, null, null, Msg.FUNC_DEF_PMML_MISSING_MODEL_NAME, pmmlModelNames.stream().collect(Collectors.joining(",")));
                }
            }
            AbstractPMMLInvocationEvaluator invoker = PMMLInvocationEvaluatorFactory.newInstance(model, getRootClassLoader(), funcDef, pmmlResource, pmmlModel, pmmlInfo);
            DMNFunctionDefinitionEvaluator func = new DMNFunctionDefinitionEvaluator(node, funcDef);
            for (InformationItem p : funcDef.getFormalParameter()) {
                DMNCompilerHelper.checkVariableName(model, p, p.getName());
                DMNType dmnType = compiler.resolveTypeRef(model, p, p, p.getTypeRef());
                func.addParameter(p.getName(), dmnType);
                invoker.addParameter(p.getName(), dmnType);
            }
            func.setEvaluator(invoker);
            return func;
        } else {
            MsgUtil.reportMessage(logger, DMNMessage.Severity.ERROR, funcDef, model, null, null, Msg.FUNC_DEF_PMML_MISSING_ENTRY, functionName, node.getIdentifierString());
        }
    } else {
        // error, PMML function definitions require a context
        MsgUtil.reportMessage(logger, DMNMessage.Severity.ERROR, funcDef, model, null, null, Msg.FUNC_DEF_BODY_NOT_CONTEXT, node.getIdentifierString());
    }
    return new DMNFunctionDefinitionEvaluator(node, funcDef);
}
Also used : Context(org.kie.dmn.model.api.Context) PMMLModelInfo(org.kie.dmn.core.pmml.PMMLModelInfo) DMNConditionalEvaluator(org.kie.dmn.core.ast.DMNConditionalEvaluator) DecisionTable(org.kie.dmn.model.api.DecisionTable) DMNMessage(org.kie.dmn.api.core.DMNMessage) Quantified(org.kie.dmn.model.api.Quantified) LoggerFactory(org.slf4j.LoggerFactory) DMNExpressionEvaluator(org.kie.dmn.core.api.DMNExpressionEvaluator) LiteralExpression(org.kie.dmn.model.api.LiteralExpression) DMNElement(org.kie.dmn.model.api.DMNElement) DTDecisionRule(org.kie.dmn.feel.runtime.decisiontables.DTDecisionRule) EvaluatorResult(org.kie.dmn.core.api.EvaluatorResult) DMNIteratorEvaluator(org.kie.dmn.core.ast.DMNIteratorEvaluator) UnaryTest(org.kie.dmn.feel.runtime.UnaryTest) DMNNode(org.kie.dmn.api.core.ast.DMNNode) BaseDMNTypeImpl(org.kie.dmn.core.impl.BaseDMNTypeImpl) OutputClause(org.kie.dmn.model.api.OutputClause) DMNModelImpl(org.kie.dmn.core.impl.DMNModelImpl) Import(org.kie.dmn.model.api.Import) UUID(java.util.UUID) FunctionKind(org.kie.dmn.model.api.FunctionKind) DMNRelationEvaluator(org.kie.dmn.core.ast.DMNRelationEvaluator) Collectors(java.util.stream.Collectors) BusinessKnowledgeModelNode(org.kie.dmn.api.core.ast.BusinessKnowledgeModelNode) HitPolicy(org.kie.dmn.model.api.HitPolicy) Objects(java.util.Objects) Resource(org.kie.api.io.Resource) List(java.util.List) DMNDTExpressionEvaluator(org.kie.dmn.core.ast.DMNDTExpressionEvaluator) Filter(org.kie.dmn.model.api.Filter) CompiledExpression(org.kie.dmn.feel.lang.CompiledExpression) Expression(org.kie.dmn.model.api.Expression) Entry(java.util.Map.Entry) Optional(java.util.Optional) QName(javax.xml.namespace.QName) InformationItem(org.kie.dmn.model.api.InformationItem) Iterator(org.kie.dmn.model.api.Iterator) DMNLiteralExpressionEvaluator(org.kie.dmn.core.ast.DMNLiteralExpressionEvaluator) RootExecutionFrame(org.kie.dmn.feel.lang.impl.RootExecutionFrame) Relation(org.kie.dmn.model.api.Relation) FEEL(org.kie.dmn.feel.FEEL) MsgUtil(org.kie.dmn.core.util.MsgUtil) DMNType(org.kie.dmn.api.core.DMNType) DMNContextEvaluator(org.kie.dmn.core.ast.DMNContextEvaluator) Binding(org.kie.dmn.model.api.Binding) InputClause(org.kie.dmn.model.api.InputClause) DTOutputClause(org.kie.dmn.feel.runtime.decisiontables.DTOutputClause) EvaluatorResultImpl(org.kie.dmn.core.ast.EvaluatorResultImpl) DTInputClause(org.kie.dmn.feel.runtime.decisiontables.DTInputClause) ArrayList(java.util.ArrayList) DecisionRule(org.kie.dmn.model.api.DecisionRule) DecisionNode(org.kie.dmn.api.core.ast.DecisionNode) DMNFilterEvaluator(org.kie.dmn.core.ast.DMNFilterEvaluator) FEELFunction(org.kie.dmn.feel.runtime.FEELFunction) For(org.kie.dmn.model.api.For) DMNBaseNode(org.kie.dmn.core.ast.DMNBaseNode) Decision(org.kie.dmn.model.api.Decision) FunctionDefinition(org.kie.dmn.model.api.FunctionDefinition) DMNInvocationEvaluator(org.kie.dmn.core.ast.DMNInvocationEvaluator) Logger(org.slf4j.Logger) DMNListEvaluator(org.kie.dmn.core.ast.DMNListEvaluator) AbstractPMMLInvocationEvaluator(org.kie.dmn.core.pmml.AbstractPMMLInvocationEvaluator) DTInvokerFunction(org.kie.dmn.feel.runtime.functions.DTInvokerFunction) ContextEntry(org.kie.dmn.model.api.ContextEntry) Invocation(org.kie.dmn.model.api.Invocation) DMNAlphaNetworkEvaluatorCompiler(org.kie.dmn.core.compiler.alphanetbased.DMNAlphaNetworkEvaluatorCompiler) Collectors.toList(java.util.stream.Collectors.toList) PMMLInvocationEvaluatorFactory(org.kie.dmn.core.pmml.AbstractPMMLInvocationEvaluator.PMMLInvocationEvaluatorFactory) DMNImportPMMLInfo(org.kie.dmn.core.pmml.DMNImportPMMLInfo) BusinessKnowledgeModel(org.kie.dmn.model.api.BusinessKnowledgeModel) DMNFunctionDefinitionEvaluator(org.kie.dmn.core.ast.DMNFunctionDefinitionEvaluator) UnaryTests(org.kie.dmn.model.api.UnaryTests) CompositeTypeImpl(org.kie.dmn.core.impl.CompositeTypeImpl) Conditional(org.kie.dmn.model.api.Conditional) Msg(org.kie.dmn.core.util.Msg) DecisionTableImpl(org.kie.dmn.feel.runtime.decisiontables.DecisionTableImpl) Collections(java.util.Collections) Context(org.kie.dmn.model.api.Context) BaseFEELFunction(org.kie.dmn.feel.runtime.functions.BaseFEELFunction) Import(org.kie.dmn.model.api.Import) LiteralExpression(org.kie.dmn.model.api.LiteralExpression) Resource(org.kie.api.io.Resource) InformationItem(org.kie.dmn.model.api.InformationItem) ContextEntry(org.kie.dmn.model.api.ContextEntry) AbstractPMMLInvocationEvaluator(org.kie.dmn.core.pmml.AbstractPMMLInvocationEvaluator) PMMLModelInfo(org.kie.dmn.core.pmml.PMMLModelInfo) DMNFunctionDefinitionEvaluator(org.kie.dmn.core.ast.DMNFunctionDefinitionEvaluator) DMNImportPMMLInfo(org.kie.dmn.core.pmml.DMNImportPMMLInfo) DMNType(org.kie.dmn.api.core.DMNType)

Aggregations

DMNModelImpl (org.kie.dmn.core.impl.DMNModelImpl)3 DMNImportPMMLInfo (org.kie.dmn.core.pmml.DMNImportPMMLInfo)3 DMNContext (org.kie.dmn.api.core.DMNContext)2 DMNModel (org.kie.dmn.api.core.DMNModel)2 DMNResult (org.kie.dmn.api.core.DMNResult)2 DMNType (org.kie.dmn.api.core.DMNType)2 CompositeTypeImpl (org.kie.dmn.core.impl.CompositeTypeImpl)2 DMNPMMLModelInfo (org.kie.dmn.core.pmml.DMNPMMLModelInfo)2 BigDecimal (java.math.BigDecimal)1 ArrayList (java.util.ArrayList)1 Collections (java.util.Collections)1 List (java.util.List)1 Map (java.util.Map)1 Entry (java.util.Map.Entry)1 Objects (java.util.Objects)1 Optional (java.util.Optional)1 UUID (java.util.UUID)1 Collectors (java.util.stream.Collectors)1 Collectors.toList (java.util.stream.Collectors.toList)1 QName (javax.xml.namespace.QName)1