Search in sources :

Example 6 with Scorecard

use of org.dmg.pmml.pmml_4_2.descr.Scorecard in project drools by kiegroup.

the class PMML4Compiler method checkBuildingResources.

private static KieBase checkBuildingResources(PMML pmml) throws IOException {
    KieServices ks = KieServices.Factory.get();
    KieContainer kieContainer = ks.getKieClasspathContainer();
    if (registry == null) {
        initRegistry();
    }
    String chosenKieBase = null;
    for (Object o : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels()) {
        if (o instanceof NaiveBayesModel) {
            if (!naiveBayesLoaded) {
                for (String ntempl : NAIVE_BAYES_TEMPLATES) {
                    prepareTemplate(ntempl);
                }
                naiveBayesLoaded = true;
            }
            chosenKieBase = chosenKieBase == null ? "KiePMML-Bayes" : "KiePMML";
        }
        if (o instanceof NeuralNetwork) {
            if (!neuralLoaded) {
                for (String ntempl : NEURAL_TEMPLATES) {
                    prepareTemplate(ntempl);
                }
                neuralLoaded = true;
            }
            chosenKieBase = chosenKieBase == null ? "KiePMML-Neural" : "KiePMML";
        }
        if (o instanceof ClusteringModel) {
            if (!clusteringLoaded) {
                for (String ntempl : CLUSTERING_TEMPLATES) {
                    prepareTemplate(ntempl);
                }
                clusteringLoaded = true;
            }
            chosenKieBase = chosenKieBase == null ? "KiePMML-Cluster" : "KiePMML";
        }
        if (o instanceof SupportVectorMachineModel) {
            if (!svmLoaded) {
                for (String ntempl : SVM_TEMPLATES) {
                    prepareTemplate(ntempl);
                }
                svmLoaded = true;
            }
            chosenKieBase = chosenKieBase == null ? "KiePMML-SVM" : "KiePMML";
        }
        if (o instanceof TreeModel) {
            if (!treeLoaded) {
                for (String ntempl : TREE_TEMPLATES) {
                    prepareTemplate(ntempl);
                }
                treeLoaded = true;
            }
            chosenKieBase = chosenKieBase == null ? "KiePMML-Tree" : "KiePMML";
        }
        if (o instanceof RegressionModel) {
            if (!simpleRegLoaded) {
                for (String ntempl : SIMPLEREG_TEMPLATES) {
                    prepareTemplate(ntempl);
                }
                simpleRegLoaded = true;
            }
            chosenKieBase = chosenKieBase == null ? "KiePMML-Regression" : "KiePMML";
        }
        if (o instanceof Scorecard) {
            if (!scorecardLoaded) {
                for (String ntempl : SCORECARD_TEMPLATES) {
                    prepareTemplate(ntempl);
                }
                scorecardLoaded = true;
            }
            chosenKieBase = chosenKieBase == null ? "KiePMML-Scorecard" : "KiePMML";
        }
    }
    if (chosenKieBase == null) {
        chosenKieBase = "KiePMML-Base";
    }
    return kieContainer.getKieBase(chosenKieBase);
}
Also used : TreeModel(org.kie.dmg.pmml.pmml_4_2.descr.TreeModel) KieServices(org.kie.api.KieServices) NaiveBayesModel(org.kie.dmg.pmml.pmml_4_2.descr.NaiveBayesModel) NeuralNetwork(org.kie.dmg.pmml.pmml_4_2.descr.NeuralNetwork) SupportVectorMachineModel(org.kie.dmg.pmml.pmml_4_2.descr.SupportVectorMachineModel) Scorecard(org.kie.dmg.pmml.pmml_4_2.descr.Scorecard) KieContainer(org.kie.api.runtime.KieContainer) ClusteringModel(org.kie.dmg.pmml.pmml_4_2.descr.ClusteringModel) RegressionModel(org.kie.dmg.pmml.pmml_4_2.descr.RegressionModel)

Example 7 with Scorecard

use of org.dmg.pmml.pmml_4_2.descr.Scorecard in project drools by kiegroup.

the class PMML4ModelFactory method getModels.

