Search in sources :

Example 21 with TreeModel

use of org.dmg.pmml.tree.TreeModel in project jpmml-r by jpmml.

the class RangerConverter method encodeForest.

private List<TreeModel> encodeForest(RGenericVector forest, MiningFunction miningFunction, ScoreEncoder scoreEncoder, Schema schema) {
    RNumberVector<?> numTrees = forest.getNumericElement("num.trees");
    RGenericVector childNodeIDs = forest.getGenericElement("child.nodeIDs");
    RGenericVector splitVarIDs = forest.getGenericElement("split.varIDs");
    RGenericVector splitValues = forest.getGenericElement("split.values");
    RGenericVector terminalClassCounts = forest.getGenericElement("terminal.class.counts", false);
    Schema segmentSchema = schema.toAnonymousSchema();
    List<TreeModel> treeModels = new ArrayList<>();
    for (int i = 0; i < ValueUtil.asInt(numTrees.asScalar()); i++) {
        TreeModel treeModel = encodeTreeModel(miningFunction, scoreEncoder, childNodeIDs.getGenericValue(i), splitVarIDs.getNumericValue(i), splitValues.getNumericValue(i), (terminalClassCounts != null ? terminalClassCounts.getGenericValue(i) : null), segmentSchema);
        treeModels.add(treeModel);
    }
    return treeModels;
}
Also used : TreeModel(org.dmg.pmml.tree.TreeModel) Schema(org.jpmml.converter.Schema) ArrayList(java.util.ArrayList)

Example 22 with TreeModel

use of org.dmg.pmml.tree.TreeModel in project jpmml-r by jpmml.

the class RangerConverter method encodeClassification.

private MiningModel encodeClassification(RGenericVector forest, Schema schema) {
    RStringVector levels = forest.getStringElement("levels");
    ScoreEncoder scoreEncoder = new ScoreEncoder() {

        @Override
        public Node encode(Node node, Number splitValue, RNumberVector<?> terminalClassCount) {
            int index = ValueUtil.asInt(splitValue);
            if (terminalClassCount != null) {
                throw new IllegalArgumentException();
            }
            node.setScore(levels.getValue(index - 1));
            return node;
        }
    };
    List<TreeModel> treeModels = encodeForest(forest, MiningFunction.CLASSIFICATION, scoreEncoder, schema);
    MiningModel miningModel = new MiningModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(schema.getLabel())).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.MAJORITY_VOTE, treeModels));
    return miningModel;
}
Also used : TreeModel(org.dmg.pmml.tree.TreeModel) MiningModel(org.dmg.pmml.mining.MiningModel) Node(org.dmg.pmml.tree.Node) ClassifierNode(org.dmg.pmml.tree.ClassifierNode) BranchNode(org.dmg.pmml.tree.BranchNode) LeafNode(org.dmg.pmml.tree.LeafNode)

Example 23 with TreeModel

use of org.dmg.pmml.tree.TreeModel in project jpmml-r by jpmml.

the class AdaConverter method encodeModel.

@Override
public Model encodeModel(Schema schema) {
    RGenericVector ada = getObject();
    RGenericVector model = ada.getGenericElement("model");
    RGenericVector trees = model.getGenericElement("trees");
    RDoubleVector alpha = model.getDoubleElement("alpha");
    List<TreeModel> treeModels = encodeTreeModels(trees);
    MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(null)).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.WEIGHTED_SUM, treeModels, alpha.getValues())).setOutput(ModelUtil.createPredictedOutput("adaValue", OpType.CONTINUOUS, DataType.DOUBLE));
    return MiningModelUtil.createBinaryLogisticClassification(miningModel, 2d, 0d, RegressionModel.NormalizationMethod.LOGIT, true, schema);
}
Also used : TreeModel(org.dmg.pmml.tree.TreeModel) MiningModel(org.dmg.pmml.mining.MiningModel)

Example 24 with TreeModel

use of org.dmg.pmml.tree.TreeModel in project jpmml-r by jpmml.

the class BaggingConverter method encodeModel.

@Override
public Model encodeModel(Schema schema) {
    RGenericVector bagging = getObject();
    RGenericVector trees = bagging.getGenericElement("trees");
    CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
    List<TreeModel> treeModels = encodeTreeModels(trees);
    MiningModel miningModel = new MiningModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel)).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.MAJORITY_VOTE, treeModels)).setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, categoricalLabel));
    return miningModel;
}
Also used : TreeModel(org.dmg.pmml.tree.TreeModel) MiningModel(org.dmg.pmml.mining.MiningModel) CategoricalLabel(org.jpmml.converter.CategoricalLabel)

Example 25 with TreeModel

use of org.dmg.pmml.tree.TreeModel in project jpmml-sparkml by jpmml.

the class RandomForestClassificationModelConverter method encodeModel.

@Override
public MiningModel encodeModel(Schema schema) {
    RandomForestClassificationModel model = getTransformer();
    List<TreeModel> treeModels = TreeModelUtil.encodeDecisionTreeEnsemble(model, schema);
    MiningModel miningModel = new MiningModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(schema.getLabel())).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.AVERAGE, treeModels));
    return miningModel;
}
Also used : TreeModel(org.dmg.pmml.tree.TreeModel) MiningModel(org.dmg.pmml.mining.MiningModel) RandomForestClassificationModel(org.apache.spark.ml.classification.RandomForestClassificationModel)

Aggregations

TreeModel (org.dmg.pmml.tree.TreeModel)48 MiningModel (org.dmg.pmml.mining.MiningModel)17 Node (org.dmg.pmml.tree.Node)12 Test (org.junit.Test)12 ArrayList (java.util.ArrayList)11 BranchNode (org.dmg.pmml.tree.BranchNode)9 LeafNode (org.dmg.pmml.tree.LeafNode)9 Schema (org.jpmml.converter.Schema)9 ClassifierNode (org.dmg.pmml.tree.ClassifierNode)8 CategoricalLabel (org.jpmml.converter.CategoricalLabel)8 KiePMMLTreeModel (org.kie.pmml.models.drools.tree.model.KiePMMLTreeModel)8 KnowledgeBuilderImpl (org.drools.compiler.builder.impl.KnowledgeBuilderImpl)6 HasClassLoaderMock (org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock)6 PMML (org.dmg.pmml.PMML)5 HasKnowledgeBuilderMock (org.kie.pmml.models.drools.commons.implementations.HasKnowledgeBuilderMock)5 KiePMMLTreeModel (org.kie.pmml.models.tree.model.KiePMMLTreeModel)5 ConstructorDeclaration (com.github.javaparser.ast.body.ConstructorDeclaration)4 Expression (com.github.javaparser.ast.expr.Expression)4 NameExpr (com.github.javaparser.ast.expr.NameExpr)4 StringLiteralExpr (com.github.javaparser.ast.expr.StringLiteralExpr)4