use of org.kie.pmml.pmml_4_2.model.ScoreCard in project drools by kiegroup.
the class PMML4ModelFactory method getModels.
public List<PMML4Model> getModels(PMML4Unit owner) {
List<PMML4Model> pmml4Models = new ArrayList<>();
owner.getRawPMML().getAssociationModelsAndBaselineModelsAndClusteringModels().forEach(serializable -> {
if (serializable instanceof Scorecard) {
Scorecard sc = (Scorecard) serializable;
ScorecardModel model = new ScorecardModel(sc.getModelName(), sc, null, owner);
pmml4Models.add(model);
} else if (serializable instanceof RegressionModel) {
RegressionModel rm = (RegressionModel) serializable;
Regression model = new Regression(rm.getModelName(), rm, null, owner);
pmml4Models.add(model);
} else if (serializable instanceof TreeModel) {
TreeModel tm = (TreeModel) serializable;
Treemodel model = new Treemodel(tm.getModelName(), tm, null, owner);
pmml4Models.add(model);
} else if (serializable instanceof MiningModel) {
MiningModel mm = (MiningModel) serializable;
Miningmodel model = new Miningmodel(mm.getModelName(), mm, null, owner);
pmml4Models.add(model);
}
});
return pmml4Models;
}
use of org.kie.pmml.pmml_4_2.model.ScoreCard in project drools by kiegroup.
the class MiningmodelTest method testSelectAll.
@Test
public void testSelectAll() {
RuleUnitExecutor executor = createExecutor(source4);
// KieRuntimeLogger console = ((InternalRuleUnitExecutor)executor).addConsoleLogger();
PMMLRequestData request = new PMMLRequestData("1234", "SampleSelectAllMine");
request.addRequestParam("age", 33.0);
request.addRequestParam("occupation", "SKYDIVER");
request.addRequestParam("residenceState", "KN");
request.addRequestParam("validLicense", true);
PMML4Result resultHolder = new PMML4Result();
resultHolder.setCorrelationId(request.getCorrelationId());
DataSource<PMMLRequestData> childModelRequest = executor.newDataSource("childModelRequest");
DataSource<PMML4Result> childModelResults = executor.newDataSource("childModelResults");
DataSource<SegmentExecution> childModelSegments = executor.newDataSource("childModelSegments");
DataSource<? extends AbstractPMMLData> miningModelPojo = executor.newDataSource("miningModelPojo");
List<String> possiblePackages = this.calculatePossiblePackageNames("SampleSelectAllMine");
Class<? extends RuleUnit> ruleUnitClass = this.getStartingRuleUnit("Start Mining - SampleSelectAllMine", (InternalKnowledgeBase) kbase, possiblePackages);
assertNotNull(ruleUnitClass);
data.insert(request);
resultData.insert(resultHolder);
executor.run(ruleUnitClass);
// console.close();
resultData.forEach(rd -> {
assertEquals("OK", rd.getResultCode());
assertEquals(request.getCorrelationId(), rd.getCorrelationId());
ScoreCard sc = rd.getResultValue("ScoreCard", null, ScoreCard.class).orElse(null);
assertNotNull(sc);
Map map = sc.getRanking();
assertNotNull(map);
assertTrue(map instanceof LinkedHashMap);
LinkedHashMap ranking = (LinkedHashMap) map;
assertTrue(ranking.containsKey("LX00") || ranking.containsKey("LC00"));
if (ranking.containsKey("LX00")) {
assertTrue(ranking.containsKey("RES"));
assertTrue(ranking.containsKey("CX2"));
assertEquals(-1.0, ranking.get("LX00"));
assertEquals(-10.0, ranking.get("RES"));
assertEquals(-30.0, ranking.get("CX2"));
Iterator iter = ranking.keySet().iterator();
assertEquals("LX00", iter.next());
assertEquals("RES", iter.next());
assertEquals("CX2", iter.next());
assertEquals(41.345, sc.getScore(), 1e-6);
} else {
assertTrue(ranking.containsKey("RST"));
assertTrue(ranking.containsKey("DX2"));
assertEquals(-1.0, ranking.get("LC00"));
assertEquals(10.0, ranking.get("RST"));
assertEquals(-30.0, ranking.get("DX2"));
Iterator iter = ranking.keySet().iterator();
assertEquals("RST", iter.next());
assertEquals("LC00", iter.next());
assertEquals("DX2", iter.next());
assertEquals(21.345, sc.getScore(), 1e-6);
}
});
int segmentsExecuted = 0;
for (Iterator<SegmentExecution> iter = childModelSegments.iterator(); iter.hasNext(); ) {
SegmentExecution cms = iter.next();
assertEquals(request.getCorrelationId(), cms.getCorrelationId());
if (cms.getState() == SegmentExecutionState.COMPLETE)
segmentsExecuted++;
}
assertEquals(2, segmentsExecuted);
}
use of org.kie.pmml.pmml_4_2.model.ScoreCard in project drools by kiegroup.
the class MiningmodelTest method testWithScorecard.
@Test
public void testWithScorecard() {
RuleUnitExecutor executor = createExecutor(source2);
// KieRuntimeLogger console = ((InternalRuleUnitExecutor)executor).addConsoleLogger();
PMMLRequestData request = new PMMLRequestData("1234", "SampleScorecardMine");
request.addRequestParam("age", 33.0);
request.addRequestParam("occupation", "SKYDIVER");
request.addRequestParam("residenceState", "KN");
request.addRequestParam("validLicense", true);
PMML4Result resultHolder = new PMML4Result();
resultHolder.setCorrelationId(request.getCorrelationId());
DataSource<PMMLRequestData> childModelRequest = executor.newDataSource("childModelRequest");
DataSource<PMML4Result> childModelResults = executor.newDataSource("childModelResults");
DataSource<SegmentExecution> childModelSegments = executor.newDataSource("childModelSegments");
DataSource<? extends AbstractPMMLData> miningModelPojo = executor.newDataSource("miningModelPojo");
List<String> possiblePackages = this.calculatePossiblePackageNames("SampleScorecardMine");
Class<? extends RuleUnit> ruleUnitClass = this.getStartingRuleUnit("Start Mining - SampleScorecardMine", (InternalKnowledgeBase) kbase, possiblePackages);
assertNotNull(ruleUnitClass);
data.insert(request);
resultData.insert(resultHolder);
executor.run(ruleUnitClass);
// console.close();
resultData.forEach(rd -> {
assertEquals(request.getCorrelationId(), rd.getCorrelationId());
assertEquals("OK", rd.getResultCode());
if (rd.getSegmentationId() == null) {
ScoreCard sc = rd.getResultValue("ScoreCard", null, ScoreCard.class).orElse(null);
assertNotNull(sc);
Map map = sc.getRanking();
assertNotNull(map);
assertTrue(map instanceof LinkedHashMap);
LinkedHashMap ranking = (LinkedHashMap) map;
assertTrue(ranking.containsKey("LX00"));
assertTrue(ranking.containsKey("RES"));
assertTrue(ranking.containsKey("CX2"));
assertEquals(-1.0, ranking.get("LX00"));
assertEquals(-10.0, ranking.get("RES"));
assertEquals(-30.0, ranking.get("CX2"));
Iterator iter = ranking.keySet().iterator();
assertEquals("LX00", iter.next());
assertEquals("RES", iter.next());
assertEquals("CX2", iter.next());
}
});
int segmentsExecuted = 0;
for (Iterator<SegmentExecution> iter = childModelSegments.iterator(); iter.hasNext(); ) {
SegmentExecution cms = iter.next();
assertEquals(request.getCorrelationId(), cms.getCorrelationId());
if (cms.getState() == SegmentExecutionState.COMPLETE)
segmentsExecuted++;
}
assertEquals(1, segmentsExecuted);
}
use of org.kie.pmml.pmml_4_2.model.ScoreCard in project drools by kiegroup.
the class ScorecardTest method testScorecardWithSimpleSetPredicateWithSpaceValue.
@Test
public void testScorecardWithSimpleSetPredicateWithSpaceValue() {
KieBase kieBase = PMMLKieBaseUtil.createKieBaseWithPMML(SOURCE_SIMPLE_SET_SPACE_VALUE_SCORECARD);
PMMLExecutor executor = new PMMLExecutor(kieBase);
PMMLRequestData requestData = new PMMLRequestData("123", "SimpleSetScorecardWithSpaceValue");
requestData.addRequestParam("param", "optA");
PMML4Result resultHolder = executor.run(requestData);
Assertions.assertThat(resultHolder).isNotNull();
double score = resultHolder.getResultValue("ScoreCard", "score", Double.class).get();
Assertions.assertThat(score).isEqualTo(13);
}
use of org.kie.pmml.pmml_4_2.model.ScoreCard in project drools by kiegroup.
the class ScorecardTest method testScorecardWithSimpleSetPredicate.
@Test
public void testScorecardWithSimpleSetPredicate() {
KieBase kieBase = PMMLKieBaseUtil.createKieBaseWithPMML(SOURCE_SIMPLE_SET_SCORECARD);
PMMLExecutor executor = new PMMLExecutor(kieBase);
PMMLRequestData requestData = new PMMLRequestData("123", "SimpleSetScorecard");
requestData.addRequestParam("param1", 4);
requestData.addRequestParam("param2", "optA");
PMML4Result resultHolder = executor.run(requestData);
Assertions.assertThat(resultHolder).isNotNull();
double score = resultHolder.getResultValue("ScoreCard", "score", Double.class).get();
Assertions.assertThat(score).isEqualTo(113);
requestData = new PMMLRequestData("123", "SimpleSetScorecard");
requestData.addRequestParam("param1", 5);
requestData.addRequestParam("param2", "optA");
resultHolder = executor.run(requestData);
Assertions.assertThat(resultHolder).isNotNull();
score = resultHolder.getResultValue("ScoreCard", "score", Double.class).get();
Assertions.assertThat(score).isEqualTo(33);
requestData = new PMMLRequestData("123", "SimpleSetScorecard");
requestData.addRequestParam("param1", -5);
requestData.addRequestParam("param2", "optC");
resultHolder = executor.run(requestData);
Assertions.assertThat(resultHolder).isNotNull();
score = resultHolder.getResultValue("ScoreCard", "score", Double.class).get();
Assertions.assertThat(score).isEqualTo(123);
requestData = new PMMLRequestData("123", "SimpleSetScorecard");
requestData.addRequestParam("param1", -5);
requestData.addRequestParam("param2", "optA");
resultHolder = executor.run(requestData);
Assertions.assertThat(resultHolder).isNotNull();
score = resultHolder.getResultValue("ScoreCard", "score", Double.class).get();
Assertions.assertThat(score).isEqualTo(113);
}
Aggregations