use of org.kie.dmn.core.pmml.DMNPMMLModelInfo in project drools by kiegroup.
the class DMNRuntimePMMLTest method testMultiOutputs.
@Test
public void testMultiOutputs() {
final DMNRuntime runtime = DMNRuntimeUtil.createRuntimeWithAdditionalResources("KiePMMLRegressionClax.dmn", DMNRuntimePMMLTest.class, "test_regression_clax.pmml");
final DMNModel dmnModel = runtime.getModel("http://www.trisotech.com/definitions/_ca466dbe-20b4-4e88-a43f-4ce3aff26e4f", "KiePMMLRegressionClax");
assertThat(dmnModel, notNullValue());
assertThat(DMNRuntimeUtil.formatMessages(dmnModel.getMessages()), dmnModel.hasErrors(), is(false));
final DMNContext dmnContext = DMNFactory.newContext();
dmnContext.set("fld1", 1.0);
dmnContext.set("fld2", 1.0);
dmnContext.set("fld3", "x");
final DMNResult dmnResult = runtime.evaluateAll(dmnModel, dmnContext);
LOG.debug("{}", dmnResult);
assertThat(DMNRuntimeUtil.formatMessages(dmnResult.getMessages()), dmnResult.hasErrors(), is(false));
final DMNContext resultContext = dmnResult.getContext();
final Map<String, Object> result = (Map<String, Object>) resultContext.get("my decision");
assertEquals("catD", (String) result.get("RegOut"));
assertEquals(0.8279559384018024, ((BigDecimal) result.get("RegProb")).doubleValue(), COMPARISON_DELTA);
assertEquals(0.0022681396056233208, ((BigDecimal) result.get("RegProbA")).doubleValue(), COMPARISON_DELTA);
DMNType dmnFEELNumber = ((DMNModelImpl) dmnModel).getTypeRegistry().resolveType(dmnModel.getDefinitions().getURIFEEL(), BuiltInType.NUMBER.getName());
DMNType dmnFEELString = ((DMNModelImpl) dmnModel).getTypeRegistry().resolveType(dmnModel.getDefinitions().getURIFEEL(), BuiltInType.STRING.getName());
// additional import info.
Map<String, DMNImportPMMLInfo> pmmlImportInfo = ((DMNModelImpl) dmnModel).getPmmlImportInfo();
assertThat(pmmlImportInfo.keySet(), hasSize(1));
DMNImportPMMLInfo p0 = pmmlImportInfo.values().iterator().next();
assertThat(p0.getImportName(), is("test_regression_clax"));
assertThat(p0.getModels(), hasSize(1));
DMNPMMLModelInfo m0 = p0.getModels().iterator().next();
assertThat(m0.getName(), is("LinReg"));
Map<String, DMNType> inputFields = m0.getInputFields();
SimpleTypeImpl fld1 = (SimpleTypeImpl) inputFields.get("fld1");
assertEquals("test_regression_clax", fld1.getNamespace());
assertEquals(BuiltInType.NUMBER, fld1.getFeelType());
assertEquals(dmnFEELNumber, fld1.getBaseType());
SimpleTypeImpl fld2 = (SimpleTypeImpl) inputFields.get("fld2");
assertEquals("test_regression_clax", fld2.getNamespace());
assertEquals(BuiltInType.NUMBER, fld2.getFeelType());
assertEquals(dmnFEELNumber, fld2.getBaseType());
SimpleTypeImpl fld3 = (SimpleTypeImpl) inputFields.get("fld3");
assertEquals("test_regression_clax", fld3.getNamespace());
assertEquals(BuiltInType.STRING, fld3.getFeelType());
assertEquals(dmnFEELString, fld3.getBaseType());
Map<String, DMNType> outputFields = m0.getOutputFields();
CompositeTypeImpl output = (CompositeTypeImpl) outputFields.get("LinReg");
assertEquals("test_regression_clax", output.getNamespace());
Map<String, DMNType> fields = output.getFields();
SimpleTypeImpl regOut = (SimpleTypeImpl) fields.get("RegOut");
assertEquals("test_regression_clax", regOut.getNamespace());
assertEquals(BuiltInType.STRING, regOut.getFeelType());
assertEquals(dmnFEELString, regOut.getBaseType());
SimpleTypeImpl regProb = (SimpleTypeImpl) fields.get("RegProb");
assertEquals("test_regression_clax", regProb.getNamespace());
assertEquals(BuiltInType.NUMBER, regProb.getFeelType());
assertEquals(dmnFEELNumber, regProb.getBaseType());
SimpleTypeImpl regProbA = (SimpleTypeImpl) fields.get("RegProbA");
assertEquals("test_regression_clax", regProbA.getNamespace());
assertEquals(BuiltInType.NUMBER, regProbA.getFeelType());
assertEquals(dmnFEELNumber, regProbA.getBaseType());
}
use of org.kie.dmn.core.pmml.DMNPMMLModelInfo in project drools by kiegroup.
the class DMNRuntimePMMLTest method runDMNModelInvokingPMML.
static void runDMNModelInvokingPMML(final DMNRuntime runtime) {
final DMNModel dmnModel = runtime.getModel("http://www.trisotech.com/definitions/_ca466dbe-20b4-4e88-a43f-4ce3aff26e4f", "KiePMMLScoreCard");
assertThat(dmnModel, notNullValue());
assertThat(DMNRuntimeUtil.formatMessages(dmnModel.getMessages()), dmnModel.hasErrors(), is(false));
final DMNContext emptyContext = DMNFactory.newContext();
final DMNResult dmnResult = runtime.evaluateAll(dmnModel, emptyContext);
LOG.debug("{}", dmnResult);
assertThat(DMNRuntimeUtil.formatMessages(dmnResult.getMessages()), dmnResult.hasErrors(), is(false));
final DMNContext result = dmnResult.getContext();
assertThat(result.get("my decision"), is(new BigDecimal("41.345")));
// additional import info.
Map<String, DMNImportPMMLInfo> pmmlImportInfo = ((DMNModelImpl) dmnModel).getPmmlImportInfo();
assertThat(pmmlImportInfo.keySet(), hasSize(1));
DMNImportPMMLInfo p0 = pmmlImportInfo.values().iterator().next();
assertThat(p0.getImportName(), is("iris"));
assertThat(p0.getModels(), hasSize(1));
DMNPMMLModelInfo m0 = p0.getModels().iterator().next();
assertThat(m0.getName(), is("Sample Score"));
assertThat(m0.getInputFields(), hasEntry(is("age"), anything()));
assertThat(m0.getInputFields(), hasEntry(is("occupation"), anything()));
assertThat(m0.getInputFields(), hasEntry(is("residenceState"), anything()));
assertThat(m0.getInputFields(), hasEntry(is("validLicense"), anything()));
assertThat(m0.getInputFields(), not(hasEntry(is("overallScore"), anything())));
assertThat(m0.getInputFields(), not(hasEntry(is("calculatedScore"), anything())));
assertThat(m0.getOutputFields(), hasEntry(is("calculatedScore"), anything()));
}
Aggregations