Search in sources :

Example 1 with NeuralNetwork

use of org.dmg.pmml.pmml_4_2.descr.NeuralNetwork in project drools by kiegroup.

the class PMML4Compiler method checkBuildingResources.

private static KieBase checkBuildingResources(PMML pmml) throws IOException {
    KieServices ks = KieServices.Factory.get();
    KieContainer kieContainer = ks.getKieClasspathContainer();
    if (registry == null) {
        initRegistry();
    }
    String chosenKieBase = null;
    for (Object o : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels()) {
        if (o instanceof NaiveBayesModel) {
            if (!naiveBayesLoaded) {
                for (String ntempl : NAIVE_BAYES_TEMPLATES) {
                    prepareTemplate(ntempl);
                }
                naiveBayesLoaded = true;
            }
            chosenKieBase = chosenKieBase == null ? "PMML-Bayes" : "PMML";
        }
        if (o instanceof NeuralNetwork) {
            if (!neuralLoaded) {
                for (String ntempl : NEURAL_TEMPLATES) {
                    prepareTemplate(ntempl);
                }
                neuralLoaded = true;
            }
            chosenKieBase = chosenKieBase == null ? "PMML-Neural" : "PMML";
        }
        if (o instanceof ClusteringModel) {
            if (!clusteringLoaded) {
                for (String ntempl : CLUSTERING_TEMPLATES) {
                    prepareTemplate(ntempl);
                }
                clusteringLoaded = true;
            }
            chosenKieBase = chosenKieBase == null ? "PMML-Cluster" : "PMML";
        }
        if (o instanceof SupportVectorMachineModel) {
            if (!svmLoaded) {
                for (String ntempl : SVM_TEMPLATES) {
                    prepareTemplate(ntempl);
                }
                svmLoaded = true;
            }
            chosenKieBase = chosenKieBase == null ? "PMML-SVM" : "PMML";
        }
        if (o instanceof TreeModel) {
            if (!treeLoaded) {
                for (String ntempl : TREE_TEMPLATES) {
                    prepareTemplate(ntempl);
                }
                treeLoaded = true;
            }
            chosenKieBase = chosenKieBase == null ? "PMML-Tree" : "PMML";
        }
        if (o instanceof RegressionModel) {
            if (!simpleRegLoaded) {
                for (String ntempl : SIMPLEREG_TEMPLATES) {
                    prepareTemplate(ntempl);
                }
                simpleRegLoaded = true;
            }
            chosenKieBase = chosenKieBase == null ? "PMML-Regression" : "PMML";
        }
        if (o instanceof Scorecard) {
            if (!scorecardLoaded) {
                for (String ntempl : SCORECARD_TEMPLATES) {
                    prepareTemplate(ntempl);
                }
                scorecardLoaded = true;
            }
            chosenKieBase = chosenKieBase == null ? "PMML-Scorecard" : "PMML";
        }
    }
    if (chosenKieBase == null) {
        chosenKieBase = "PMML-Base";
    }
    return kieContainer.getKieBase(chosenKieBase);
}
Also used : TreeModel(org.dmg.pmml.pmml_4_2.descr.TreeModel) KieServices(org.kie.api.KieServices) NaiveBayesModel(org.dmg.pmml.pmml_4_2.descr.NaiveBayesModel) NeuralNetwork(org.dmg.pmml.pmml_4_2.descr.NeuralNetwork) SupportVectorMachineModel(org.dmg.pmml.pmml_4_2.descr.SupportVectorMachineModel) Scorecard(org.dmg.pmml.pmml_4_2.descr.Scorecard) KieContainer(org.kie.api.runtime.KieContainer) ClusteringModel(org.dmg.pmml.pmml_4_2.descr.ClusteringModel) RegressionModel(org.dmg.pmml.pmml_4_2.descr.RegressionModel)

Example 2 with NeuralNetwork

use of org.dmg.pmml.pmml_4_2.descr.NeuralNetwork in project drools by kiegroup.

the class PMML4Compiler method checkBuildingResources.