public List<PMML4Model> getModels(PMML4Unit owner) {
    List<PMML4Model> pmml4Models = new ArrayList<>();
    owner.getRawPMML().getAssociationModelsAndBaselineModelsAndClusteringModels().forEach(serializable -> {
        if (serializable instanceof Scorecard) {
            Scorecard sc = (Scorecard) serializable;
            ScorecardModel model = new ScorecardModel(sc.getModelName(), sc, null, owner);
            pmml4Models.add(model);
        } else if (serializable instanceof RegressionModel) {
            RegressionModel rm = (RegressionModel) serializable;
            Regression model = new Regression(rm.getModelName(), rm, null, owner);
            pmml4Models.add(model);
        } else if (serializable instanceof TreeModel) {
            TreeModel tm = (TreeModel) serializable;
            Treemodel model = new Treemodel(tm.getModelName(), tm, null, owner);
            pmml4Models.add(model);
        } else if (serializable instanceof MiningModel) {
            MiningModel mm = (MiningModel) serializable;
            Miningmodel model = new Miningmodel(mm.getModelName(), mm, null, owner);
            pmml4Models.add(model);
        }
    });
    return pmml4Models;
}
Also used : TreeModel(org.kie.dmg.pmml.pmml_4_2.descr.TreeModel) MiningModel(org.kie.dmg.pmml.pmml_4_2.descr.MiningModel) PMML4Model(org.kie.pmml.pmml_4_2.PMML4Model) ArrayList(java.util.ArrayList) Scorecard(org.kie.dmg.pmml.pmml_4_2.descr.Scorecard) RegressionModel(org.kie.dmg.pmml.pmml_4_2.descr.RegressionModel)

Example 8 with Scorecard

use of org.dmg.pmml.pmml_4_2.descr.Scorecard in project drools by kiegroup.

the class GuidedScoreCardDRLPersistence method createPMMLDocument.

