Search in sources :

Example 21 with PMML

use of org.dmg.pmml.pmml_4_2.descr.PMML in project drools by kiegroup.

the class ScorecardReasonCodeTest method testReasonCodes.

@Test
public void testReasonCodes() throws Exception {
    final ScorecardCompiler scorecardCompiler = new ScorecardCompiler(INTERNAL_DECLARED_TYPES);
    boolean compileResult = scorecardCompiler.compileFromExcel(PMMLDocumentTest.class.getResourceAsStream("/scoremodel_reasoncodes.xls"));
    if (!compileResult) {
        assertErrors(scorecardCompiler);
    }
    final PMML pmmlDocument = scorecardCompiler.getPMMLDocument();
    for (Object serializable : pmmlDocument.getAssociationModelsAndBaselineModelsAndClusteringModels()) {
        if (serializable instanceof Scorecard) {
            for (Object obj : ((Scorecard) serializable).getExtensionsAndCharacteristicsAndMiningSchemas()) {
                if (obj instanceof Characteristics) {
                    Characteristics characteristics = (Characteristics) obj;
                    assertEquals(4, characteristics.getCharacteristics().size());
                    for (Characteristic characteristic : characteristics.getCharacteristics()) {
                        for (Attribute attribute : characteristic.getAttributes()) {
                            assertNotNull(attribute.getReasonCode());
                        }
                    }
                    return;
                }
            }
        }
    }
    fail();
}
Also used : Characteristics(org.dmg.pmml.pmml_4_2.descr.Characteristics) Attribute(org.dmg.pmml.pmml_4_2.descr.Attribute) Characteristic(org.dmg.pmml.pmml_4_2.descr.Characteristic) PMML(org.dmg.pmml.pmml_4_2.descr.PMML) Scorecard(org.dmg.pmml.pmml_4_2.descr.Scorecard) Test(org.junit.Test)

Example 22 with PMML

use of org.dmg.pmml.pmml_4_2.descr.PMML in project drools by kiegroup.

the class ScoringStrategiesTest method testScoringExtension.

@Test
public void testScoringExtension() throws Exception {
    PMML pmmlDocument;
    ScorecardCompiler scorecardCompiler = new ScorecardCompiler(INTERNAL_DECLARED_TYPES);
    if (scorecardCompiler.compileFromExcel(PMMLDocumentTest.class.getResourceAsStream("/scoremodel_scoring_strategies.xls"))) {
        pmmlDocument = scorecardCompiler.getPMMLDocument();
        assertNotNull(pmmlDocument);
        String drl = scorecardCompiler.getDRL();
        assertNotNull(drl);
        for (Object serializable : pmmlDocument.getAssociationModelsAndBaselineModelsAndClusteringModels()) {
            if (serializable instanceof Scorecard) {
                Scorecard scorecard = (Scorecard) serializable;
                assertEquals("Sample Score", scorecard.getModelName());
                Extension extension = ScorecardPMMLUtils.getExtension(scorecard.getExtensionsAndCharacteristicsAndMiningSchemas(), ScorecardPMMLExtensionNames.SCORECARD_SCORING_STRATEGY);
                assertNotNull(extension);
                assertEquals(extension.getValue(), AggregationStrategy.AGGREGATE_SCORE.toString());
                return;
            }
        }
    }
    fail();
}
Also used : Extension(org.dmg.pmml.pmml_4_2.descr.Extension) PMML(org.dmg.pmml.pmml_4_2.descr.PMML) Scorecard(org.dmg.pmml.pmml_4_2.descr.Scorecard) Test(org.junit.Test)

Example 23 with PMML

use of org.dmg.pmml.pmml_4_2.descr.PMML in project drools by kiegroup.

the class GuidedScoreCardDRLPersistence method marshal.

public static String marshal(final ScoreCardModel model) {
    final PMML pmml = createPMMLDocument(model);
    final StringBuilder sb = new StringBuilder();
    // Package statement and Imports are appended by org.drools.scorecards.drl.AbstractDRLEmitter
    // Build rules
    sb.append(ScorecardCompiler.convertToDRL(pmml, ScorecardCompiler.DrlType.EXTERNAL_OBJECT_MODEL));
    return sb.toString();
}
Also used : PMML(org.dmg.pmml.pmml_4_2.descr.PMML)

Example 24 with PMML

use of org.dmg.pmml.pmml_4_2.descr.PMML in project drools by kiegroup.

the class PMMLGenerationTest method testNNGenration.

