use of org.kie.api.pmml.PMMLRequestData in project drools by kiegroup.
the class DecisionTreeTest method testReturnNullNoTrueChildPredictionStrategy.
@Test
public void testReturnNullNoTrueChildPredictionStrategy() {
KieBase kieBase = PMMLKieBaseUtil.createKieBaseWithPMML(TREE_RETURN_NULL_NOTRUECHILD_STRATEGY);
PMMLExecutor executor = new PMMLExecutor(kieBase);
PMMLRequestData request = new PMMLRequestData("123", "TreeTest");
request.addRequestParam("fld1", 30.0);
PMML4Result resultHolder = executor.run(request);
Assertions.assertThat(resultHolder).isNotNull();
String targetValue = resultHolder.getResultValue("Fld2", "value", String.class).orElse(null);
Assertions.assertThat(targetValue).isEqualTo("tgtY");
request = new PMMLRequestData("123", "TreeTest");
request.addRequestParam("fld1", 50.0);
resultHolder = executor.run(request);
Assertions.assertThat(resultHolder).isNotNull();
Assertions.assertThat(resultHolder.getResultValue("Fld2", "value", String.class)).isEmpty();
}
use of org.kie.api.pmml.PMMLRequestData in project drools by kiegroup.
the class DecisionTreeTest method testSimpleTree.
@Test
public void testSimpleTree() throws Exception {
RuleUnitExecutor executor = createExecutor(source1);
PMMLRequestData request = new PMMLRequestData("123", "TreeTest");
request.addRequestParam("fld1", 30.0);
request.addRequestParam("fld2", 60.0);
request.addRequestParam("fld3", "false");
request.addRequestParam("fld4", "optA");
PMML4Result resultHolder = new PMML4Result();
List<String> possiblePackages = calculatePossiblePackageNames("TreeTest");
Class<? extends RuleUnit> unitClass = getStartingRuleUnit("RuleUnitIndicator", (InternalKnowledgeBase) kbase, possiblePackages);
assertNotNull(unitClass);
int x = executor.run(unitClass);
data.insert(request);
resultData.insert(resultHolder);
executor.run(unitClass);
assertEquals("OK", resultHolder.getResultCode());
Object obj = resultHolder.getResultValue("Fld5", null);
assertNotNull(obj);
String targetValue = resultHolder.getResultValue("Fld5", "value", String.class).orElse(null);
assertEquals("tgtY", targetValue);
}
use of org.kie.api.pmml.PMMLRequestData in project drools by kiegroup.
the class DecisionTreeTest method testNullPredictionMissingValueStrategy.
@Test
public void testNullPredictionMissingValueStrategy() {
KieBase kieBase = PMMLKieBaseUtil.createKieBaseWithPMML(TREE_RETURN_NULL_MISSING_STRATEGY);
PMMLExecutor executor = new PMMLExecutor(kieBase);
PMMLRequestData request = new PMMLRequestData("123", "TreeTest");
request.addRequestParam("fld1", 30.0);
PMML4Result resultHolder = executor.run(request);
Assertions.assertThat(resultHolder).isNotNull();
String targetValue = resultHolder.getResultValue("Fld3", "value", String.class).orElse(null);
Assertions.assertThat(targetValue).isEqualTo("tgtY");
request = new PMMLRequestData("123", "TreeTest");
request.addRequestParam("fld1", 100.0);
resultHolder = executor.run(request);
Assertions.assertThat(resultHolder).isNotNull();
targetValue = resultHolder.getResultValue("Fld3", "value", String.class).orElse(null);
Assertions.assertThat(targetValue).isNull();
}
use of org.kie.api.pmml.PMMLRequestData in project drools by kiegroup.
the class DecisionTreeTest method testMissingTree.
@Test
public void testMissingTree() throws Exception {
RuleUnitExecutor executor = createExecutor(source2);
PMMLRequestData requestData = new PMMLRequestData("123", "Missing");
requestData.addRequestParam(new ParameterInfo<>("123", "fld1", Double.class, 45.0));
requestData.addRequestParam(new ParameterInfo<>("123", "fld2", Double.class, 60.0));
requestData.addRequestParam(new ParameterInfo<>("123", "fld3", String.class, "optA"));
PMML4Result resultHolder = new PMML4Result();
List<String> possiblePackages = calculatePossiblePackageNames("Missing");
Class<? extends RuleUnit> unitClass = getStartingRuleUnit("RuleUnitIndicator", (InternalKnowledgeBase) kbase, possiblePackages);
assertNotNull(unitClass);
// initializes the model
int x = executor.run(unitClass);
data.insert(requestData);
resultData.insert(resultHolder);
executor.run(unitClass);
AbstractTreeToken missingTreeToken = resultHolder.getResultValue("MissingTreeToken", null, AbstractTreeToken.class).orElse(null);
assertNotNull(missingTreeToken);
Double tokVal = resultHolder.getResultValue("MissingTreeToken", "confidence", Double.class).orElse(null);
assertNotNull(tokVal);
assertEquals(0.6, tokVal, 0.0);
String current = resultHolder.getResultValue("MissingTreeToken", "current", String.class).orElse(null);
assertNotNull(current);
assertEquals("null", current);
Object fld9 = resultHolder.getResultValue("Fld9", null);
assertNotNull(fld9);
String fld9Val = resultHolder.getResultValue("Fld9", "value", String.class).orElse(null);
assertNotNull(fld9Val);
assertEquals("tgtZ", fld9Val);
}
use of org.kie.api.pmml.PMMLRequestData in project drools by kiegroup.
the class MiningmodelTest method testWithRegression.
@Test
public void testWithRegression() {
RuleUnitExecutor executor = createExecutor(source2);
// KieRuntimeLogger console = ((InternalRuleUnitExecutor)executor).addConsoleLogger();
PMMLRequestData request = new PMMLRequestData("123", "SampleScorecardMine");
request.addRequestParam("fld1r", 1.0);
request.addRequestParam("fld2r", 1.0);
request.addRequestParam("fld3r", "x");
PMML4Result resultHolder = new PMML4Result();
resultHolder.setCorrelationId(request.getCorrelationId());
DataSource<PMMLRequestData> childModelRequest = executor.newDataSource("childModelRequest");
DataSource<PMML4Result> childModelResults = executor.newDataSource("childModelResults");
DataSource<SegmentExecution> childModelSegments = executor.newDataSource("childModelSegments");
DataSource<? extends AbstractPMMLData> miningModelPojo = executor.newDataSource("miningModelPojo");
List<String> possiblePackages = this.calculatePossiblePackageNames("SampleScorecardMine");
Class<? extends RuleUnit> ruleUnitClass = this.getStartingRuleUnit("Start Mining - SampleScorecardMine", (InternalKnowledgeBase) kbase, possiblePackages);
assertNotNull(ruleUnitClass);
data.insert(request);
resultData.insert(resultHolder);
executor.run(ruleUnitClass);
// console.close();
resultData.forEach(rd -> {
assertEquals(request.getCorrelationId(), rd.getCorrelationId());
assertEquals("OK", rd.getResultCode());
if (rd.getSegmentationId() == null) {
System.out.println(rd);
assertNotNull(rd.getResultValue("RegOut", null));
String regOutValue = rd.getResultValue("RegOut", "value", String.class).orElse(null);
assertEquals("catC", regOutValue);
assertNotNull(rd.getResultValue("RegProb", null));
Double regProbValue = rd.getResultValue("RegProb", "value", Double.class).orElse(null);
assertEquals(0.709228, regProbValue, 1e-6);
assertNotNull(rd.getResultValue("RegProbA", null));
Double regProbValueA = rd.getResultValue("RegProbA", "value", Double.class).orElse(null);
assertEquals(0.010635, regProbValueA, 1e-6);
}
});
int segmentsExecuted = 0;
for (Iterator<SegmentExecution> iter = childModelSegments.iterator(); iter.hasNext(); ) {
SegmentExecution cms = iter.next();
assertEquals(request.getCorrelationId(), cms.getCorrelationId());
if (cms.getState() == SegmentExecutionState.COMPLETE)
segmentsExecuted++;
}
assertEquals(1, segmentsExecuted);
}
Aggregations