private static PMML createPMMLDocument(final ScoreCardModel model) {
    final Scorecard pmmlScorecard = ScorecardPMMLUtils.createScorecard();
    final Output output = new Output();
    final Characteristics characteristics = new Characteristics();
    final MiningSchema miningSchema = new MiningSchema();
    Extension extension = new Extension();
    extension.setName(PMMLExtensionNames.EXTERNAL_CLASS);
    extension.setValue(model.getFactName());
    pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().add(extension);
    String agendaGroup = model.getAgendaGroup();
    if (!StringUtils.isEmpty(agendaGroup)) {
        extension = new Extension();
        extension.setName(PMMLExtensionNames.AGENDA_GROUP);
        extension.setValue(agendaGroup);
        pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().add(extension);
    }
    String ruleFlowGroup = model.getRuleFlowGroup();
    if (!StringUtils.isEmpty(ruleFlowGroup)) {
        extension = new Extension();
        extension.setName(PMMLExtensionNames.RULEFLOW_GROUP);
        extension.setValue(agendaGroup);
        pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().add(extension);
    }
    extension = new Extension();
    extension.setName(PMMLExtensionNames.MODEL_IMPORTS);
    pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().add(extension);
    List<String> imports = new ArrayList<String>();
    StringBuilder importBuilder = new StringBuilder();
    for (Import imp : model.getImports().getImports()) {
        if (!imports.contains(imp.getType())) {
            imports.add(imp.getType());
            importBuilder.append(imp.getType()).append(",");
        }
    }
    extension.setValue(importBuilder.toString());
    extension = new Extension();
    extension.setName(ScorecardPMMLExtensionNames.SCORECARD_RESULTANT_SCORE_FIELD);
    extension.setValue(model.getFieldName());
    pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().add(extension);
    extension = new Extension();
    extension.setName(PMMLExtensionNames.MODEL_PACKAGE);
    String pkgName = model.getPackageName();
    extension.setValue(!(pkgName == null || pkgName.isEmpty()) ? pkgName : null);
    pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().add(extension);
    final String modelName = convertToJavaIdentifier(model.getName());
    pmmlScorecard.setModelName(modelName);
    pmmlScorecard.setInitialScore(model.getInitialScore());
    pmmlScorecard.setUseReasonCodes(model.isUseReasonCodes());
    if (model.isUseReasonCodes()) {
        pmmlScorecard.setBaselineScore(model.getBaselineScore());
        pmmlScorecard.setReasonCodeAlgorithm(model.getReasonCodesAlgorithm());
    }
    for (final org.drools.workbench.models.guided.scorecard.shared.Characteristic characteristic : model.getCharacteristics()) {
        final Characteristic _characteristic = new Characteristic();
        characteristics.getCharacteristics().add(_characteristic);
        extension = new Extension();
        extension.setName(PMMLExtensionNames.EXTERNAL_CLASS);
        extension.setValue(characteristic.getFact());
        _characteristic.getExtensions().add(extension);
        extension = new Extension();
        extension.setName(ScorecardPMMLExtensionNames.CHARACTERTISTIC_DATATYPE);
        if ("string".equalsIgnoreCase(characteristic.getDataType())) {
            extension.setValue(XLSKeywords.DATATYPE_TEXT);
        } else if ("int".equalsIgnoreCase(characteristic.getDataType()) || "double".equalsIgnoreCase(characteristic.getDataType())) {
            extension.setValue(XLSKeywords.DATATYPE_NUMBER);
        } else if ("boolean".equalsIgnoreCase(characteristic.getDataType())) {
            extension.setValue(XLSKeywords.DATATYPE_BOOLEAN);
        } else {
            System.out.println(">>>> Found unknown data type :: " + characteristic.getDataType());
        }
        _characteristic.getExtensions().add(extension);
        _characteristic.setBaselineScore(characteristic.getBaselineScore());
        if (model.isUseReasonCodes()) {
            _characteristic.setReasonCode(characteristic.getReasonCode());
        }
        _characteristic.setName(characteristic.getName());
        final MiningField miningField = new MiningField();
        miningField.setName(characteristic.getField());
        miningField.setUsageType(FIELDUSAGETYPE.ACTIVE);
        miningField.setInvalidValueTreatment(INVALIDVALUETREATMENTMETHOD.RETURN_INVALID);
        miningSchema.getMiningFields().add(miningField);
        extension = new Extension();
        extension.setName(PMMLExtensionNames.EXTERNAL_CLASS);
        extension.setValue(characteristic.getFact());
        miningField.getExtensions().add(extension);
        for (final org.drools.workbench.models.guided.scorecard.shared.Attribute attribute : characteristic.getAttributes()) {
            final Attribute _attribute = new Attribute();
            _characteristic.getAttributes().add(_attribute);
            extension = new Extension();
            extension.setName(ScorecardPMMLExtensionNames.CHARACTERTISTIC_FIELD);
            extension.setValue(characteristic.getField());
            _attribute.getExtensions().add(extension);
            if (model.isUseReasonCodes()) {
                _attribute.setReasonCode(attribute.getReasonCode());
            }
            _attribute.setPartialScore(attribute.getPartialScore());
            final String operator = attribute.getOperator();
            final String dataType = characteristic.getDataType();
            String predicateResolver;
            if ("boolean".equalsIgnoreCase(dataType)) {
                predicateResolver = operator.toUpperCase();
            } else if ("String".equalsIgnoreCase(dataType)) {
                if (operator.contains("=")) {
                    predicateResolver = operator + attribute.getValue();
                } else {
                    predicateResolver = attribute.getValue() + ",";
                }
            } else {
                if (NUMERIC_OPERATORS.contains(operator)) {
                    predicateResolver = operator + " " + attribute.getValue();
                } else {
                    predicateResolver = attribute.getValue().replace(",", "-");
                }
            }
            extension = new Extension();
            extension.setName("predicateResolver");
            extension.setValue(predicateResolver);
            _attribute.getExtensions().add(extension);
        }
    }
    pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().add(miningSchema);
    pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().add(output);
    pmmlScorecard.getExtensionsAndCharacteristicsAndMiningSchemas().add(characteristics);
    return new ScorecardPMMLGenerator().generateDocument(pmmlScorecard);
}
Also used : MiningField(org.dmg.pmml.pmml_4_2.descr.MiningField) Import(org.kie.soup.project.datamodel.imports.Import) Attribute(org.dmg.pmml.pmml_4_2.descr.Attribute) Characteristic(org.dmg.pmml.pmml_4_2.descr.Characteristic) ArrayList(java.util.ArrayList) Extension(org.dmg.pmml.pmml_4_2.descr.Extension) MiningSchema(org.dmg.pmml.pmml_4_2.descr.MiningSchema) Characteristics(org.dmg.pmml.pmml_4_2.descr.Characteristics) Output(org.dmg.pmml.pmml_4_2.descr.Output) ScorecardPMMLGenerator(org.drools.scorecards.pmml.ScorecardPMMLGenerator) Scorecard(org.dmg.pmml.pmml_4_2.descr.Scorecard)

