use of org.kie.api.pmml.PMML4Result in project drools by kiegroup.
the class ScorecardTest method testScorecard.
@Test
public void testScorecard() throws Exception {
RuleUnitExecutor executor = createExecutor(source1);
PMMLRequestData requestData = createRequest("123", "Sample Score", 33.0, "SKYDIVER", "KN", true);
PMML4Result resultHolder = new PMML4Result();
List<String> possiblePackages = calculatePossiblePackageNames("Sample Score", "org.drools.scorecards.example");
Class<? extends RuleUnit> unitClass = getStartingRuleUnit("RuleUnitIndicator", (InternalKnowledgeBase) kbase, possiblePackages);
assertNotNull(unitClass);
executor.run(unitClass);
Collection<? extends EntryPoint> eps = ((InternalRuleUnitExecutor) executor).getKieSession().getEntryPoints();
eps.forEach(ep -> {
System.out.println(ep);
});
data.insert(requestData);
resultData.insert(resultHolder);
executor.run(unitClass);
assertEquals(3, resultHolder.getResultVariables().size());
Object scorecard = resultHolder.getResultValue("ScoreCard", null);
assertNotNull(scorecard);
Double score = resultHolder.getResultValue("ScoreCard", "score", Double.class).orElse(null);
assertEquals(41.345, score, 0.000);
Object ranking = resultHolder.getResultValue("ScoreCard", "ranking");
assertNotNull(ranking);
assertTrue(ranking instanceof LinkedHashMap);
LinkedHashMap map = (LinkedHashMap) ranking;
assertTrue(map.containsKey("LX00"));
assertTrue(map.containsKey("RES"));
assertTrue(map.containsKey("CX2"));
assertEquals(-1.0, map.get("LX00"));
assertEquals(-10.0, map.get("RES"));
assertEquals(-30.0, map.get("CX2"));
Iterator iter = map.keySet().iterator();
assertEquals("LX00", iter.next());
assertEquals("RES", iter.next());
assertEquals("CX2", iter.next());
}
use of org.kie.api.pmml.PMML4Result in project drools by kiegroup.
the class ScorecardTest method testScorecardOutputs.
@Test
public void testScorecardOutputs() throws Exception {
// RuleUnitExecutor.create().bind(kbase);
RuleUnitExecutor executor = createExecutor(source2);
PMMLRequestData requestData = new PMMLRequestData("123", "SampleScorecard");
requestData.addRequestParam("cage", "engineering");
requestData.addRequestParam("age", 25);
requestData.addRequestParam("wage", 500.0);
PMML4Result resultHolder = new PMML4Result();
List<String> possiblePackages = calculatePossiblePackageNames("SampleScorecard");
Class<? extends RuleUnit> unitClass = getStartingRuleUnit("RuleUnitIndicator", (InternalKnowledgeBase) kbase, possiblePackages);
assertNotNull(unitClass);
executor.run(unitClass);
data.insert(requestData);
resultData.insert(resultHolder);
executor.run(unitClass);
assertEquals("OK", resultHolder.getResultCode());
assertEquals(6, resultHolder.getResultVariables().size());
assertNotNull(resultHolder.getResultValue("OutRC1", null));
assertNotNull(resultHolder.getResultValue("OutRC2", null));
assertNotNull(resultHolder.getResultValue("OutRC3", null));
assertEquals("RC2", resultHolder.getResultValue("OutRC1", "value"));
assertEquals("RC1", resultHolder.getResultValue("OutRC2", "value"));
assertEquals("RC1", resultHolder.getResultValue("OutRC3", "value"));
}
use of org.kie.api.pmml.PMML4Result in project drools by kiegroup.
the class ScorecardTest method testScorecardWithCompoundPredicate.
@Test
public void testScorecardWithCompoundPredicate() {
KieBase kieBase = PMMLKieBaseUtil.createKieBaseWithPMML(SOURCE_COMPOUND_PREDICATE_SCORECARD);
PMMLExecutor executor = new PMMLExecutor(kieBase);
PMMLRequestData requestData = new PMMLRequestData("123", "ScorecardCompoundPredicate");
requestData.addRequestParam("param1", 41.0);
requestData.addRequestParam("param2", 21.0);
PMML4Result resultHolder = executor.run(requestData);
double score = resultHolder.getResultValue("ScoreCard", "score", Double.class).get();
Assertions.assertThat(score).isEqualTo(120.8);
Map<String, Double> rankingMap = (Map<String, Double>) resultHolder.getResultValue("ScoreCard", "ranking");
Assertions.assertThat(rankingMap.get("reasonCh1")).isEqualTo(50);
Assertions.assertThat(rankingMap.get("reasonCh2")).isEqualTo(5);
requestData = new PMMLRequestData("123", "ScorecardCompoundPredicate");
requestData.addRequestParam("param1", 40.0);
requestData.addRequestParam("param2", 25.0);
resultHolder = executor.run(requestData);
score = resultHolder.getResultValue("ScoreCard", "score", Double.class).get();
Assertions.assertThat(score).isEqualTo(120.8);
requestData = new PMMLRequestData("123", "ScorecardCompoundPredicate");
requestData.addRequestParam("param1", 40.0);
requestData.addRequestParam("param2", 55.0);
resultHolder = executor.run(requestData);
score = resultHolder.getResultValue("ScoreCard", "score", Double.class).get();
Assertions.assertThat(score).isEqualTo(210.8);
requestData = new PMMLRequestData("123", "ScorecardCompoundPredicate");
requestData.addRequestParam("param1", 4.0);
requestData.addRequestParam("param2", -25.0);
resultHolder = executor.run(requestData);
score = resultHolder.getResultValue("ScoreCard", "score", Double.class).get();
Assertions.assertThat(score).isEqualTo(30.8);
}
use of org.kie.api.pmml.PMML4Result in project drools by kiegroup.
the class ScorecardTest method testMultipleInputData.
@Test
public void testMultipleInputData() throws Exception {
RuleUnitExecutor[] executor = new RuleUnitExecutor[3];
PMMLRequestData[] requestData = new PMMLRequestData[3];
PMML4Result[] resultHolder = new PMML4Result[3];
Resource res = ResourceFactory.newClassPathResource(source1);
kbase = new KieHelper().addResource(res, ResourceType.PMML).build();
executor[0] = RuleUnitExecutor.create().bind(kbase);
executor[1] = RuleUnitExecutor.create().bind(kbase);
executor[2] = RuleUnitExecutor.create().bind(kbase);
DataSource<PMMLRequestData>[] requests = new DataSource[3];
DataSource<PMML4Result>[] results = new DataSource[3];
DataSource<PMML4Data>[] pmmlDatas = new DataSource[3];
Double[] expectedScores = new Double[3];
expectedScores[0] = 41.345;
expectedScores[1] = 26.345;
expectedScores[2] = 39.345;
LinkedHashMap<String, Double>[] expectedResults = new LinkedHashMap[3];
expectedResults[0] = new LinkedHashMap<>();
expectedResults[0].put("LX00", -1.0);
expectedResults[0].put("RES", -10.0);
expectedResults[0].put("CX2", -30.0);
expectedResults[1] = new LinkedHashMap<>();
expectedResults[1].put("RES", 10.0);
expectedResults[1].put("LX00", -1.0);
expectedResults[1].put("OCC", -10.0);
expectedResults[1].put("ABZ", -25.0);
expectedResults[2] = new LinkedHashMap<>();
expectedResults[2].put("LX00", 1.0);
expectedResults[2].put("OCC", -5.0);
expectedResults[2].put("RES", -5.0);
expectedResults[2].put("CX1", -30.0);
requestData[0] = createRequest("123", "Sample Score", 33.0, "SKYDIVER", "KN", true);
requestData[1] = createRequest("124", "Sample Score", 50.0, "TEACHER", "AP", true);
requestData[2] = createRequest("125", "Sample Score", 10.0, "STUDENT", "TN", false);
for (int x = 0; x < 3; x++) {
requests[x] = executor[x].newDataSource("request");
results[x] = executor[x].newDataSource("results");
pmmlDatas[x] = executor[x].newDataSource("pmmlData");
resultHolder[x] = new PMML4Result(requestData[x].getCorrelationId());
}
List<String> possiblePackages = calculatePossiblePackageNames("Sample Score", "org.drools.scorecards.example");
Class<? extends RuleUnit> unitClass = getStartingRuleUnit("RuleUnitIndicator", (InternalKnowledgeBase) kbase, possiblePackages);
assertNotNull(unitClass);
for (int x = 0; x < 3; x++) {
executor[x].run(unitClass);
}
for (int y = 0; y < 3; y++) {
requests[y].insert(requestData[y]);
results[y].insert(resultHolder[y]);
}
for (int z = 0; z < 3; z++) {
executor[z].run(unitClass);
}
for (int p = 0; p < 3; p++) {
checkResult(resultHolder[p], expectedScores[p], expectedResults[p]);
}
}
use of org.kie.api.pmml.PMML4Result in project drools by kiegroup.
the class SimpleRegressionTest method testRegression.
@Test
public void testRegression() throws Exception {
RuleUnitExecutor executor = createExecutor(source1);
PMMLRequestData request = new PMMLRequestData("123", "LinReg");
request.addRequestParam("fld1", 0.9);
request.addRequestParam("fld2", 0.3);
request.addRequestParam("fld3", "x");
PMML4Result resultHolder = new PMML4Result();
List<String> possiblePackages = calculatePossiblePackageNames("LinReg");
Class<? extends RuleUnit> unitClass = getStartingRuleUnit("RuleUnitIndicator", (InternalKnowledgeBase) kbase, possiblePackages);
assertNotNull(unitClass);
int x = executor.run(unitClass);
data.insert(request);
resultData.insert(resultHolder);
executor.run(unitClass);
assertEquals("OK", resultHolder.getResultCode());
assertNotNull(resultHolder.getResultValue("Fld4", null));
Double value = resultHolder.getResultValue("Fld4", "value", Double.class).orElse(null);
assertNotNull(value);
double chkVal = 0.5 + 5 * 0.9 * 0.9 + 2 * 0.3 - 3.0 + 0.4 * 0.9 * 0.3;
chkVal = 1.0 / (1.0 + Math.exp(-chkVal));
assertEquals(chkVal, value, 1e-6);
}
Aggregations