use of org.dmg.pmml.tree.TreeModel in project jpmml-r by jpmml.
the class RangerConverter method encodeForest.
private List<TreeModel> encodeForest(RGenericVector forest, MiningFunction miningFunction, ScoreEncoder scoreEncoder, Schema schema) {
RNumberVector<?> numTrees = forest.getNumericElement("num.trees");
RGenericVector childNodeIDs = forest.getGenericElement("child.nodeIDs");
RGenericVector splitVarIDs = forest.getGenericElement("split.varIDs");
RGenericVector splitValues = forest.getGenericElement("split.values");
RGenericVector terminalClassCounts = forest.getGenericElement("terminal.class.counts", false);
Schema segmentSchema = schema.toAnonymousSchema();
List<TreeModel> treeModels = new ArrayList<>();
for (int i = 0; i < ValueUtil.asInt(numTrees.asScalar()); i++) {
TreeModel treeModel = encodeTreeModel(miningFunction, scoreEncoder, childNodeIDs.getGenericValue(i), splitVarIDs.getNumericValue(i), splitValues.getNumericValue(i), (terminalClassCounts != null ? terminalClassCounts.getGenericValue(i) : null), segmentSchema);
treeModels.add(treeModel);
}
return treeModels;
}
use of org.dmg.pmml.tree.TreeModel in project jpmml-r by jpmml.
the class RangerConverter method encodeClassification.
private MiningModel encodeClassification(RGenericVector forest, Schema schema) {
RStringVector levels = forest.getStringElement("levels");
ScoreEncoder scoreEncoder = new ScoreEncoder() {
@Override
public Node encode(Node node, Number splitValue, RNumberVector<?> terminalClassCount) {
int index = ValueUtil.asInt(splitValue);
if (terminalClassCount != null) {
throw new IllegalArgumentException();
}
node.setScore(levels.getValue(index - 1));
return node;
}
};
List<TreeModel> treeModels = encodeForest(forest, MiningFunction.CLASSIFICATION, scoreEncoder, schema);
MiningModel miningModel = new MiningModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(schema.getLabel())).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.MAJORITY_VOTE, treeModels));
return miningModel;
}
use of org.dmg.pmml.tree.TreeModel in project jpmml-r by jpmml.
the class AdaConverter method encodeModel.
@Override
public Model encodeModel(Schema schema) {
RGenericVector ada = getObject();
RGenericVector model = ada.getGenericElement("model");
RGenericVector trees = model.getGenericElement("trees");
RDoubleVector alpha = model.getDoubleElement("alpha");
List<TreeModel> treeModels = encodeTreeModels(trees);
MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(null)).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.WEIGHTED_SUM, treeModels, alpha.getValues())).setOutput(ModelUtil.createPredictedOutput("adaValue", OpType.CONTINUOUS, DataType.DOUBLE));
return MiningModelUtil.createBinaryLogisticClassification(miningModel, 2d, 0d, RegressionModel.NormalizationMethod.LOGIT, true, schema);
}
use of org.dmg.pmml.tree.TreeModel in project jpmml-r by jpmml.
the class BaggingConverter method encodeModel.
@Override
public Model encodeModel(Schema schema) {
RGenericVector bagging = getObject();
RGenericVector trees = bagging.getGenericElement("trees");
CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
List<TreeModel> treeModels = encodeTreeModels(trees);
MiningModel miningModel = new MiningModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel)).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.MAJORITY_VOTE, treeModels)).setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, categoricalLabel));
return miningModel;
}
use of org.dmg.pmml.tree.TreeModel in project jpmml-sparkml by jpmml.
the class RandomForestClassificationModelConverter method encodeModel.
@Override
public MiningModel encodeModel(Schema schema) {
RandomForestClassificationModel model = getTransformer();
List<TreeModel> treeModels = TreeModelUtil.encodeDecisionTreeEnsemble(model, schema);
MiningModel miningModel = new MiningModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(schema.getLabel())).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.AVERAGE, treeModels));
return miningModel;
}
Aggregations