@Test
public void testNNGenration() {
    PMML net = PMMLGeneratorUtils.generateSimpleNeuralNetwork(modelName, inputfieldNames, outputfieldNames, inputMeans, inputStds, outputMeans, outputStds, hiddenSize, weights);
    assertNotNull(net);
    ByteArrayOutputStream baos = new ByteArrayOutputStream();
    assertTrue(PMMLGeneratorUtils.streamPMML(net, baos));
    ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray());
    PMML4Compiler compiler = new PMML4Compiler();
    SchemaFactory sf = SchemaFactory.newInstance(XMLConstants.W3C_XML_SCHEMA_NS_URI);
    try {
        Schema schema = sf.newSchema(Thread.currentThread().getContextClassLoader().getResource(compiler.SCHEMA_PATH));
        schema.newValidator().validate(new StreamSource(bais));
    } catch (SAXException e) {
        fail(e.getMessage());
    } catch (IOException e) {
        fail(e.getMessage());
    }
    PMML net2 = null;
    try {
        bais.reset();
        JAXBContext ctx = JAXBContext.newInstance(PMML.class.getPackage().getName());
        net2 = (PMML) ctx.createUnmarshaller().unmarshal(bais);
    } catch (JAXBException e) {
        e.printStackTrace();
    }
    assertNotNull(net2);
    assertEquals(inputfieldNames.length + outputfieldNames.length, net2.getDataDictionary().getDataFields().size());
    assertEquals(net.getDataDictionary().getDataFields().size(), net2.getDataDictionary().getDataFields().size());
    NeuralNetwork n1 = (NeuralNetwork) net.getAssociationModelsAndBaselineModelsAndClusteringModels().get(0);
    NeuralNetwork n2 = (NeuralNetwork) net2.getAssociationModelsAndBaselineModelsAndClusteringModels().get(0);
    assertEquals(n1.getExtensionsAndNeuralLayersAndNeuralInputs().size(), n2.getExtensionsAndNeuralLayersAndNeuralInputs().size());
    assertEquals(6, n2.getExtensionsAndNeuralLayersAndNeuralInputs().size());
    NeuralLayer l1 = (NeuralLayer) n1.getExtensionsAndNeuralLayersAndNeuralInputs().get(3);
    NeuralLayer l2 = (NeuralLayer) n2.getExtensionsAndNeuralLayersAndNeuralInputs().get(3);
    assertEquals(l1.getNeurons().get(4).getCons().get(2).getWeight(), l2.getNeurons().get(4).getCons().get(2).getWeight(), 1e-9);
    assertEquals(weights[(inputfieldNames.length + 1) * 4 + 3], l2.getNeurons().get(4).getCons().get(2).getWeight(), 1e-9);
}
Also used : SchemaFactory(javax.xml.validation.SchemaFactory) Schema(javax.xml.validation.Schema) StreamSource(javax.xml.transform.stream.StreamSource) JAXBException(javax.xml.bind.JAXBException) JAXBContext(javax.xml.bind.JAXBContext) ByteArrayOutputStream(java.io.ByteArrayOutputStream) IOException(java.io.IOException) NeuralLayer(org.kie.dmg.pmml.pmml_4_2.descr.NeuralLayer) NeuralNetwork(org.kie.dmg.pmml.pmml_4_2.descr.NeuralNetwork) SAXException(org.xml.sax.SAXException) ByteArrayInputStream(java.io.ByteArrayInputStream) PMML(org.kie.dmg.pmml.pmml_4_2.descr.PMML) Test(org.junit.Test)

Example 25 with PMML

use of org.dmg.pmml.pmml_4_2.descr.PMML in project drools by kiegroup.

the class PMML4Compiler method addMissingFieldDefinition.

