use of org.drools.pmml.pmml_4_2.PMML4Compiler in project drools by kiegroup.
the class DecisionTreeTest method testMissingTreeLastChoice.
@Test
public void testMissingTreeLastChoice() throws Exception {
PMML4Compiler compiler = new PMML4Compiler();
PMML pmml = compiler.loadModel(PMML, ResourceFactory.newClassPathResource(source2).getInputStream());
for (Object o : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels()) {
if (o instanceof TreeModel) {
TreeModel tree = (TreeModel) o;
tree.setMissingValueStrategy(MISSINGVALUESTRATEGY.LAST_PREDICTION);
}
}
String theory = compiler.generateTheory(pmml);
if (VERBOSE) {
System.out.println(theory);
}
KieSession kSession = getSession(theory);
setKSession(kSession);
setKbase(getKSession().getKieBase());
// init model
kSession.fireAllRules();
FactType tgt = kSession.getKieBase().getFactType(packageName, "Fld9");
FactType tok = kSession.getKieBase().getFactType(PMML4Helper.pmmlDefaultPackageName(), "TreeToken");
kSession.getEntryPoint("in_Fld1").insert(-1.0);
kSession.getEntryPoint("in_Fld2").insert(-1.0);
kSession.getEntryPoint("in_Fld3").insert("optA");
kSession.fireAllRules();
Object token = getToken(kSession);
assertEquals(0.8, (Double) tok.get(token, "confidence"), 1e-6);
assertEquals("null", tok.get(token, "current"));
assertEquals(50.0, tok.get(token, "totalCount"));
checkFirstDataFieldOfTypeStatus(tgt, true, false, "Missing", "tgtX");
checkGeneratedRules();
}
use of org.drools.pmml.pmml_4_2.PMML4Compiler in project drools by kiegroup.
the class HeaderTest method testPMMLHeader.
@Test
public void testPMMLHeader() {
String source = PMML4Helper.pmmlDefaultPackageName().replace(".", File.separator) + File.separator + "test_header.xml";
boolean header = false;
boolean timestamp = false;
boolean appl = false;
boolean descr = false;
boolean copyright = false;
boolean annotation = false;
PMML4Compiler compiler = new PMML4Compiler();
compiler.getHelper().setPack("org.kie.pmml.pmml_4_2.test");
String theory = compiler.compile(source, null);
BufferedReader reader = new BufferedReader(new StringReader(theory));
try {
String line = "";
while ((line = reader.readLine()) != null) {
line = line.trim();
if (line.startsWith("// Imported PMML Model Theory"))
header = true;
else if (line.startsWith("// Creation timestamp :"))
timestamp = line.contains("now");
else if (line.startsWith("// Description :"))
descr = line.contains("test");
else if (line.startsWith("// Copyright :"))
copyright = line.contains("opensource");
else if (line.startsWith("// Annotation :"))
annotation = line.contains("notes here");
else if (line.startsWith("// Trained with :"))
appl = line.contains("handmade");
}
} catch (IOException ioe) {
ioe.printStackTrace();
fail();
}
assertTrue(header);
assertTrue(timestamp);
assertTrue(descr);
assertTrue(copyright);
assertTrue(annotation);
assertTrue(appl);
KieSession ksession = getSession(theory);
KiePackage pack = ksession.getKieBase().getKiePackage("org.kie.pmml.pmml_4_2.test");
assertNotNull(pack);
ksession.dispose();
}
use of org.drools.pmml.pmml_4_2.PMML4Compiler in project drools by kiegroup.
the class HeaderTest method testPMMLHeader.
@Test
public void testPMMLHeader() {
String source = PMML4Helper.pmmlDefaultPackageName().replace(".", File.separator) + File.separator + "test_header.xml";
boolean header = false;
boolean timestamp = false;
boolean appl = false;
boolean descr = false;
boolean copyright = false;
boolean annotation = false;
PMML4Compiler compiler = new PMML4Compiler();
compiler.getHelper().setPack("org.drools.pmml.pmml_4_2.test");
String theory = compiler.compile(source, null);
BufferedReader reader = new BufferedReader(new StringReader(theory));
try {
String line = "";
while ((line = reader.readLine()) != null) {
line = line.trim();
if (line.startsWith("// Imported PMML Model Theory"))
header = true;
else if (line.startsWith("// Creation timestamp :"))
timestamp = line.contains("now");
else if (line.startsWith("// Description :"))
descr = line.contains("test");
else if (line.startsWith("// Copyright :"))
copyright = line.contains("opensource");
else if (line.startsWith("// Annotation :"))
annotation = line.contains("notes here");
else if (line.startsWith("// Trained with :"))
appl = line.contains("handmade");
}
} catch (IOException ioe) {
ioe.printStackTrace();
fail();
}
assertTrue(header);
assertTrue(timestamp);
assertTrue(descr);
assertTrue(copyright);
assertTrue(annotation);
assertTrue(appl);
KieSession ksession = getSession(theory);
KiePackage pack = ksession.getKieBase().getKiePackage("org.drools.pmml.pmml_4_2.test");
assertNotNull(pack);
ksession.dispose();
}
use of org.drools.pmml.pmml_4_2.PMML4Compiler in project drools by kiegroup.
the class DecisionTreeTest method testMissingTreeDefault.
@Test
public void testMissingTreeDefault() throws Exception {
PMML4Compiler compiler = new PMML4Compiler();
PMML pmml = compiler.loadModel(PMML, ResourceFactory.newClassPathResource(source2).getInputStream());
for (Object o : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels()) {
if (o instanceof TreeModel) {
TreeModel tree = (TreeModel) o;
tree.setMissingValueStrategy(MISSINGVALUESTRATEGY.DEFAULT_CHILD);
}
}
KieSession kSession = getSession(compiler.generateTheory(pmml));
setKSession(kSession);
setKbase(getKSession().getKieBase());
// init model
kSession.fireAllRules();
FactType tgt = kSession.getKieBase().getFactType(packageName, "Fld9");
FactType tok = kSession.getKieBase().getFactType(PMML4Helper.pmmlDefaultPackageName(), "TreeToken");
kSession.getEntryPoint("in_Fld1").insert(70.0);
kSession.getEntryPoint("in_Fld2").insert(40.0);
kSession.getEntryPoint("in_Fld3").insert("miss");
kSession.fireAllRules();
Object token = getToken(kSession);
assertEquals(0.72, (Double) tok.get(token, "confidence"), 1e-6);
assertEquals("null", tok.get(token, "current"));
assertEquals(40.0, tok.get(token, "totalCount"));
checkFirstDataFieldOfTypeStatus(tgt, true, false, "Missing", "tgtX");
checkGeneratedRules();
}
use of org.drools.pmml.pmml_4_2.PMML4Compiler in project drools by kiegroup.
the class DecisionTreeTest method testMissingTreeAllMissingDefault.
@Test
public void testMissingTreeAllMissingDefault() throws Exception {
PMML4Compiler compiler = new PMML4Compiler();
PMML pmml = compiler.loadModel(PMML, ResourceFactory.newClassPathResource(source2).getInputStream());
for (Object o : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels()) {
if (o instanceof TreeModel) {
TreeModel tree = (TreeModel) o;
tree.setMissingValueStrategy(MISSINGVALUESTRATEGY.DEFAULT_CHILD);
}
}
String theory = compiler.generateTheory(pmml);
if (VERBOSE) {
System.out.println(theory);
}
KieSession kSession = getSession(theory);
setKSession(kSession);
setKbase(getKSession().getKieBase());
// init model
kSession.fireAllRules();
FactType tgt = kSession.getKieBase().getFactType(packageName, "Fld9");
FactType tok = kSession.getKieBase().getFactType(PMML4Helper.pmmlDefaultPackageName(), "TreeToken");
kSession.getEntryPoint("in_Fld1").insert(-1.0);
kSession.getEntryPoint("in_Fld2").insert(-1.0);
kSession.getEntryPoint("in_Fld3").insert("miss");
kSession.fireAllRules();
Object token = getToken(kSession);
assertEquals(1.0, (Double) tok.get(token, "confidence"), 1e-6);
assertEquals("null", tok.get(token, "current"));
assertEquals(0.0, tok.get(token, "totalCount"));
// checkFirstDataFieldOfTypeStatus(tgt, true, false, "Missing", "tgtX" );
checkGeneratedRules();
}
Aggregations