use of org.dmg.pmml.True in project shifu by ShifuML.
the class TreeEnsemblePmmlCreator method convert.
public MiningModel convert(IndependentTreeModel treeModel) {
MiningModel gbt = new MiningModel();
MiningSchema miningSchema = new TreeModelMiningSchemaCreator(this.modelConfig, this.columnConfigList).build(null);
gbt.setMiningSchema(miningSchema);
if (treeModel.isClassification()) {
gbt.setMiningFunction(MiningFunction.fromValue("classification"));
} else {
gbt.setMiningFunction(MiningFunction.fromValue("regression"));
}
gbt.setTargets(createTargets(this.modelConfig));
Segmentation seg = new Segmentation();
gbt.setSegmentation(seg);
seg.setMultipleModelMethod(MultipleModelMethod.fromValue("weightedAverage"));
List<Segment> list = seg.getSegments();
int idCount = 0;
// such case we only support treeModel is one element list
if (treeModel.getTrees().size() != 1) {
throw new RuntimeException("Bagging model cannot be supported in PMML generation.");
}
for (TreeNode tn : treeModel.getTrees().get(0)) {
TreeNodePmmlElementCreator tnec = new TreeNodePmmlElementCreator(this.modelConfig, this.columnConfigList, treeModel);
org.dmg.pmml.tree.Node root = tnec.convert(tn.getNode());
TreeModelPmmlElementCreator tmec = new TreeModelPmmlElementCreator(this.modelConfig, this.columnConfigList);
org.dmg.pmml.tree.TreeModel tm = tmec.convert(treeModel, root);
tm.setModelName(String.valueOf(idCount));
Segment segment = new Segment();
if (treeModel.isGBDT()) {
segment.setWeight(treeModel.getWeights().get(0).get(idCount) * treeModel.getTrees().size());
} else {
segment.setWeight(treeModel.getWeights().get(0).get(idCount));
}
segment.setId("Segement" + String.valueOf(idCount));
idCount++;
segment.setPredicate(new True());
segment.setModel(tm);
list.add(segment);
}
return gbt;
}
Aggregations