use of org.kie.dmg.pmml.pmml_4_2.descr.PMML in project drools by kiegroup.
the class PMML4Compiler method compile.
public String compile(InputStream source, ClassLoader classLoader) {
this.results = new ArrayList<KnowledgeBuilderResult>();
PMML pmml = loadModel(PMML, source);
helper.setResolver(classLoader);
if (getResults().isEmpty()) {
return generateTheory(pmml);
} else {
return null;
}
}
use of org.kie.dmg.pmml.pmml_4_2.descr.PMML in project drools by kiegroup.
the class PMMLGenerationTest method testNNGenration.
@Test
public void testNNGenration() {
PMML net = PMMLGeneratorUtils.generateSimpleNeuralNetwork(modelName, inputfieldNames, outputfieldNames, inputMeans, inputStds, outputMeans, outputStds, hiddenSize, weights);
assertNotNull(net);
ByteArrayOutputStream baos = new ByteArrayOutputStream();
assertTrue(PMMLGeneratorUtils.streamPMML(net, baos));
ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray());
PMML4Compiler compiler = new PMML4Compiler();
SchemaFactory sf = SchemaFactory.newInstance(XMLConstants.W3C_XML_SCHEMA_NS_URI);
try {
Schema schema = sf.newSchema(Thread.currentThread().getContextClassLoader().getResource(compiler.SCHEMA_PATH));
schema.newValidator().validate(new StreamSource(bais));
} catch (SAXException e) {
fail(e.getMessage());
} catch (IOException e) {
fail(e.getMessage());
}
PMML net2 = null;
try {
bais.reset();
JAXBContext ctx = JAXBContext.newInstance(PMML.class.getPackage().getName());
net2 = (PMML) ctx.createUnmarshaller().unmarshal(bais);
} catch (JAXBException e) {
e.printStackTrace();
}
assertNotNull(net2);
assertEquals(inputfieldNames.length + outputfieldNames.length, net2.getDataDictionary().getDataFields().size());
assertEquals(net.getDataDictionary().getDataFields().size(), net2.getDataDictionary().getDataFields().size());
NeuralNetwork n1 = (NeuralNetwork) net.getAssociationModelsAndBaselineModelsAndClusteringModels().get(0);
NeuralNetwork n2 = (NeuralNetwork) net2.getAssociationModelsAndBaselineModelsAndClusteringModels().get(0);
assertEquals(n1.getExtensionsAndNeuralLayersAndNeuralInputs().size(), n2.getExtensionsAndNeuralLayersAndNeuralInputs().size());
assertEquals(6, n2.getExtensionsAndNeuralLayersAndNeuralInputs().size());
NeuralLayer l1 = (NeuralLayer) n1.getExtensionsAndNeuralLayersAndNeuralInputs().get(3);
NeuralLayer l2 = (NeuralLayer) n2.getExtensionsAndNeuralLayersAndNeuralInputs().get(3);
assertEquals(l1.getNeurons().get(4).getCons().get(2).getWeight(), l2.getNeurons().get(4).getCons().get(2).getWeight(), 1e-9);
assertEquals(weights[(inputfieldNames.length + 1) * 4 + 3], l2.getNeurons().get(4).getCons().get(2).getWeight(), 1e-9);
}
use of org.kie.dmg.pmml.pmml_4_2.descr.PMML in project drools by kiegroup.
the class DecisionTreeTest method testMissingTreeDefault.
@Test
public void testMissingTreeDefault() throws Exception {
PMML4Compiler compiler = new PMML4Compiler();
PMML pmml = compiler.loadModel(PMML, ResourceFactory.newClassPathResource(source2).getInputStream());
for (Object o : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels()) {
if (o instanceof TreeModel) {
TreeModel tree = (TreeModel) o;
tree.setMissingValueStrategy(MISSINGVALUESTRATEGY.DEFAULT_CHILD);
}
}
KieSession kSession = getSession(compiler.generateTheory(pmml));
setKSession(kSession);
setKbase(getKSession().getKieBase());
// init model
kSession.fireAllRules();
FactType tgt = kSession.getKieBase().getFactType(packageName, "Fld9");
FactType tok = kSession.getKieBase().getFactType(PMML4Helper.pmmlDefaultPackageName(), "TreeToken");
kSession.getEntryPoint("in_Fld1").insert(70.0);
kSession.getEntryPoint("in_Fld2").insert(40.0);
kSession.getEntryPoint("in_Fld3").insert("miss");
kSession.fireAllRules();
Object token = getToken(kSession);
assertEquals(0.72, (Double) tok.get(token, "confidence"), 1e-6);
assertEquals("null", tok.get(token, "current"));
assertEquals(40.0, tok.get(token, "totalCount"));
checkFirstDataFieldOfTypeStatus(tgt, true, false, "Missing", "tgtX");
checkGeneratedRules();
}
use of org.kie.dmg.pmml.pmml_4_2.descr.PMML in project drools by kiegroup.
the class DecisionTreeTest method testMissingTreeAllMissingDefault.
@Test
public void testMissingTreeAllMissingDefault() throws Exception {
PMML4Compiler compiler = new PMML4Compiler();
PMML pmml = compiler.loadModel(PMML, ResourceFactory.newClassPathResource(source2).getInputStream());
for (Object o : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels()) {
if (o instanceof TreeModel) {
TreeModel tree = (TreeModel) o;
tree.setMissingValueStrategy(MISSINGVALUESTRATEGY.DEFAULT_CHILD);
}
}
String theory = compiler.generateTheory(pmml);
if (VERBOSE) {
System.out.println(theory);
}
KieSession kSession = getSession(theory);
setKSession(kSession);
setKbase(getKSession().getKieBase());
// init model
kSession.fireAllRules();
FactType tgt = kSession.getKieBase().getFactType(packageName, "Fld9");
FactType tok = kSession.getKieBase().getFactType(PMML4Helper.pmmlDefaultPackageName(), "TreeToken");
kSession.getEntryPoint("in_Fld1").insert(-1.0);
kSession.getEntryPoint("in_Fld2").insert(-1.0);
kSession.getEntryPoint("in_Fld3").insert("miss");
kSession.fireAllRules();
Object token = getToken(kSession);
assertEquals(1.0, (Double) tok.get(token, "confidence"), 1e-6);
assertEquals("null", tok.get(token, "current"));
assertEquals(0.0, tok.get(token, "totalCount"));
// checkFirstDataFieldOfTypeStatus(tgt, true, false, "Missing", "tgtX" );
checkGeneratedRules();
}
use of org.kie.dmg.pmml.pmml_4_2.descr.PMML in project drools by kiegroup.
the class DecisionTreeTest method testMissingTreeNone.
@Test
public void testMissingTreeNone() throws Exception {
PMML4Compiler compiler = new PMML4Compiler();
PMML pmml = compiler.loadModel(PMML, ResourceFactory.newClassPathResource(source2).getInputStream());
for (Object o : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels()) {
if (o instanceof TreeModel) {
TreeModel tree = (TreeModel) o;
tree.setMissingValueStrategy(MISSINGVALUESTRATEGY.NONE);
}
}
String theory = compiler.generateTheory(pmml);
if (VERBOSE) {
System.out.println(theory);
}
KieSession kSession = getSession(theory);
setKSession(kSession);
setKbase(getKSession().getKieBase());
// init model
kSession.fireAllRules();
FactType tgt = kSession.getKieBase().getFactType(packageName, "Fld9");
FactType tok = kSession.getKieBase().getFactType(PMML4Helper.pmmlDefaultPackageName(), "TreeToken");
kSession.getEntryPoint("in_Fld1").insert(-1.0);
kSession.getEntryPoint("in_Fld2").insert(-1.0);
kSession.getEntryPoint("in_Fld3").insert("miss");
kSession.fireAllRules();
Object token = getToken(kSession);
assertEquals(0.6, (Double) tok.get(token, "confidence"), 1e-6);
assertEquals("null", tok.get(token, "current"));
assertEquals(100.0, tok.get(token, "totalCount"));
checkFirstDataFieldOfTypeStatus(tgt, true, false, "Missing", "tgtX");
checkGeneratedRules();
}
Aggregations