Search in sources :

Example 1 with DMNPMMLModelInfo

use of org.kie.dmn.core.pmml.DMNPMMLModelInfo in project drools by kiegroup.

the class DMNRuntimePMMLTest method testMultiOutputs.

@Test
public void testMultiOutputs() {
    final DMNRuntime runtime = DMNRuntimeUtil.createRuntimeWithAdditionalResources("KiePMMLRegressionClax.dmn", DMNRuntimePMMLTest.class, "test_regression_clax.pmml");
    final DMNModel dmnModel = runtime.getModel("http://www.trisotech.com/definitions/_ca466dbe-20b4-4e88-a43f-4ce3aff26e4f", "KiePMMLRegressionClax");
    assertThat(dmnModel, notNullValue());
    assertThat(DMNRuntimeUtil.formatMessages(dmnModel.getMessages()), dmnModel.hasErrors(), is(false));
    final DMNContext dmnContext = DMNFactory.newContext();
    dmnContext.set("fld1", 1.0);
    dmnContext.set("fld2", 1.0);
    dmnContext.set("fld3", "x");
    final DMNResult dmnResult = runtime.evaluateAll(dmnModel, dmnContext);
    LOG.debug("{}", dmnResult);
    assertThat(DMNRuntimeUtil.formatMessages(dmnResult.getMessages()), dmnResult.hasErrors(), is(false));
    final DMNContext resultContext = dmnResult.getContext();
    final Map<String, Object> result = (Map<String, Object>) resultContext.get("my decision");
    assertEquals("catD", (String) result.get("RegOut"));
    assertEquals(0.8279559384018024, ((BigDecimal) result.get("RegProb")).doubleValue(), COMPARISON_DELTA);
    assertEquals(0.0022681396056233208, ((BigDecimal) result.get("RegProbA")).doubleValue(), COMPARISON_DELTA);
    DMNType dmnFEELNumber = ((DMNModelImpl) dmnModel).getTypeRegistry().resolveType(dmnModel.getDefinitions().getURIFEEL(), BuiltInType.NUMBER.getName());
    DMNType dmnFEELString = ((DMNModelImpl) dmnModel).getTypeRegistry().resolveType(dmnModel.getDefinitions().getURIFEEL(), BuiltInType.STRING.getName());
    // additional import info.
    Map<String, DMNImportPMMLInfo> pmmlImportInfo = ((DMNModelImpl) dmnModel).getPmmlImportInfo();
    assertThat(pmmlImportInfo.keySet(), hasSize(1));
    DMNImportPMMLInfo p0 = pmmlImportInfo.values().iterator().next();
    assertThat(p0.getImportName(), is("test_regression_clax"));
    assertThat(p0.getModels(), hasSize(1));
    DMNPMMLModelInfo m0 = p0.getModels().iterator().next();
    assertThat(m0.getName(), is("LinReg"));
    Map<String, DMNType> inputFields = m0.getInputFields();
    SimpleTypeImpl fld1 = (SimpleTypeImpl) inputFields.get("fld1");
    assertEquals("test_regression_clax", fld1.getNamespace());
    assertEquals(BuiltInType.NUMBER, fld1.getFeelType());
    assertEquals(dmnFEELNumber, fld1.getBaseType());
    SimpleTypeImpl fld2 = (SimpleTypeImpl) inputFields.get("fld2");
    assertEquals("test_regression_clax", fld2.getNamespace());
    assertEquals(BuiltInType.NUMBER, fld2.getFeelType());
    assertEquals(dmnFEELNumber, fld2.getBaseType());
    SimpleTypeImpl fld3 = (SimpleTypeImpl) inputFields.get("fld3");
    assertEquals("test_regression_clax", fld3.getNamespace());
    assertEquals(BuiltInType.STRING, fld3.getFeelType());
    assertEquals(dmnFEELString, fld3.getBaseType());
    Map<String, DMNType> outputFields = m0.getOutputFields();
    CompositeTypeImpl output = (CompositeTypeImpl) outputFields.get("LinReg");
    assertEquals("test_regression_clax", output.getNamespace());
    Map<String, DMNType> fields = output.getFields();
    SimpleTypeImpl regOut = (SimpleTypeImpl) fields.get("RegOut");
    assertEquals("test_regression_clax", regOut.getNamespace());
    assertEquals(BuiltInType.STRING, regOut.getFeelType());
    assertEquals(dmnFEELString, regOut.getBaseType());
    SimpleTypeImpl regProb = (SimpleTypeImpl) fields.get("RegProb");
    assertEquals("test_regression_clax", regProb.getNamespace());
    assertEquals(BuiltInType.NUMBER, regProb.getFeelType());
    assertEquals(dmnFEELNumber, regProb.getBaseType());
    SimpleTypeImpl regProbA = (SimpleTypeImpl) fields.get("RegProbA");
    assertEquals("test_regression_clax", regProbA.getNamespace());
    assertEquals(BuiltInType.NUMBER, regProbA.getFeelType());
    assertEquals(dmnFEELNumber, regProbA.getBaseType());
}
Also used : DMNResult(org.kie.dmn.api.core.DMNResult) DMNContext(org.kie.dmn.api.core.DMNContext) DMNPMMLModelInfo(org.kie.dmn.core.pmml.DMNPMMLModelInfo) DMNModelImpl(org.kie.dmn.core.impl.DMNModelImpl) DMNRuntime(org.kie.dmn.api.core.DMNRuntime) SimpleTypeImpl(org.kie.dmn.core.impl.SimpleTypeImpl) DMNImportPMMLInfo(org.kie.dmn.core.pmml.DMNImportPMMLInfo) Map(java.util.Map) DMNModel(org.kie.dmn.api.core.DMNModel) CompositeTypeImpl(org.kie.dmn.core.impl.CompositeTypeImpl) DMNType(org.kie.dmn.api.core.DMNType) Test(org.junit.Test)

