use of org.kie.dmn.core.pmml.DMNImportPMMLInfo in project drools by kiegroup.
the class DMNRuntimePMMLTest method testMultiOutputs.
@Test
public void testMultiOutputs() {
final DMNRuntime runtime = DMNRuntimeUtil.createRuntimeWithAdditionalResources("KiePMMLRegressionClax.dmn", DMNRuntimePMMLTest.class, "test_regression_clax.pmml");
final DMNModel dmnModel = runtime.getModel("http://www.trisotech.com/definitions/_ca466dbe-20b4-4e88-a43f-4ce3aff26e4f", "KiePMMLRegressionClax");
assertThat(dmnModel, notNullValue());
assertThat(DMNRuntimeUtil.formatMessages(dmnModel.getMessages()), dmnModel.hasErrors(), is(false));
final DMNContext dmnContext = DMNFactory.newContext();
dmnContext.set("fld1", 1.0);
dmnContext.set("fld2", 1.0);
dmnContext.set("fld3", "x");
final DMNResult dmnResult = runtime.evaluateAll(dmnModel, dmnContext);
LOG.debug("{}", dmnResult);
assertThat(DMNRuntimeUtil.formatMessages(dmnResult.getMessages()), dmnResult.hasErrors(), is(false));
final DMNContext resultContext = dmnResult.getContext();
final Map<String, Object> result = (Map<String, Object>) resultContext.get("my decision");
assertEquals("catD", (String) result.get("RegOut"));
assertEquals(0.8279559384018024, ((BigDecimal) result.get("RegProb")).doubleValue(), COMPARISON_DELTA);
assertEquals(0.0022681396056233208, ((BigDecimal) result.get("RegProbA")).doubleValue(), COMPARISON_DELTA);
DMNType dmnFEELNumber = ((DMNModelImpl) dmnModel).getTypeRegistry().resolveType(dmnModel.getDefinitions().getURIFEEL(), BuiltInType.NUMBER.getName());
DMNType dmnFEELString = ((DMNModelImpl) dmnModel).getTypeRegistry().resolveType(dmnModel.getDefinitions().getURIFEEL(), BuiltInType.STRING.getName());
// additional import info.
Map<String, DMNImportPMMLInfo> pmmlImportInfo = ((DMNModelImpl) dmnModel).getPmmlImportInfo();
assertThat(pmmlImportInfo.keySet(), hasSize(1));
DMNImportPMMLInfo p0 = pmmlImportInfo.values().iterator().next();
assertThat(p0.getImportName(), is("test_regression_clax"));
assertThat(p0.getModels(), hasSize(1));
DMNPMMLModelInfo m0 = p0.getModels().iterator().next();
assertThat(m0.getName(), is("LinReg"));
Map<String, DMNType> inputFields = m0.getInputFields();
SimpleTypeImpl fld1 = (SimpleTypeImpl) inputFields.get("fld1");
assertEquals("test_regression_clax", fld1.getNamespace());
assertEquals(BuiltInType.NUMBER, fld1.getFeelType());
assertEquals(dmnFEELNumber, fld1.getBaseType());
SimpleTypeImpl fld2 = (SimpleTypeImpl) inputFields.get("fld2");
assertEquals("test_regression_clax", fld2.getNamespace());
assertEquals(BuiltInType.NUMBER, fld2.getFeelType());
assertEquals(dmnFEELNumber, fld2.getBaseType());
SimpleTypeImpl fld3 = (SimpleTypeImpl) inputFields.get("fld3");
assertEquals("test_regression_clax", fld3.getNamespace());
assertEquals(BuiltInType.STRING, fld3.getFeelType());
assertEquals(dmnFEELString, fld3.getBaseType());
Map<String, DMNType> outputFields = m0.getOutputFields();
CompositeTypeImpl output = (CompositeTypeImpl) outputFields.get("LinReg");
assertEquals("test_regression_clax", output.getNamespace());
Map<String, DMNType> fields = output.getFields();
SimpleTypeImpl regOut = (SimpleTypeImpl) fields.get("RegOut");
assertEquals("test_regression_clax", regOut.getNamespace());
assertEquals(BuiltInType.STRING, regOut.getFeelType());
assertEquals(dmnFEELString, regOut.getBaseType());
SimpleTypeImpl regProb = (SimpleTypeImpl) fields.get("RegProb");
assertEquals("test_regression_clax", regProb.getNamespace());
assertEquals(BuiltInType.NUMBER, regProb.getFeelType());
assertEquals(dmnFEELNumber, regProb.getBaseType());
SimpleTypeImpl regProbA = (SimpleTypeImpl) fields.get("RegProbA");
assertEquals("test_regression_clax", regProbA.getNamespace());
assertEquals(BuiltInType.NUMBER, regProbA.getFeelType());
assertEquals(dmnFEELNumber, regProbA.getBaseType());
}
use of org.kie.dmn.core.pmml.DMNImportPMMLInfo in project drools by kiegroup.
the class DMNRuntimePMMLTest method runDMNModelInvokingPMML.
static void runDMNModelInvokingPMML(final DMNRuntime runtime) {
final DMNModel dmnModel = runtime.getModel("http://www.trisotech.com/definitions/_ca466dbe-20b4-4e88-a43f-4ce3aff26e4f", "KiePMMLScoreCard");
assertThat(dmnModel, notNullValue());
assertThat(DMNRuntimeUtil.formatMessages(dmnModel.getMessages()), dmnModel.hasErrors(), is(false));
final DMNContext emptyContext = DMNFactory.newContext();
final DMNResult dmnResult = runtime.evaluateAll(dmnModel, emptyContext);
LOG.debug("{}", dmnResult);
assertThat(DMNRuntimeUtil.formatMessages(dmnResult.getMessages()), dmnResult.hasErrors(), is(false));
final DMNContext result = dmnResult.getContext();
assertThat(result.get("my decision"), is(new BigDecimal("41.345")));
// additional import info.
Map<String, DMNImportPMMLInfo> pmmlImportInfo = ((DMNModelImpl) dmnModel).getPmmlImportInfo();
assertThat(pmmlImportInfo.keySet(), hasSize(1));
DMNImportPMMLInfo p0 = pmmlImportInfo.values().iterator().next();
assertThat(p0.getImportName(), is("iris"));
assertThat(p0.getModels(), hasSize(1));
DMNPMMLModelInfo m0 = p0.getModels().iterator().next();
assertThat(m0.getName(), is("Sample Score"));
assertThat(m0.getInputFields(), hasEntry(is("age"), anything()));
assertThat(m0.getInputFields(), hasEntry(is("occupation"), anything()));
assertThat(m0.getInputFields(), hasEntry(is("residenceState"), anything()));
assertThat(m0.getInputFields(), hasEntry(is("validLicense"), anything()));
assertThat(m0.getInputFields(), not(hasEntry(is("overallScore"), anything())));
assertThat(m0.getInputFields(), not(hasEntry(is("calculatedScore"), anything())));
assertThat(m0.getOutputFields(), hasEntry(is("calculatedScore"), anything()));
}
use of org.kie.dmn.core.pmml.DMNImportPMMLInfo in project drools by kiegroup.
the class DMNEvaluatorCompiler method compileFunctionDefinitionPMML.
private DMNExpressionEvaluator compileFunctionDefinitionPMML(DMNCompilerContext ctx, DMNModelImpl model, DMNBaseNode node, String functionName, FunctionDefinition funcDef) {
if (funcDef.getExpression() instanceof Context) {
Context context = (Context) funcDef.getExpression();
String pmmlDocument = null;
String pmmlModel = null;
for (ContextEntry ce : context.getContextEntry()) {
if (ce.getVariable() != null && ce.getVariable().getName() != null && ce.getExpression() instanceof LiteralExpression) {
LiteralExpression ceLitExpr = (LiteralExpression) ce.getExpression();
if (ce.getVariable().getName().equals("document")) {
if (ceLitExpr.getText() != null) {
pmmlDocument = stripQuotes(ceLitExpr.getText().trim());
}
} else if (ce.getVariable().getName().equals("model")) {
if (ceLitExpr.getText() != null) {
pmmlModel = stripQuotes(ceLitExpr.getText().trim());
}
}
}
}
final String nameLookup = pmmlDocument;
Optional<Import> lookupImport = model.getDefinitions().getImport().stream().filter(x -> x.getName().equals(nameLookup)).findFirst();
if (lookupImport.isPresent()) {
Import theImport = lookupImport.get();
logger.trace("theImport: {}", theImport);
Resource pmmlResource = DMNCompilerImpl.resolveRelativeResource(getRootClassLoader(), model, theImport, funcDef, ctx.getRelativeResolver());
logger.trace("pmmlResource: {}", pmmlResource);
DMNImportPMMLInfo pmmlInfo = model.getPmmlImportInfo().get(pmmlDocument);
logger.trace("pmmlInfo: {}", pmmlInfo);
if (pmmlModel == null || pmmlModel.isEmpty()) {
List<String> pmmlModelNames = pmmlInfo.getModels().stream().map(PMMLModelInfo::getName).filter(x -> x != null).collect(Collectors.toList());
if (pmmlModelNames.size() > 0) {
MsgUtil.reportMessage(logger, DMNMessage.Severity.WARN, funcDef, model, null, null, Msg.FUNC_DEF_PMML_MISSING_MODEL_NAME, pmmlModelNames.stream().collect(Collectors.joining(",")));
}
}
AbstractPMMLInvocationEvaluator invoker = PMMLInvocationEvaluatorFactory.newInstance(model, getRootClassLoader(), funcDef, pmmlResource, pmmlModel, pmmlInfo);
DMNFunctionDefinitionEvaluator func = new DMNFunctionDefinitionEvaluator(node, funcDef);
for (InformationItem p : funcDef.getFormalParameter()) {
DMNCompilerHelper.checkVariableName(model, p, p.getName());
DMNType dmnType = compiler.resolveTypeRef(model, p, p, p.getTypeRef());
func.addParameter(p.getName(), dmnType);
invoker.addParameter(p.getName(), dmnType);
}
func.setEvaluator(invoker);
return func;
} else {
MsgUtil.reportMessage(logger, DMNMessage.Severity.ERROR, funcDef, model, null, null, Msg.FUNC_DEF_PMML_MISSING_ENTRY, functionName, node.getIdentifierString());
}
} else {
// error, PMML function definitions require a context
MsgUtil.reportMessage(logger, DMNMessage.Severity.ERROR, funcDef, model, null, null, Msg.FUNC_DEF_BODY_NOT_CONTEXT, node.getIdentifierString());
}
return new DMNFunctionDefinitionEvaluator(node, funcDef);
}
Aggregations