use of org.kie.dmg.pmml.pmml_4_2.descr.TreeModel in project drools by kiegroup.
the class DecisionTreeTest method testMissingTreeAllMissingDefault.
@Test
public void testMissingTreeAllMissingDefault() throws Exception {
PMML4Compiler compiler = new PMML4Compiler();
PMML pmml = compiler.loadModel(PMML, ResourceFactory.newClassPathResource(source2).getInputStream());
for (Object o : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels()) {
if (o instanceof TreeModel) {
TreeModel tree = (TreeModel) o;
tree.setMissingValueStrategy(MISSINGVALUESTRATEGY.DEFAULT_CHILD);
}
}
String theory = compiler.generateTheory(pmml);
if (VERBOSE) {
System.out.println(theory);
}
KieSession kSession = getSession(theory);
setKSession(kSession);
setKbase(getKSession().getKieBase());
// init model
kSession.fireAllRules();
FactType tgt = kSession.getKieBase().getFactType(packageName, "Fld9");
FactType tok = kSession.getKieBase().getFactType(PMML4Helper.pmmlDefaultPackageName(), "TreeToken");
kSession.getEntryPoint("in_Fld1").insert(-1.0);
kSession.getEntryPoint("in_Fld2").insert(-1.0);
kSession.getEntryPoint("in_Fld3").insert("miss");
kSession.fireAllRules();
Object token = getToken(kSession);
assertEquals(1.0, (Double) tok.get(token, "confidence"), 1e-6);
assertEquals("null", tok.get(token, "current"));
assertEquals(0.0, tok.get(token, "totalCount"));
// checkFirstDataFieldOfTypeStatus(tgt, true, false, "Missing", "tgtX" );
checkGeneratedRules();
}
use of org.kie.dmg.pmml.pmml_4_2.descr.TreeModel in project drools by kiegroup.
the class DecisionTreeTest method testMissingTreeNone.
@Test
public void testMissingTreeNone() throws Exception {
PMML4Compiler compiler = new PMML4Compiler();
PMML pmml = compiler.loadModel(PMML, ResourceFactory.newClassPathResource(source2).getInputStream());
for (Object o : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels()) {
if (o instanceof TreeModel) {
TreeModel tree = (TreeModel) o;
tree.setMissingValueStrategy(MISSINGVALUESTRATEGY.NONE);
}
}
String theory = compiler.generateTheory(pmml);
if (VERBOSE) {
System.out.println(theory);
}
KieSession kSession = getSession(theory);
setKSession(kSession);
setKbase(getKSession().getKieBase());
// init model
kSession.fireAllRules();
FactType tgt = kSession.getKieBase().getFactType(packageName, "Fld9");
FactType tok = kSession.getKieBase().getFactType(PMML4Helper.pmmlDefaultPackageName(), "TreeToken");
kSession.getEntryPoint("in_Fld1").insert(-1.0);
kSession.getEntryPoint("in_Fld2").insert(-1.0);
kSession.getEntryPoint("in_Fld3").insert("miss");
kSession.fireAllRules();
Object token = getToken(kSession);
assertEquals(0.6, (Double) tok.get(token, "confidence"), 1e-6);
assertEquals("null", tok.get(token, "current"));
assertEquals(100.0, tok.get(token, "totalCount"));
checkFirstDataFieldOfTypeStatus(tgt, true, false, "Missing", "tgtX");
checkGeneratedRules();
}
use of org.kie.dmg.pmml.pmml_4_2.descr.TreeModel in project drools by kiegroup.
the class DecisionTreeTest method testMissingAggregate.
@Test
public void testMissingAggregate() throws Exception {
PMML4Compiler compiler = new PMML4Compiler();
PMML pmml = compiler.loadModel(PMML, ResourceFactory.newClassPathResource(source2).getInputStream());
for (Object o : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels()) {
if (o instanceof TreeModel) {
TreeModel tree = (TreeModel) o;
tree.setMissingValueStrategy(MISSINGVALUESTRATEGY.AGGREGATE_NODES);
}
}
String theory = compiler.generateTheory(pmml);
if (VERBOSE) {
System.out.println(theory);
}
KieSession kSession = getSession(theory);
setKSession(kSession);
setKbase(getKSession().getKieBase());
// init model
kSession.fireAllRules();
FactType tgt = kSession.getKieBase().getFactType(packageName, "Fld9");
FactType tok = kSession.getKieBase().getFactType(PMML4Helper.pmmlDefaultPackageName(), "TreeToken");
kSession.getEntryPoint("in_Fld1").insert(45.0);
kSession.getEntryPoint("in_Fld2").insert(90.0);
kSession.getEntryPoint("in_Fld3").insert("miss");
kSession.fireAllRules();
Object token = getToken(kSession);
assertEquals(0.47, (Double) tok.get(token, "confidence"), 1e-2);
assertEquals("null", tok.get(token, "current"));
assertEquals(60.0, tok.get(token, "totalCount"));
checkFirstDataFieldOfTypeStatus(tgt, true, false, "Missing", "tgtY");
checkGeneratedRules();
}
use of org.kie.dmg.pmml.pmml_4_2.descr.TreeModel in project drools by kiegroup.
the class DecisionTreeTest method testMissingTreeNull.
@Test
public void testMissingTreeNull() throws Exception {
PMML4Compiler compiler = new PMML4Compiler();
PMML pmml = compiler.loadModel(PMML, ResourceFactory.newClassPathResource(source2).getInputStream());
for (Object o : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels()) {
if (o instanceof TreeModel) {
TreeModel tree = (TreeModel) o;
tree.setMissingValueStrategy(MISSINGVALUESTRATEGY.NULL_PREDICTION);
}
}
String theory = compiler.generateTheory(pmml);
if (VERBOSE) {
System.out.println(theory);
}
KieSession kSession = getSession(theory);
setKSession(kSession);
setKbase(getKSession().getKieBase());
// init model
kSession.fireAllRules();
FactType tgt = kSession.getKieBase().getFactType(packageName, "Fld9");
FactType tok = kSession.getKieBase().getFactType(PMML4Helper.pmmlDefaultPackageName(), "TreeToken");
kSession.getEntryPoint("in_Fld1").insert(-1.0);
kSession.getEntryPoint("in_Fld2").insert(-1.0);
kSession.getEntryPoint("in_Fld3").insert("optA");
kSession.fireAllRules();
Object token = getToken(kSession);
assertEquals(0.0, (Double) tok.get(token, "confidence"), 1e-6);
assertEquals("null", tok.get(token, "current"));
assertEquals(0.0, tok.get(token, "totalCount"));
assertEquals(0, getKSession().getObjects(new ClassObjectFilter(tgt.getFactClass())).size());
checkGeneratedRules();
}
use of org.kie.dmg.pmml.pmml_4_2.descr.TreeModel in project drools by kiegroup.
the class PMML4ModelFactory method getModel.
public PMML4Model getModel(Segment segment, MiningSegmentation segmentation) {
PMML4Model model = null;
if (segment.getMiningModel() != null) {
MiningModel mm = segment.getMiningModel();
model = new Miningmodel(mm.getModelName(), mm, segmentation.getOwner(), null);
} else if (segment.getRegressionModel() != null) {
RegressionModel rm = segment.getRegressionModel();
model = new Regression(rm.getModelName(), rm, segmentation.getOwner(), null);
} else if (segment.getScorecard() != null) {
Scorecard sc = segment.getScorecard();
model = new ScorecardModel(sc.getModelName(), sc, segmentation.getOwner(), null);
} else if (segment.getTreeModel() != null) {
TreeModel tm = segment.getTreeModel();
model = new Treemodel(tm.getModelName(), tm, segmentation.getOwner(), null);
}
return model;
}
Aggregations