Search in sources :

Example 11 with PMML

use of org.kie.dmg.pmml.pmml_4_2.descr.PMML in project drools by kiegroup.

the class PMML4Compiler method compile.

public String compile(InputStream source, ClassLoader classLoader) {
    this.results = new ArrayList<KnowledgeBuilderResult>();
    PMML pmml = loadModel(PMML, source);
    helper.setResolver(classLoader);
    if (getResults().isEmpty()) {
        return generateTheory(pmml);
    } else {
        return null;
    }
}
Also used : PMML(org.dmg.pmml.pmml_4_2.descr.PMML) KnowledgeBuilderResult(org.kie.internal.builder.KnowledgeBuilderResult)

Example 12 with PMML

use of org.kie.dmg.pmml.pmml_4_2.descr.PMML in project drools by kiegroup.

the class PMMLGenerationTest method testNNGenration.

@Test
public void testNNGenration() {
    PMML net = PMMLGeneratorUtils.generateSimpleNeuralNetwork(modelName, inputfieldNames, outputfieldNames, inputMeans, inputStds, outputMeans, outputStds, hiddenSize, weights);
    assertNotNull(net);
    ByteArrayOutputStream baos = new ByteArrayOutputStream();
    assertTrue(PMMLGeneratorUtils.streamPMML(net, baos));
    ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray());
    PMML4Compiler compiler = new PMML4Compiler();
    SchemaFactory sf = SchemaFactory.newInstance(XMLConstants.W3C_XML_SCHEMA_NS_URI);
    try {
        Schema schema = sf.newSchema(Thread.currentThread().getContextClassLoader().getResource(compiler.SCHEMA_PATH));
        schema.newValidator().validate(new StreamSource(bais));
    } catch (SAXException e) {
        fail(e.getMessage());
    } catch (IOException e) {
        fail(e.getMessage());
    }
    PMML net2 = null;
    try {
        bais.reset();
        JAXBContext ctx = JAXBContext.newInstance(PMML.class.getPackage().getName());
        net2 = (PMML) ctx.createUnmarshaller().unmarshal(bais);
    } catch (JAXBException e) {
        e.printStackTrace();
    }
    assertNotNull(net2);
    assertEquals(inputfieldNames.length + outputfieldNames.length, net2.getDataDictionary().getDataFields().size());
    assertEquals(net.getDataDictionary().getDataFields().size(), net2.getDataDictionary().getDataFields().size());
    NeuralNetwork n1 = (NeuralNetwork) net.getAssociationModelsAndBaselineModelsAndClusteringModels().get(0);
    NeuralNetwork n2 = (NeuralNetwork) net2.getAssociationModelsAndBaselineModelsAndClusteringModels().get(0);
    assertEquals(n1.getExtensionsAndNeuralLayersAndNeuralInputs().size(), n2.getExtensionsAndNeuralLayersAndNeuralInputs().size());
    assertEquals(6, n2.getExtensionsAndNeuralLayersAndNeuralInputs().size());
    NeuralLayer l1 = (NeuralLayer) n1.getExtensionsAndNeuralLayersAndNeuralInputs().get(3);
    NeuralLayer l2 = (NeuralLayer) n2.getExtensionsAndNeuralLayersAndNeuralInputs().get(3);
    assertEquals(l1.getNeurons().get(4).getCons().get(2).getWeight(), l2.getNeurons().get(4).getCons().get(2).getWeight(), 1e-9);
    assertEquals(weights[(inputfieldNames.length + 1) * 4 + 3], l2.getNeurons().get(4).getCons().get(2).getWeight(), 1e-9);
}
Also used : SchemaFactory(javax.xml.validation.SchemaFactory) Schema(javax.xml.validation.Schema) StreamSource(javax.xml.transform.stream.StreamSource) JAXBException(javax.xml.bind.JAXBException) JAXBContext(javax.xml.bind.JAXBContext) ByteArrayOutputStream(java.io.ByteArrayOutputStream) IOException(java.io.IOException) NeuralLayer(org.dmg.pmml.pmml_4_2.descr.NeuralLayer) NeuralNetwork(org.dmg.pmml.pmml_4_2.descr.NeuralNetwork) SAXException(org.xml.sax.SAXException) ByteArrayInputStream(java.io.ByteArrayInputStream) PMML(org.dmg.pmml.pmml_4_2.descr.PMML) Test(org.junit.Test)