private void addMissingFieldDefinition(PMML pmml, MiningSegmentation msm, MiningSegment seg) {
    // get the list of models that may contain the field definition
    List<PMML4Model> models = msm.getMiningSegments().stream().filter(s -> s != seg && s.getSegmentIndex() < seg.getSegmentIndex()).map(iseg -> {
        return iseg.getModel();
    }).collect(Collectors.toList());
    seg.getModel().getMiningFields().stream().filter(mf -> !mf.isInDictionary()).forEach(pmf -> {
        String fldName = pmf.getName();
        boolean fieldAdded = false;
        for (Iterator<PMML4Model> iter = models.iterator(); iter.hasNext() && !fieldAdded; ) {
            PMML4Model mdl = iter.next();
            PMMLOutputField outfield = mdl.findOutputField(fldName);
            PMMLMiningField target = (outfield != null && outfield.getTargetField() != null) ? mdl.findMiningField(outfield.getTargetField()) : null;
            if (outfield != null) {
                DataField e = null;
                if (outfield.getRawDataField() != null && outfield.getRawDataField().getDataType() != null) {
                    e = outfield.getRawDataField();
                } else if (target != null) {
                    e = target.getRawDataField();
                }
                if (e != null) {
                    e.setName(fldName);
                    pmml.getDataDictionary().getDataFields().add(e);
                    BigInteger bi = pmml.getDataDictionary().getNumberOfFields();
                    pmml.getDataDictionary().setNumberOfFields(bi.add(BigInteger.ONE));
                    fieldAdded = true;
                }
            }
        }
    });
}
Also used : ResourceFactory(org.kie.internal.io.ResourceFactory) SupportVectorMachineModel(org.kie.dmg.pmml.pmml_4_2.descr.SupportVectorMachineModel) Miningmodel(org.kie.pmml.pmml_4_2.model.Miningmodel) TemplateRegistry(org.mvel2.templates.TemplateRegistry) Map(java.util.Map) BigInteger(java.math.BigInteger) Scorecard(org.kie.dmg.pmml.pmml_4_2.descr.Scorecard) KieSession(org.kie.api.runtime.KieSession) TemplateCompiler(org.mvel2.templates.TemplateCompiler) EventProcessingOption(org.kie.api.conf.EventProcessingOption) PMML4UnitImpl(org.kie.pmml.pmml_4_2.model.PMML4UnitImpl) PMMLMiningField(org.kie.pmml.pmml_4_2.model.PMMLMiningField) TreeModel(org.kie.dmg.pmml.pmml_4_2.descr.TreeModel) KnowledgeBuilderResult(org.kie.internal.builder.KnowledgeBuilderResult) KieBaseModel(org.kie.api.builder.model.KieBaseModel) Collectors(java.util.stream.Collectors) JAXBException(javax.xml.bind.JAXBException) Resource(org.kie.api.io.Resource) List(java.util.List) PMMLCompiler(org.drools.compiler.compiler.PMMLCompiler) KieSessionModel(org.kie.api.builder.model.KieSessionModel) SAXException(org.xml.sax.SAXException) Writer(java.io.Writer) NaiveBayesModel(org.kie.dmg.pmml.pmml_4_2.descr.NaiveBayesModel) UnsupportedEncodingException(java.io.UnsupportedEncodingException) Marshaller(javax.xml.bind.Marshaller) HashMap(java.util.HashMap) ResourceType(org.kie.api.io.ResourceType) Schema(javax.xml.validation.Schema) ArrayList(java.util.ArrayList) MiningSegmentation(org.kie.pmml.pmml_4_2.model.mining.MiningSegmentation) RegressionModel(org.kie.dmg.pmml.pmml_4_2.descr.RegressionModel) ClusteringModel(org.kie.dmg.pmml.pmml_4_2.descr.ClusteringModel) IoUtils(org.drools.core.util.IoUtils) PMMLResource(org.drools.compiler.compiler.PMMLResource) KieServices(org.kie.api.KieServices) OutputStreamWriter(java.io.OutputStreamWriter) ByteArrayResource(org.drools.core.io.impl.ByteArrayResource) XMLConstants(javax.xml.XMLConstants) KieBase(org.kie.api.KieBase) SimpleTemplateRegistry(org.mvel2.templates.SimpleTemplateRegistry) JAXBContext(javax.xml.bind.JAXBContext) MiningSegment(org.kie.pmml.pmml_4_2.model.mining.MiningSegment) OutputStream(java.io.OutputStream) Unmarshaller(javax.xml.bind.Unmarshaller) SchemaFactory(javax.xml.validation.SchemaFactory) Iterator(java.util.Iterator) DataField(org.kie.dmg.pmml.pmml_4_2.descr.DataField) KieContainer(org.kie.api.runtime.KieContainer) IOException(java.io.IOException) DataDictionary(org.kie.dmg.pmml.pmml_4_2.descr.DataDictionary) File(java.io.File) PMMLOutputField(org.kie.pmml.pmml_4_2.model.PMMLOutputField) ClassPathResource(org.drools.core.io.impl.ClassPathResource) KieModuleModel(org.kie.api.builder.model.KieModuleModel) PMML(org.kie.dmg.pmml.pmml_4_2.descr.PMML) Collections(java.util.Collections) InputStream(java.io.InputStream) NeuralNetwork(org.kie.dmg.pmml.pmml_4_2.descr.NeuralNetwork) PMMLOutputField(org.kie.pmml.pmml_4_2.model.PMMLOutputField) DataField(org.kie.dmg.pmml.pmml_4_2.descr.DataField) BigInteger(java.math.BigInteger) PMMLMiningField(org.kie.pmml.pmml_4_2.model.PMMLMiningField)

Aggregations

PMML (org.dmg.pmml.pmml_4_2.descr.PMML)20 Test (org.junit.Test)17 KieSession (org.kie.api.runtime.KieSession)10 Scorecard (org.dmg.pmml.pmml_4_2.descr.Scorecard)8 PMML (org.kie.dmg.pmml.pmml_4_2.descr.PMML)8 TreeModel (org.dmg.pmml.pmml_4_2.descr.TreeModel)7 FactType (org.kie.api.definition.type.FactType)7 DroolsAbstractPMMLTest (org.drools.pmml.pmml_4_2.DroolsAbstractPMMLTest)6 PMML4Compiler (org.drools.pmml.pmml_4_2.PMML4Compiler)6 KieServices (org.kie.api.KieServices)6 JAXBContext (javax.xml.bind.JAXBContext)5 JAXBException (javax.xml.bind.JAXBException)5 KieContainer (org.kie.api.runtime.KieContainer)5 ArrayList (java.util.ArrayList)4 KieBase (org.kie.api.KieBase)4 KnowledgeBuilderResult (org.kie.internal.builder.KnowledgeBuilderResult)4 ByteArrayOutputStream (java.io.ByteArrayOutputStream)3 IOException (java.io.IOException)3 HashMap (java.util.HashMap)3 Map (java.util.Map)3