Example 2 with DMNPMMLModelInfo

use of org.kie.dmn.core.pmml.DMNPMMLModelInfo in project drools by kiegroup.

the class DMNRuntimePMMLTest method runDMNModelInvokingPMML.

static void runDMNModelInvokingPMML(final DMNRuntime runtime) {
    final DMNModel dmnModel = runtime.getModel("http://www.trisotech.com/definitions/_ca466dbe-20b4-4e88-a43f-4ce3aff26e4f", "KiePMMLScoreCard");
    assertThat(dmnModel, notNullValue());
    assertThat(DMNRuntimeUtil.formatMessages(dmnModel.getMessages()), dmnModel.hasErrors(), is(false));
    final DMNContext emptyContext = DMNFactory.newContext();
    final DMNResult dmnResult = runtime.evaluateAll(dmnModel, emptyContext);
    LOG.debug("{}", dmnResult);
    assertThat(DMNRuntimeUtil.formatMessages(dmnResult.getMessages()), dmnResult.hasErrors(), is(false));
    final DMNContext result = dmnResult.getContext();
    assertThat(result.get("my decision"), is(new BigDecimal("41.345")));
    // additional import info.
    Map<String, DMNImportPMMLInfo> pmmlImportInfo = ((DMNModelImpl) dmnModel).getPmmlImportInfo();
    assertThat(pmmlImportInfo.keySet(), hasSize(1));
    DMNImportPMMLInfo p0 = pmmlImportInfo.values().iterator().next();
    assertThat(p0.getImportName(), is("iris"));
    assertThat(p0.getModels(), hasSize(1));
    DMNPMMLModelInfo m0 = p0.getModels().iterator().next();
    assertThat(m0.getName(), is("Sample Score"));
    assertThat(m0.getInputFields(), hasEntry(is("age"), anything()));
    assertThat(m0.getInputFields(), hasEntry(is("occupation"), anything()));
    assertThat(m0.getInputFields(), hasEntry(is("residenceState"), anything()));
    assertThat(m0.getInputFields(), hasEntry(is("validLicense"), anything()));
    assertThat(m0.getInputFields(), not(hasEntry(is("overallScore"), anything())));
    assertThat(m0.getInputFields(), not(hasEntry(is("calculatedScore"), anything())));
    assertThat(m0.getOutputFields(), hasEntry(is("calculatedScore"), anything()));
}
Also used : DMNResult(org.kie.dmn.api.core.DMNResult) DMNContext(org.kie.dmn.api.core.DMNContext) DMNPMMLModelInfo(org.kie.dmn.core.pmml.DMNPMMLModelInfo) DMNImportPMMLInfo(org.kie.dmn.core.pmml.DMNImportPMMLInfo) DMNModelImpl(org.kie.dmn.core.impl.DMNModelImpl) DMNModel(org.kie.dmn.api.core.DMNModel) BigDecimal(java.math.BigDecimal)

Aggregations

DMNContext (org.kie.dmn.api.core.DMNContext)2 DMNModel (org.kie.dmn.api.core.DMNModel)2 DMNResult (org.kie.dmn.api.core.DMNResult)2 DMNModelImpl (org.kie.dmn.core.impl.DMNModelImpl)2 DMNImportPMMLInfo (org.kie.dmn.core.pmml.DMNImportPMMLInfo)2 DMNPMMLModelInfo (org.kie.dmn.core.pmml.DMNPMMLModelInfo)2 BigDecimal (java.math.BigDecimal)1 Map (java.util.Map)1 Test (org.junit.Test)1 DMNRuntime (org.kie.dmn.api.core.DMNRuntime)1 DMNType (org.kie.dmn.api.core.DMNType)1 CompositeTypeImpl (org.kie.dmn.core.impl.CompositeTypeImpl)1 SimpleTypeImpl (org.kie.dmn.core.impl.SimpleTypeImpl)1