Example 9 with Scorecard

use of org.dmg.pmml.pmml_4_2.descr.Scorecard in project drools by kiegroup.

the class ExternalObjectModelTest method testWithReasonCodes.

@Test
public void testWithReasonCodes() throws Exception {
    ScorecardCompiler scorecardCompiler2 = new ScorecardCompiler(EXTERNAL_OBJECT_MODEL);
    PMML pmmlDocument2 = null;
    String drl2 = null;
    if (scorecardCompiler2.compileFromExcel(PMMLDocumentTest.class.getResourceAsStream("/scoremodel_externalmodel.xls"), "scorecards_reasoncode")) {
        pmmlDocument2 = scorecardCompiler2.getPMMLDocument();
        PMML4Compiler.dumpModel(pmmlDocument2, System.out);
        assertNotNull(pmmlDocument2);
        drl2 = scorecardCompiler2.getDRL();
    // System.out.println(drl2);
    } else {
        for (ScorecardError error : scorecardCompiler2.getScorecardParseErrors()) {
            System.out.println(error.getErrorLocation() + ":" + error.getErrorMessage());
        }
        fail("failed to parse scoremodel Excel (scorecards_reasoncode).");
    }
    assertNotNull(pmmlDocument2);
    assertTrue(drl2 != null && !drl2.isEmpty());
    KieServices ks = KieServices.Factory.get();
    KieFileSystem kfs = ks.newKieFileSystem();
    kfs.write(ks.getResources().newByteArrayResource(drl2.getBytes()).setSourcePath("test_scorecard_rules.drl").setResourceType(ResourceType.DRL));
    KieBuilder kieBuilder = ks.newKieBuilder(kfs);
    Results res = kieBuilder.buildAll().getResults();
    KieContainer kieContainer = ks.newKieContainer(kieBuilder.getKieModule().getReleaseId());
    KieBase kbase = kieContainer.getKieBase();
    KieSession session = kbase.newKieSession();
    FactType scorecardInternalsType = kbase.getFactType(PMML4Helper.pmmlDefaultPackageName(), "ScoreCard");
    Applicant applicant = new Applicant();
    applicant.setAge(10);
    session.insert(applicant);
    // session.addEventListener(new DebugWorkingMemoryEventListener());
    session.fireAllRules();
    // occupation = 0, age = 30, validLicence -1, initialScore=100
    assertEquals(129.0, applicant.getTotalScore(), 0.0);
    assertEquals("VL0099", applicant.getReasonCodes());
    Object scorecardInternals = session.getObjects(new ClassObjectFilter(scorecardInternalsType.getFactClass())).iterator().next();
    Assert.assertEquals(129.0, scorecardInternalsType.get(scorecardInternals, "score"));
    Map reasonCodesMap = (Map) scorecardInternalsType.get(scorecardInternals, "ranking");
    Assert.assertNotNull(reasonCodesMap);
    Assert.assertEquals(Arrays.asList("VL0099", "AGE02"), new ArrayList(reasonCodesMap.keySet()));
    session.dispose();
    session = kbase.newKieSession();
    applicant = new Applicant();
    applicant.setOccupation("SKYDIVER");
    applicant.setAge(0);
    session.insert(applicant);
    session.fireAllRules();
    session.dispose();
    // occupation = -10, age = +10, validLicense = -1, initialScore=100;
    assertEquals(99.0, applicant.getTotalScore(), 0.0);
    session = kbase.newKieSession();
    applicant = new Applicant();
    applicant.setResidenceState("AP");
    applicant.setOccupation("TEACHER");
    applicant.setAge(20);
    applicant.setValidLicense(true);
    session.insert(applicant);
    session.fireAllRules();
    session.dispose();
    // occupation = +10, age = +40, state = -10, validLicense = 1, initialScore=100
    assertEquals(141.0, applicant.getTotalScore(), 0.0);
}
Also used : KieFileSystem(org.kie.api.builder.KieFileSystem) ArrayList(java.util.ArrayList) KieServices(org.kie.api.KieServices) FactType(org.kie.api.definition.type.FactType) ClassObjectFilter(org.kie.api.runtime.ClassObjectFilter) Results(org.kie.api.builder.Results) KieBase(org.kie.api.KieBase) PMML(org.dmg.pmml.pmml_4_2.descr.PMML) KieSession(org.kie.api.runtime.KieSession) KieBuilder(org.kie.api.builder.KieBuilder) Applicant(org.drools.scorecards.example.Applicant) Map(java.util.Map) KieContainer(org.kie.api.runtime.KieContainer) Test(org.junit.Test)

