use of org.dmg.pmml.tree.TreeModel in project jpmml-sparkml by jpmml.
the class GBTRegressionModelConverter method encodeModel.
@Override
public MiningModel encodeModel(Schema schema) {
GBTRegressionModel model = getTransformer();
List<TreeModel> treeModels = TreeModelUtil.encodeDecisionTreeEnsemble(model, schema);
MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel())).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.WEIGHTED_SUM, treeModels, Doubles.asList(model.treeWeights())));
return miningModel;
}
use of org.dmg.pmml.tree.TreeModel in project jpmml-sparkml by jpmml.
the class TreeModelCompactor method visit.
@Override
public VisitorAction visit(TreeModel treeModel) {
TreeModel.MissingValueStrategy missingValueStrategy = treeModel.getMissingValueStrategy();
TreeModel.NoTrueChildStrategy noTrueChildStrategy = treeModel.getNoTrueChildStrategy();
TreeModel.SplitCharacteristic splitCharacteristic = treeModel.getSplitCharacteristic();
if (!(TreeModel.MissingValueStrategy.NONE).equals(missingValueStrategy) || !(TreeModel.NoTrueChildStrategy.RETURN_NULL_PREDICTION).equals(noTrueChildStrategy) || !(TreeModel.SplitCharacteristic.BINARY_SPLIT).equals(splitCharacteristic)) {
throw new IllegalArgumentException();
}
treeModel.setMissingValueStrategy(TreeModel.MissingValueStrategy.NULL_PREDICTION).setSplitCharacteristic(TreeModel.SplitCharacteristic.MULTI_SPLIT);
MiningFunction miningFunction = treeModel.getMiningFunction();
switch(miningFunction) {
case REGRESSION:
treeModel.setNoTrueChildStrategy(TreeModel.NoTrueChildStrategy.RETURN_LAST_PREDICTION);
break;
case CLASSIFICATION:
break;
default:
throw new IllegalArgumentException();
}
return super.visit(treeModel);
}
use of org.dmg.pmml.tree.TreeModel in project pyramid by cheng-li.
the class PMMLConverter method createMiningModel.
protected static MiningModel createMiningModel(List<RegressionTree> regTrees, float base_score, Schema schema) {
ContinuousLabel continuousLabel = (ContinuousLabel) schema.getLabel();
Schema segmentSchema = schema.toAnonymousSchema();
List<TreeModel> treeModels = new ArrayList<>();
for (RegressionTree regTree : regTrees) {
TreeModel treeModel = regTree.encodeTreeModel(segmentSchema);
treeModels.add(treeModel);
}
MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(continuousLabel)).setMathContext(MathContext.FLOAT).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.SUM, treeModels)).setTargets(ModelUtil.createRescaleTargets(null, ValueUtil.floatToDouble(base_score), continuousLabel));
return miningModel;
}
use of org.dmg.pmml.tree.TreeModel in project pyramid by cheng-li.
the class RegressionTree method encodeTreeModel.
// ======================PMML===========================
// this part follows the design of jpmml package
public TreeModel encodeTreeModel(Schema schema) {
org.dmg.pmml.tree.Node root = new org.dmg.pmml.tree.Node().setPredicate(new True());
encodeNode(root, 0, schema);
TreeModel treeModel = new TreeModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel()), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT).setMissingValueStrategy(TreeModel.MissingValueStrategy.NONE).setMathContext(MathContext.FLOAT);
return treeModel;
}
use of org.dmg.pmml.tree.TreeModel in project pyramid by cheng-li.
the class PMMLConverter method createMiningModel.
protected static MiningModel createMiningModel(List<RegressionTree> regTrees, float base_score, Schema schema) {
ContinuousLabel continuousLabel = (ContinuousLabel) schema.getLabel();
Schema segmentSchema = schema.toAnonymousSchema();
List<TreeModel> treeModels = new ArrayList<>();
for (RegressionTree regTree : regTrees) {
TreeModel treeModel = regTree.encodeTreeModel(segmentSchema);
treeModels.add(treeModel);
}
MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(continuousLabel)).setMathContext(MathContext.FLOAT).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.SUM, treeModels)).setTargets(ModelUtil.createRescaleTargets(null, ValueUtil.floatToDouble(base_score), continuousLabel));
return miningModel;
}
Aggregations