private static KieBase checkBuildingResources(PMML pmml) throws IOException {
    KieServices ks = KieServices.Factory.get();
    KieContainer kieContainer = ks.getKieClasspathContainer();
    if (registry == null) {
        initRegistry();
    }
    String chosenKieBase = null;
    for (Object o : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels()) {
        if (o instanceof NaiveBayesModel) {
            if (!naiveBayesLoaded) {
                for (String ntempl : NAIVE_BAYES_TEMPLATES) {
                    prepareTemplate(ntempl);
                }
                naiveBayesLoaded = true;
            }
            chosenKieBase = chosenKieBase == null ? "KiePMML-Bayes" : "KiePMML";
        }
        if (o instanceof NeuralNetwork) {
            if (!neuralLoaded) {
                for (String ntempl : NEURAL_TEMPLATES) {
                    prepareTemplate(ntempl);
                }
                neuralLoaded = true;
            }
            chosenKieBase = chosenKieBase == null ? "KiePMML-Neural" : "KiePMML";
        }
        if (o instanceof ClusteringModel) {
            if (!clusteringLoaded) {
                for (String ntempl : CLUSTERING_TEMPLATES) {
                    prepareTemplate(ntempl);
                }
                clusteringLoaded = true;
            }
            chosenKieBase = chosenKieBase == null ? "KiePMML-Cluster" : "KiePMML";
        }
        if (o instanceof SupportVectorMachineModel) {
            if (!svmLoaded) {
                for (String ntempl : SVM_TEMPLATES) {
                    prepareTemplate(ntempl);
                }
                svmLoaded = true;
            }
            chosenKieBase = chosenKieBase == null ? "KiePMML-SVM" : "KiePMML";
        }
        if (o instanceof TreeModel) {
            if (!treeLoaded) {
                for (String ntempl : TREE_TEMPLATES) {
                    prepareTemplate(ntempl);
                }
                treeLoaded = true;
            }
            chosenKieBase = chosenKieBase == null ? "KiePMML-Tree" : "KiePMML";
        }
        if (o instanceof RegressionModel) {
            if (!simpleRegLoaded) {
                for (String ntempl : SIMPLEREG_TEMPLATES) {
                    prepareTemplate(ntempl);
                }
                simpleRegLoaded = true;
            }
            chosenKieBase = chosenKieBase == null ? "KiePMML-Regression" : "KiePMML";
        }
        if (o instanceof Scorecard) {
            if (!scorecardLoaded) {
                for (String ntempl : SCORECARD_TEMPLATES) {
                    prepareTemplate(ntempl);
                }
                scorecardLoaded = true;
            }
            chosenKieBase = chosenKieBase == null ? "KiePMML-Scorecard" : "KiePMML";
        }
    }
    if (chosenKieBase == null) {
        chosenKieBase = "KiePMML-Base";
    }
    return kieContainer.getKieBase(chosenKieBase);
}
Also used : TreeModel(org.kie.dmg.pmml.pmml_4_2.descr.TreeModel) KieServices(org.kie.api.KieServices) NaiveBayesModel(org.kie.dmg.pmml.pmml_4_2.descr.NaiveBayesModel) NeuralNetwork(org.kie.dmg.pmml.pmml_4_2.descr.NeuralNetwork) SupportVectorMachineModel(org.kie.dmg.pmml.pmml_4_2.descr.SupportVectorMachineModel) Scorecard(org.kie.dmg.pmml.pmml_4_2.descr.Scorecard) KieContainer(org.kie.api.runtime.KieContainer) ClusteringModel(org.kie.dmg.pmml.pmml_4_2.descr.ClusteringModel) RegressionModel(org.kie.dmg.pmml.pmml_4_2.descr.RegressionModel)

Example 3 with NeuralNetwork

use of org.dmg.pmml.pmml_4_2.descr.NeuralNetwork in project drools by kiegroup.

the class PMMLGenerationTest method testNNGenration.