Example 10 with Scorecard

use of org.dmg.pmml.pmml_4_2.descr.Scorecard in project drools by kiegroup.

the class ScorecardReasonCodeTest method testReasonCodes.

@Test
public void testReasonCodes() throws Exception {
    final ScorecardCompiler scorecardCompiler = new ScorecardCompiler(INTERNAL_DECLARED_TYPES);
    boolean compileResult = scorecardCompiler.compileFromExcel(PMMLDocumentTest.class.getResourceAsStream("/scoremodel_reasoncodes.xls"));
    if (!compileResult) {
        assertErrors(scorecardCompiler);
    }
    final PMML pmmlDocument = scorecardCompiler.getPMMLDocument();
    for (Object serializable : pmmlDocument.getAssociationModelsAndBaselineModelsAndClusteringModels()) {
        if (serializable instanceof Scorecard) {
            for (Object obj : ((Scorecard) serializable).getExtensionsAndCharacteristicsAndMiningSchemas()) {
                if (obj instanceof Characteristics) {
                    Characteristics characteristics = (Characteristics) obj;
                    assertEquals(4, characteristics.getCharacteristics().size());
                    for (Characteristic characteristic : characteristics.getCharacteristics()) {
                        for (Attribute attribute : characteristic.getAttributes()) {
                            assertNotNull(attribute.getReasonCode());
                        }
                    }
                    return;
                }
            }
        }
    }
    fail();
}
Also used : Characteristics(org.dmg.pmml.pmml_4_2.descr.Characteristics) Attribute(org.dmg.pmml.pmml_4_2.descr.Attribute) Characteristic(org.dmg.pmml.pmml_4_2.descr.Characteristic) PMML(org.dmg.pmml.pmml_4_2.descr.PMML) Scorecard(org.dmg.pmml.pmml_4_2.descr.Scorecard) Test(org.junit.Test)

Aggregations

Scorecard (org.dmg.pmml.pmml_4_2.descr.Scorecard)8 PMML (org.dmg.pmml.pmml_4_2.descr.PMML)7 Test (org.junit.Test)7 ArrayList (java.util.ArrayList)3 Characteristics (org.dmg.pmml.pmml_4_2.descr.Characteristics)3 Extension (org.dmg.pmml.pmml_4_2.descr.Extension)3 KieServices (org.kie.api.KieServices)3 KieContainer (org.kie.api.runtime.KieContainer)3 RegressionModel (org.kie.dmg.pmml.pmml_4_2.descr.RegressionModel)3 Scorecard (org.kie.dmg.pmml.pmml_4_2.descr.Scorecard)3 TreeModel (org.kie.dmg.pmml.pmml_4_2.descr.TreeModel)3 Attribute (org.dmg.pmml.pmml_4_2.descr.Attribute)2 Characteristic (org.dmg.pmml.pmml_4_2.descr.Characteristic)2 Output (org.dmg.pmml.pmml_4_2.descr.Output)2 MiningModel (org.kie.dmg.pmml.pmml_4_2.descr.MiningModel)2 PMML4Model (org.kie.pmml.pmml_4_2.PMML4Model)2 Map (java.util.Map)1 ClusteringModel (org.dmg.pmml.pmml_4_2.descr.ClusteringModel)1 MiningField (org.dmg.pmml.pmml_4_2.descr.MiningField)1 MiningSchema (org.dmg.pmml.pmml_4_2.descr.MiningSchema)1