Example 13 with PMML

use of org.kie.dmg.pmml.pmml_4_2.descr.PMML in project drools by kiegroup.

the class DecisionTreeTest method testMissingTreeDefault.

@Test
public void testMissingTreeDefault() throws Exception {
    PMML4Compiler compiler = new PMML4Compiler();
    PMML pmml = compiler.loadModel(PMML, ResourceFactory.newClassPathResource(source2).getInputStream());
    for (Object o : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels()) {
        if (o instanceof TreeModel) {
            TreeModel tree = (TreeModel) o;
            tree.setMissingValueStrategy(MISSINGVALUESTRATEGY.DEFAULT_CHILD);
        }
    }
    KieSession kSession = getSession(compiler.generateTheory(pmml));
    setKSession(kSession);
    setKbase(getKSession().getKieBase());
    // init model
    kSession.fireAllRules();
    FactType tgt = kSession.getKieBase().getFactType(packageName, "Fld9");
    FactType tok = kSession.getKieBase().getFactType(PMML4Helper.pmmlDefaultPackageName(), "TreeToken");
    kSession.getEntryPoint("in_Fld1").insert(70.0);
    kSession.getEntryPoint("in_Fld2").insert(40.0);
    kSession.getEntryPoint("in_Fld3").insert("miss");
    kSession.fireAllRules();
    Object token = getToken(kSession);
    assertEquals(0.72, (Double) tok.get(token, "confidence"), 1e-6);
    assertEquals("null", tok.get(token, "current"));
    assertEquals(40.0, tok.get(token, "totalCount"));
    checkFirstDataFieldOfTypeStatus(tgt, true, false, "Missing", "tgtX");
    checkGeneratedRules();
}
Also used : TreeModel(org.dmg.pmml.pmml_4_2.descr.TreeModel) PMML(org.dmg.pmml.pmml_4_2.descr.PMML) KieSession(org.kie.api.runtime.KieSession) PMML4Compiler(org.drools.pmml.pmml_4_2.PMML4Compiler) FactType(org.kie.api.definition.type.FactType) DroolsAbstractPMMLTest(org.drools.pmml.pmml_4_2.DroolsAbstractPMMLTest) Test(org.junit.Test)

Example 14 with PMML

use of org.kie.dmg.pmml.pmml_4_2.descr.PMML in project drools by kiegroup.

the class DecisionTreeTest method testMissingTreeAllMissingDefault.

@Test
public void testMissingTreeAllMissingDefault() throws Exception {
    PMML4Compiler compiler = new PMML4Compiler();
    PMML pmml = compiler.loadModel(PMML, ResourceFactory.newClassPathResource(source2).getInputStream());
    for (Object o : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels()) {
        if (o instanceof TreeModel) {
            TreeModel tree = (TreeModel) o;
            tree.setMissingValueStrategy(MISSINGVALUESTRATEGY.DEFAULT_CHILD);
        }
    }
    String theory = compiler.generateTheory(pmml);
    if (VERBOSE) {
        System.out.println(theory);
    }
    KieSession kSession = getSession(theory);
    setKSession(kSession);
    setKbase(getKSession().getKieBase());
    // init model
    kSession.fireAllRules();
    FactType tgt = kSession.getKieBase().getFactType(packageName, "Fld9");
    FactType tok = kSession.getKieBase().getFactType(PMML4Helper.pmmlDefaultPackageName(), "TreeToken");
    kSession.getEntryPoint("in_Fld1").insert(-1.0);
    kSession.getEntryPoint("in_Fld2").insert(-1.0);
    kSession.getEntryPoint("in_Fld3").insert("miss");
    kSession.fireAllRules();
    Object token = getToken(kSession);
    assertEquals(1.0, (Double) tok.get(token, "confidence"), 1e-6);
    assertEquals("null", tok.get(token, "current"));
    assertEquals(0.0, tok.get(token, "totalCount"));
    // checkFirstDataFieldOfTypeStatus(tgt, true, false, "Missing", "tgtX" );
    checkGeneratedRules();
}
Also used : TreeModel(org.dmg.pmml.pmml_4_2.descr.TreeModel) PMML(org.dmg.pmml.pmml_4_2.descr.PMML) KieSession(org.kie.api.runtime.KieSession) PMML4Compiler(org.drools.pmml.pmml_4_2.PMML4Compiler) FactType(org.kie.api.definition.type.FactType) DroolsAbstractPMMLTest(org.drools.pmml.pmml_4_2.DroolsAbstractPMMLTest) Test(org.junit.Test)