@Test
public void testNNGenration() {
    PMML net = PMMLGeneratorUtils.generateSimpleNeuralNetwork(modelName, inputfieldNames, outputfieldNames, inputMeans, inputStds, outputMeans, outputStds, hiddenSize, weights);
    assertNotNull(net);
    ByteArrayOutputStream baos = new ByteArrayOutputStream();
    assertTrue(PMMLGeneratorUtils.streamPMML(net, baos));
    ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray());
    PMML4Compiler compiler = new PMML4Compiler();
    SchemaFactory sf = SchemaFactory.newInstance(XMLConstants.W3C_XML_SCHEMA_NS_URI);
    try {
        Schema schema = sf.newSchema(Thread.currentThread().getContextClassLoader().getResource(compiler.SCHEMA_PATH));
        schema.newValidator().validate(new StreamSource(bais));
    } catch (SAXException e) {
        fail(e.getMessage());
    } catch (IOException e) {
        fail(e.getMessage());
    }
    PMML net2 = null;
    try {
        bais.reset();
        JAXBContext ctx = JAXBContext.newInstance(PMML.class.getPackage().getName());
        net2 = (PMML) ctx.createUnmarshaller().unmarshal(bais);
    } catch (JAXBException e) {
        e.printStackTrace();
    }
    assertNotNull(net2);
    assertEquals(inputfieldNames.length + outputfieldNames.length, net2.getDataDictionary().getDataFields().size());
    assertEquals(net.getDataDictionary().getDataFields().size(), net2.getDataDictionary().getDataFields().size());
    NeuralNetwork n1 = (NeuralNetwork) net.getAssociationModelsAndBaselineModelsAndClusteringModels().get(0);
    NeuralNetwork n2 = (NeuralNetwork) net2.getAssociationModelsAndBaselineModelsAndClusteringModels().get(0);
    assertEquals(n1.getExtensionsAndNeuralLayersAndNeuralInputs().size(), n2.getExtensionsAndNeuralLayersAndNeuralInputs().size());
    assertEquals(6, n2.getExtensionsAndNeuralLayersAndNeuralInputs().size());
    NeuralLayer l1 = (NeuralLayer) n1.getExtensionsAndNeuralLayersAndNeuralInputs().get(3);
    NeuralLayer l2 = (NeuralLayer) n2.getExtensionsAndNeuralLayersAndNeuralInputs().get(3);
    assertEquals(l1.getNeurons().get(4).getCons().get(2).getWeight(), l2.getNeurons().get(4).getCons().get(2).getWeight(), 1e-9);
    assertEquals(weights[(inputfieldNames.length + 1) * 4 + 3], l2.getNeurons().get(4).getCons().get(2).getWeight(), 1e-9);
}
Also used : SchemaFactory(javax.xml.validation.SchemaFactory) Schema(javax.xml.validation.Schema) StreamSource(javax.xml.transform.stream.StreamSource) JAXBException(javax.xml.bind.JAXBException) JAXBContext(javax.xml.bind.JAXBContext) ByteArrayOutputStream(java.io.ByteArrayOutputStream) IOException(java.io.IOException) NeuralLayer(org.dmg.pmml.pmml_4_2.descr.NeuralLayer) NeuralNetwork(org.dmg.pmml.pmml_4_2.descr.NeuralNetwork) SAXException(org.xml.sax.SAXException) ByteArrayInputStream(java.io.ByteArrayInputStream) PMML(org.dmg.pmml.pmml_4_2.descr.PMML) Test(org.junit.Test)

Example 4 with NeuralNetwork

use of org.dmg.pmml.pmml_4_2.descr.NeuralNetwork in project drools by kiegroup.

the class PMMLGenerationTest method testNNGenration.

