use of org.dmg.pmml.pmml_4_2.descr.Scorecard in project drools by kiegroup.
the class PMML4Compiler method checkBuildingResources.
private static KieBase checkBuildingResources(PMML pmml) throws IOException {
KieServices ks = KieServices.Factory.get();
KieContainer kieContainer = ks.getKieClasspathContainer();
if (registry == null) {
initRegistry();
}
String chosenKieBase = null;
for (Object o : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels()) {
if (o instanceof NaiveBayesModel) {
if (!naiveBayesLoaded) {
for (String ntempl : NAIVE_BAYES_TEMPLATES) {
prepareTemplate(ntempl);
}
naiveBayesLoaded = true;
}
chosenKieBase = chosenKieBase == null ? "KiePMML-Bayes" : "KiePMML";
}
if (o instanceof NeuralNetwork) {
if (!neuralLoaded) {
for (String ntempl : NEURAL_TEMPLATES) {
prepareTemplate(ntempl);
}
neuralLoaded = true;
}
chosenKieBase = chosenKieBase == null ? "KiePMML-Neural" : "KiePMML";
}
if (o instanceof ClusteringModel) {
if (!clusteringLoaded) {
for (String ntempl : CLUSTERING_TEMPLATES) {
prepareTemplate(ntempl);
}
clusteringLoaded = true;
}
chosenKieBase = chosenKieBase == null ? "KiePMML-Cluster" : "KiePMML";
}
if (o instanceof SupportVectorMachineModel) {
if (!svmLoaded) {
for (String ntempl : SVM_TEMPLATES) {
prepareTemplate(ntempl);
}
svmLoaded = true;
}
chosenKieBase = chosenKieBase == null ? "KiePMML-SVM" : "KiePMML";
}
if (o instanceof TreeModel) {
if (!treeLoaded) {
for (String ntempl : TREE_TEMPLATES) {
prepareTemplate(ntempl);
}
treeLoaded = true;
}
chosenKieBase = chosenKieBase == null ? "KiePMML-Tree" : "KiePMML";
}
if (o instanceof RegressionModel) {
if (!simpleRegLoaded) {
for (String ntempl : SIMPLEREG_TEMPLATES) {
prepareTemplate(ntempl);
}
simpleRegLoaded = true;
}
chosenKieBase = chosenKieBase == null ? "KiePMML-Regression" : "KiePMML";
}
if (o instanceof Scorecard) {
if (!scorecardLoaded) {
for (String ntempl : SCORECARD_TEMPLATES) {
prepareTemplate(ntempl);
}
scorecardLoaded = true;
}
chosenKieBase = chosenKieBase == null ? "KiePMML-Scorecard" : "KiePMML";
}
}
if (chosenKieBase == null) {
chosenKieBase = "KiePMML-Base";
}
return kieContainer.getKieBase(chosenKieBase);
}
use of org.dmg.pmml.pmml_4_2.descr.Scorecard in project drools by kiegroup.
the class PMML4ModelFactory method getModels.
public List<PMML4Model> getModels(PMML4Unit owner) {
List<PMML4Model> pmml4Models = new ArrayList<>();
owner.getRawPMML().getAssociationModelsAndBaselineModelsAndClusteringModels().forEach(serializable -> {
if (serializable instanceof Scorecard) {
Scorecard sc = (Scorecard) serializable;
ScorecardModel model = new ScorecardModel(sc.getModelName(), sc, null, owner);
pmml4Models.add(model);
} else if (serializable instanceof RegressionModel) {
RegressionModel rm = (RegressionModel) serializable;
Regression model = new Regression(rm.getModelName(), rm, null, owner);
pmml4Models.add(model);
} else if (serializable instanceof TreeModel) {
TreeModel tm = (TreeModel) serializable;
Treemodel model = new Treemodel(tm.getModelName(), tm, null, owner);
pmml4Models.add(model);
} else if (serializable instanceof MiningModel) {
MiningModel mm = (MiningModel) serializable;
Miningmodel model = new Miningmodel(mm.getModelName(), mm, null, owner);
pmml4Models.add(model);
}
});
return pmml4Models;
}
use of org.dmg.pmml.pmml_4_2.descr.Scorecard in project drools by kiegroup.
the class GuidedScoreCardDRLPersistence method createPMMLDocument.
private static PMML createPMMLDocument(final ScoreCardModel model) {
final Scorecard pmmlScorecard = ScorecardPMMLUtils.createScorecard();
final Output output = new Output();
final Characteristics characteristics = new Characteristics();
final MiningSchema miningSchema = new MiningSchema();
Extension extension = new Extension();
extension.setName(PMMLExtensionNames.EXTERNAL_CLASS);
extension.setValue(model.getFactName());
pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().add(extension);
String agendaGroup = model.getAgendaGroup();
if (!StringUtils.isEmpty(agendaGroup)) {
extension = new Extension();
extension.setName(PMMLExtensionNames.AGENDA_GROUP);
extension.setValue(agendaGroup);
pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().add(extension);
}
String ruleFlowGroup = model.getRuleFlowGroup();
if (!StringUtils.isEmpty(ruleFlowGroup)) {
extension = new Extension();
extension.setName(PMMLExtensionNames.RULEFLOW_GROUP);
extension.setValue(agendaGroup);
pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().add(extension);
}
extension = new Extension();
extension.setName(PMMLExtensionNames.MODEL_IMPORTS);
pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().add(extension);
List<String> imports = new ArrayList<String>();
StringBuilder importBuilder = new StringBuilder();
for (Import imp : model.getImports().getImports()) {
if (!imports.contains(imp.getType())) {
imports.add(imp.getType());
importBuilder.append(imp.getType()).append(",");
}
}
extension.setValue(importBuilder.toString());
extension = new Extension();
extension.setName(ScorecardPMMLExtensionNames.SCORECARD_RESULTANT_SCORE_FIELD);
extension.setValue(model.getFieldName());
pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().add(extension);
extension = new Extension();
extension.setName(PMMLExtensionNames.MODEL_PACKAGE);
String pkgName = model.getPackageName();
extension.setValue(!(pkgName == null || pkgName.isEmpty()) ? pkgName : null);
pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().add(extension);
final String modelName = convertToJavaIdentifier(model.getName());
pmmlScorecard.setModelName(modelName);
pmmlScorecard.setInitialScore(model.getInitialScore());
pmmlScorecard.setUseReasonCodes(model.isUseReasonCodes());
if (model.isUseReasonCodes()) {
pmmlScorecard.setBaselineScore(model.getBaselineScore());
pmmlScorecard.setReasonCodeAlgorithm(model.getReasonCodesAlgorithm());
}
for (final org.drools.workbench.models.guided.scorecard.shared.Characteristic characteristic : model.getCharacteristics()) {
final Characteristic _characteristic = new Characteristic();
characteristics.getCharacteristics().add(_characteristic);
extension = new Extension();
extension.setName(PMMLExtensionNames.EXTERNAL_CLASS);
extension.setValue(characteristic.getFact());
_characteristic.getExtensions().add(extension);
extension = new Extension();
extension.setName(ScorecardPMMLExtensionNames.CHARACTERTISTIC_DATATYPE);
if ("string".equalsIgnoreCase(characteristic.getDataType())) {
extension.setValue(XLSKeywords.DATATYPE_TEXT);
} else if ("int".equalsIgnoreCase(characteristic.getDataType()) || "double".equalsIgnoreCase(characteristic.getDataType())) {
extension.setValue(XLSKeywords.DATATYPE_NUMBER);
} else if ("boolean".equalsIgnoreCase(characteristic.getDataType())) {
extension.setValue(XLSKeywords.DATATYPE_BOOLEAN);
} else {
System.out.println(">>>> Found unknown data type :: " + characteristic.getDataType());
}
_characteristic.getExtensions().add(extension);
_characteristic.setBaselineScore(characteristic.getBaselineScore());
if (model.isUseReasonCodes()) {
_characteristic.setReasonCode(characteristic.getReasonCode());
}
_characteristic.setName(characteristic.getName());
final MiningField miningField = new MiningField();
miningField.setName(characteristic.getField());
miningField.setUsageType(FIELDUSAGETYPE.ACTIVE);
miningField.setInvalidValueTreatment(INVALIDVALUETREATMENTMETHOD.RETURN_INVALID);
miningSchema.getMiningFields().add(miningField);
extension = new Extension();
extension.setName(PMMLExtensionNames.EXTERNAL_CLASS);
extension.setValue(characteristic.getFact());
miningField.getExtensions().add(extension);
for (final org.drools.workbench.models.guided.scorecard.shared.Attribute attribute : characteristic.getAttributes()) {
final Attribute _attribute = new Attribute();
_characteristic.getAttributes().add(_attribute);
extension = new Extension();
extension.setName(ScorecardPMMLExtensionNames.CHARACTERTISTIC_FIELD);
extension.setValue(characteristic.getField());
_attribute.getExtensions().add(extension);
if (model.isUseReasonCodes()) {
_attribute.setReasonCode(attribute.getReasonCode());
}
_attribute.setPartialScore(attribute.getPartialScore());
final String operator = attribute.getOperator();
final String dataType = characteristic.getDataType();
String predicateResolver;
if ("boolean".equalsIgnoreCase(dataType)) {
predicateResolver = operator.toUpperCase();
} else if ("String".equalsIgnoreCase(dataType)) {
if (operator.contains("=")) {
predicateResolver = operator + attribute.getValue();
} else {
predicateResolver = attribute.getValue() + ",";
}
} else {
if (NUMERIC_OPERATORS.contains(operator)) {
predicateResolver = operator + " " + attribute.getValue();
} else {
predicateResolver = attribute.getValue().replace(",", "-");
}
}
extension = new Extension();
extension.setName("predicateResolver");
extension.setValue(predicateResolver);
_attribute.getExtensions().add(extension);
}
}
pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().add(miningSchema);
pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().add(output);
pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().add(characteristics);
return new ScorecardPMMLGenerator().generateDocument(pmmlScorecard);
}
use of org.dmg.pmml.pmml_4_2.descr.Scorecard in project drools by kiegroup.
the class ExternalObjectModelTest method testWithReasonCodes.
@Test
public void testWithReasonCodes() throws Exception {
ScorecardCompiler scorecardCompiler2 = new ScorecardCompiler(EXTERNAL_OBJECT_MODEL);
PMML pmmlDocument2 = null;
String drl2 = null;
if (scorecardCompiler2.compileFromExcel(PMMLDocumentTest.class.getResourceAsStream("/scoremodel_externalmodel.xls"), "scorecards_reasoncode")) {
pmmlDocument2 = scorecardCompiler2.getPMMLDocument();
PMML4Compiler.dumpModel(pmmlDocument2, System.out);
assertNotNull(pmmlDocument2);
drl2 = scorecardCompiler2.getDRL();
// System.out.println(drl2);
} else {
for (ScorecardError error : scorecardCompiler2.getScorecardParseErrors()) {
System.out.println(error.getErrorLocation() + ":" + error.getErrorMessage());
}
fail("failed to parse scoremodel Excel (scorecards_reasoncode).");
}
assertNotNull(pmmlDocument2);
assertTrue(drl2 != null && !drl2.isEmpty());
KieServices ks = KieServices.Factory.get();
KieFileSystem kfs = ks.newKieFileSystem();
kfs.write(ks.getResources().newByteArrayResource(drl2.getBytes()).setSourcePath("test_scorecard_rules.drl").setResourceType(ResourceType.DRL));
KieBuilder kieBuilder = ks.newKieBuilder(kfs);
Results res = kieBuilder.buildAll().getResults();
KieContainer kieContainer = ks.newKieContainer(kieBuilder.getKieModule().getReleaseId());
KieBase kbase = kieContainer.getKieBase();
KieSession session = kbase.newKieSession();
FactType scorecardInternalsType = kbase.getFactType(PMML4Helper.pmmlDefaultPackageName(), "ScoreCard");
Applicant applicant = new Applicant();
applicant.setAge(10);
session.insert(applicant);
// session.addEventListener(new DebugWorkingMemoryEventListener());
session.fireAllRules();
// occupation = 0, age = 30, validLicence -1, initialScore=100
assertEquals(129.0, applicant.getTotalScore(), 0.0);
assertEquals("VL0099", applicant.getReasonCodes());
Object scorecardInternals = session.getObjects(new ClassObjectFilter(scorecardInternalsType.getFactClass())).iterator().next();
Assert.assertEquals(129.0, scorecardInternalsType.get(scorecardInternals, "score"));
Map reasonCodesMap = (Map) scorecardInternalsType.get(scorecardInternals, "ranking");
Assert.assertNotNull(reasonCodesMap);
Assert.assertEquals(Arrays.asList("VL0099", "AGE02"), new ArrayList(reasonCodesMap.keySet()));
session.dispose();
session = kbase.newKieSession();
applicant = new Applicant();
applicant.setOccupation("SKYDIVER");
applicant.setAge(0);
session.insert(applicant);
session.fireAllRules();
session.dispose();
// occupation = -10, age = +10, validLicense = -1, initialScore=100;
assertEquals(99.0, applicant.getTotalScore(), 0.0);
session = kbase.newKieSession();
applicant = new Applicant();
applicant.setResidenceState("AP");
applicant.setOccupation("TEACHER");
applicant.setAge(20);
applicant.setValidLicense(true);
session.insert(applicant);
session.fireAllRules();
session.dispose();
// occupation = +10, age = +40, state = -10, validLicense = 1, initialScore=100
assertEquals(141.0, applicant.getTotalScore(), 0.0);
}
use of org.dmg.pmml.pmml_4_2.descr.Scorecard in project drools by kiegroup.
the class ScorecardReasonCodeTest method testReasonCodes.
@Test
public void testReasonCodes() throws Exception {
final ScorecardCompiler scorecardCompiler = new ScorecardCompiler(INTERNAL_DECLARED_TYPES);
boolean compileResult = scorecardCompiler.compileFromExcel(PMMLDocumentTest.class.getResourceAsStream("/scoremodel_reasoncodes.xls"));
if (!compileResult) {
assertErrors(scorecardCompiler);
}
final PMML pmmlDocument = scorecardCompiler.getPMMLDocument();
for (Object serializable : pmmlDocument.getAssociationModelsAndBaselineModelsAndClusteringModels()) {
if (serializable instanceof Scorecard) {
for (Object obj : ((Scorecard) serializable).getExtensionsAndCharacteristicsAndMiningSchemas()) {
if (obj instanceof Characteristics) {
Characteristics characteristics = (Characteristics) obj;
assertEquals(4, characteristics.getCharacteristics().size());
for (Characteristic characteristic : characteristics.getCharacteristics()) {
for (Attribute attribute : characteristic.getAttributes()) {
assertNotNull(attribute.getReasonCode());
}
}
return;
}
}
}
}
fail();
}
Aggregations