Example 15 with PMML

use of org.kie.dmg.pmml.pmml_4_2.descr.PMML in project drools by kiegroup.

the class DecisionTreeTest method testMissingTreeNone.

@Test
public void testMissingTreeNone() throws Exception {
    PMML4Compiler compiler = new PMML4Compiler();
    PMML pmml = compiler.loadModel(PMML, ResourceFactory.newClassPathResource(source2).getInputStream());
    for (Object o : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels()) {
        if (o instanceof TreeModel) {
            TreeModel tree = (TreeModel) o;
            tree.setMissingValueStrategy(MISSINGVALUESTRATEGY.NONE);
        }
    }
    String theory = compiler.generateTheory(pmml);
    if (VERBOSE) {
        System.out.println(theory);
    }
    KieSession kSession = getSession(theory);
    setKSession(kSession);
    setKbase(getKSession().getKieBase());
    // init model
    kSession.fireAllRules();
    FactType tgt = kSession.getKieBase().getFactType(packageName, "Fld9");
    FactType tok = kSession.getKieBase().getFactType(PMML4Helper.pmmlDefaultPackageName(), "TreeToken");
    kSession.getEntryPoint("in_Fld1").insert(-1.0);
    kSession.getEntryPoint("in_Fld2").insert(-1.0);
    kSession.getEntryPoint("in_Fld3").insert("miss");
    kSession.fireAllRules();
    Object token = getToken(kSession);
    assertEquals(0.6, (Double) tok.get(token, "confidence"), 1e-6);
    assertEquals("null", tok.get(token, "current"));
    assertEquals(100.0, tok.get(token, "totalCount"));
    checkFirstDataFieldOfTypeStatus(tgt, true, false, "Missing", "tgtX");
    checkGeneratedRules();
}
Also used : TreeModel(org.dmg.pmml.pmml_4_2.descr.TreeModel) PMML(org.dmg.pmml.pmml_4_2.descr.PMML) KieSession(org.kie.api.runtime.KieSession) PMML4Compiler(org.drools.pmml.pmml_4_2.PMML4Compiler) FactType(org.kie.api.definition.type.FactType) DroolsAbstractPMMLTest(org.drools.pmml.pmml_4_2.DroolsAbstractPMMLTest) Test(org.junit.Test)

Aggregations

PMML (org.dmg.pmml.pmml_4_2.descr.PMML)20 Test (org.junit.Test)17 KieSession (org.kie.api.runtime.KieSession)10 PMML (org.kie.dmg.pmml.pmml_4_2.descr.PMML)8 FactType (org.kie.api.definition.type.FactType)7 Scorecard (org.dmg.pmml.pmml_4_2.descr.Scorecard)6 TreeModel (org.dmg.pmml.pmml_4_2.descr.TreeModel)6 DroolsAbstractPMMLTest (org.drools.pmml.pmml_4_2.DroolsAbstractPMMLTest)6 PMML4Compiler (org.drools.pmml.pmml_4_2.PMML4Compiler)6 KieServices (org.kie.api.KieServices)6 JAXBContext (javax.xml.bind.JAXBContext)5 JAXBException (javax.xml.bind.JAXBException)5 KieContainer (org.kie.api.runtime.KieContainer)5 KieBase (org.kie.api.KieBase)4 KnowledgeBuilderResult (org.kie.internal.builder.KnowledgeBuilderResult)4 ByteArrayOutputStream (java.io.ByteArrayOutputStream)3 IOException (java.io.IOException)3 ArrayList (java.util.ArrayList)3 HashMap (java.util.HashMap)3 Map (java.util.Map)3