@Test
public void testNNGenration() {
    PMML net = PMMLGeneratorUtils.generateSimpleNeuralNetwork(modelName, inputfieldNames, outputfieldNames, inputMeans, inputStds, outputMeans, outputStds, hiddenSize, weights);
    assertNotNull(net);
    ByteArrayOutputStream baos = new ByteArrayOutputStream();
    assertTrue(PMMLGeneratorUtils.streamPMML(net, baos));
    ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray());
    PMML4Compiler compiler = new PMML4Compiler();
    SchemaFactory sf = SchemaFactory.newInstance(XMLConstants.W3C_XML_SCHEMA_NS_URI);
    try {
        Schema schema = sf.newSchema(Thread.currentThread().getContextClassLoader().getResource(compiler.SCHEMA_PATH));
        schema.newValidator().validate(new StreamSource(bais));
    } catch (SAXException e) {
        fail(e.getMessage());
    } catch (IOException e) {
        fail(e.getMessage());
    }
    PMML net2 = null;
    try {
        bais.reset();
        JAXBContext ctx = JAXBContext.newInstance(PMML.class.getPackage().getName());
        net2 = (PMML) ctx.createUnmarshaller().unmarshal(bais);
    } catch (JAXBException e) {
        e.printStackTrace();
    }
    assertNotNull(net2);
    assertEquals(inputfieldNames.length + outputfieldNames.length, net2.getDataDictionary().getDataFields().size());
    assertEquals(net.getDataDictionary().getDataFields().size(), net2.getDataDictionary().getDataFields().size());
    NeuralNetwork n1 = (NeuralNetwork) net.getAssociationModelsAndBaselineModelsAndClusteringModels().get(0);
    NeuralNetwork n2 = (NeuralNetwork) net2.getAssociationModelsAndBaselineModelsAndClusteringModels().get(0);
    assertEquals(n1.getExtensionsAndNeuralLayersAndNeuralInputs().size(), n2.getExtensionsAndNeuralLayersAndNeuralInputs().size());
    assertEquals(6, n2.getExtensionsAndNeuralLayersAndNeuralInputs().size());
    NeuralLayer l1 = (NeuralLayer) n1.getExtensionsAndNeuralLayersAndNeuralInputs().get(3);
    NeuralLayer l2 = (NeuralLayer) n2.getExtensionsAndNeuralLayersAndNeuralInputs().get(3);
    assertEquals(l1.getNeurons().get(4).getCons().get(2).getWeight(), l2.getNeurons().get(4).getCons().get(2).getWeight(), 1e-9);
    assertEquals(weights[(inputfieldNames.length + 1) * 4 + 3], l2.getNeurons().get(4).getCons().get(2).getWeight(), 1e-9);
}
Also used : SchemaFactory(javax.xml.validation.SchemaFactory) Schema(javax.xml.validation.Schema) StreamSource(javax.xml.transform.stream.StreamSource) JAXBException(javax.xml.bind.JAXBException) JAXBContext(javax.xml.bind.JAXBContext) ByteArrayOutputStream(java.io.ByteArrayOutputStream) IOException(java.io.IOException) NeuralLayer(org.kie.dmg.pmml.pmml_4_2.descr.NeuralLayer) NeuralNetwork(org.kie.dmg.pmml.pmml_4_2.descr.NeuralNetwork) SAXException(org.xml.sax.SAXException) ByteArrayInputStream(java.io.ByteArrayInputStream) PMML(org.kie.dmg.pmml.pmml_4_2.descr.PMML) Test(org.junit.Test)

Aggregations

ByteArrayInputStream (java.io.ByteArrayInputStream)2 ByteArrayOutputStream (java.io.ByteArrayOutputStream)2 IOException (java.io.IOException)2 JAXBContext (javax.xml.bind.JAXBContext)2 JAXBException (javax.xml.bind.JAXBException)2 StreamSource (javax.xml.transform.stream.StreamSource)2 Schema (javax.xml.validation.Schema)2 SchemaFactory (javax.xml.validation.SchemaFactory)2 NeuralNetwork (org.dmg.pmml.pmml_4_2.descr.NeuralNetwork)2 Test (org.junit.Test)2 KieServices (org.kie.api.KieServices)2 KieContainer (org.kie.api.runtime.KieContainer)2 NeuralNetwork (org.kie.dmg.pmml.pmml_4_2.descr.NeuralNetwork)2 SAXException (org.xml.sax.SAXException)2 ClusteringModel (org.dmg.pmml.pmml_4_2.descr.ClusteringModel)1 NaiveBayesModel (org.dmg.pmml.pmml_4_2.descr.NaiveBayesModel)1 NeuralLayer (org.dmg.pmml.pmml_4_2.descr.NeuralLayer)1 PMML (org.dmg.pmml.pmml_4_2.descr.PMML)1 RegressionModel (org.dmg.pmml.pmml_4_2.descr.RegressionModel)1 Scorecard (org.dmg.pmml.pmml_4_2.descr.Scorecard)1