Search in sources :

Example 6 with CategoryManager

use of org.jpmml.converter.CategoryManager in project jpmml-r by jpmml.

the class RandomForestConverter method encodeTreeModel.

private <P extends Number> TreeModel encodeTreeModel(MiningFunction miningFunction, ScoreEncoder<P> scoreEncoder, List<? extends Number> leftDaughter, List<? extends Number> rightDaughter, List<P> nodepred, List<? extends Number> bestvar, List<Double> xbestsplit, Schema schema) {
    RGenericVector randomForest = getObject();
    Node root = encodeNode(True.INSTANCE, 0, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit, nodepred, new CategoryManager(), schema);
    TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), root).setMissingValueStrategy(TreeModel.MissingValueStrategy.NULL_PREDICTION).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
    if (this.compact) {
        Visitor visitor = new RandomForestCompactor();
        visitor.applyTo(treeModel);
    }
    return treeModel;
}
Also used : TreeModel(org.dmg.pmml.tree.TreeModel) Visitor(org.dmg.pmml.Visitor) Node(org.dmg.pmml.tree.Node) BranchNode(org.dmg.pmml.tree.BranchNode) LeafNode(org.dmg.pmml.tree.LeafNode) CategoryManager(org.jpmml.converter.CategoryManager) RandomForestCompactor(org.jpmml.rexp.visitors.RandomForestCompactor)

Example 7 with CategoryManager

use of org.jpmml.converter.CategoryManager in project jpmml-r by jpmml.

the class RandomForestConverter method encodeNode.

private <P extends Number> Node encodeNode(Predicate predicate, int i, ScoreEncoder<P> scoreEncoder, List<? extends Number> leftDaughter, List<? extends Number> rightDaughter, List<? extends Number> bestvar, List<Double> xbestsplit, List<P> nodepred, CategoryManager categoryManager, Schema schema) {
    Integer id = Integer.valueOf(i + 1);
    int var = ValueUtil.asInt(bestvar.get(i));
    if (var == 0) {
        P prediction = nodepred.get(i);
        Node result = new LeafNode(scoreEncoder.encode(prediction), predicate).setId(id);
        return result;
    }
    CategoryManager leftCategoryManager = categoryManager;
    CategoryManager rightCategoryManager = categoryManager;
    Predicate leftPredicate;
    Predicate rightPredicate;
    Feature feature = schema.getFeature(var - 1);
    Double split = xbestsplit.get(i);
    if (feature instanceof BooleanFeature) {
        BooleanFeature booleanFeature = (BooleanFeature) feature;
        if (split != 0.5d) {
            throw new IllegalArgumentException();
        }
        leftPredicate = createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(0));
        rightPredicate = createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(1));
    } else if (feature instanceof CategoricalFeature) {
        CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
        String name = categoricalFeature.getName();
        List<?> values = categoricalFeature.getValues();
        java.util.function.Predicate<Object> valueFilter = categoryManager.getValueFilter(name);
        List<Object> leftValues = selectValues(values, valueFilter, split, true);
        List<Object> rightValues = selectValues(values, valueFilter, split, false);
        leftCategoryManager = categoryManager.fork(name, leftValues);
        rightCategoryManager = categoryManager.fork(name, rightValues);
        leftPredicate = createPredicate(categoricalFeature, leftValues);
        rightPredicate = createPredicate(categoricalFeature, rightValues);
    } else {
        ContinuousFeature continuousFeature = feature.toContinuousFeature();
        leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, split);
        rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, split);
    }
    Node result = new BranchNode(null, predicate).setId(id);
    List<Node> nodes = result.getNodes();
    int left = ValueUtil.asInt(leftDaughter.get(i));
    if (left != 0) {
        Node leftChild = encodeNode(leftPredicate, left - 1, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit, nodepred, leftCategoryManager, schema);
        nodes.add(leftChild);
    }
    int right = ValueUtil.asInt(rightDaughter.get(i));
    if (right != 0) {
        Node rightChild = encodeNode(rightPredicate, right - 1, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit, nodepred, rightCategoryManager, schema);
        nodes.add(rightChild);
    }
    return result;
}
Also used : Node(org.dmg.pmml.tree.Node) BranchNode(org.dmg.pmml.tree.BranchNode) LeafNode(org.dmg.pmml.tree.LeafNode) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Feature(org.jpmml.converter.Feature) BooleanFeature(org.jpmml.converter.BooleanFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) BooleanFeature(org.jpmml.converter.BooleanFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) BranchNode(org.dmg.pmml.tree.BranchNode) ContinuousFeature(org.jpmml.converter.ContinuousFeature) LeafNode(org.dmg.pmml.tree.LeafNode) ArrayList(java.util.ArrayList) List(java.util.List) CategoryManager(org.jpmml.converter.CategoryManager)

Example 8 with CategoryManager

use of org.jpmml.converter.CategoryManager in project jpmml-sparkml by jpmml.

the class TreeModelUtil method encodeTreeModel.

private static <M extends Model<M> & DecisionTreeModel> TreeModel encodeTreeModel(MiningFunction miningFunction, ScoreEncoder scoreEncoder, M model, PredicateManager predicateManager, Schema schema) {
    Node root = encodeNode(True.INSTANCE, scoreEncoder, model.rootNode(), predicateManager, new CategoryManager(), schema);
    TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
    return treeModel;
}
Also used : TreeModel(org.dmg.pmml.tree.TreeModel) DecisionTreeModel(org.apache.spark.ml.tree.DecisionTreeModel) Node(org.dmg.pmml.tree.Node) ClassifierNode(org.dmg.pmml.tree.ClassifierNode) BranchNode(org.dmg.pmml.tree.BranchNode) LeafNode(org.dmg.pmml.tree.LeafNode) CategoryManager(org.jpmml.converter.CategoryManager)

Aggregations

BranchNode (org.dmg.pmml.tree.BranchNode)8 LeafNode (org.dmg.pmml.tree.LeafNode)8 Node (org.dmg.pmml.tree.Node)8 CategoryManager (org.jpmml.converter.CategoryManager)8 Predicate (org.dmg.pmml.Predicate)4 SimplePredicate (org.dmg.pmml.SimplePredicate)4 ClassifierNode (org.dmg.pmml.tree.ClassifierNode)4 TreeModel (org.dmg.pmml.tree.TreeModel)4 CategoricalFeature (org.jpmml.converter.CategoricalFeature)4 ContinuousFeature (org.jpmml.converter.ContinuousFeature)4 Feature (org.jpmml.converter.Feature)4 ArrayList (java.util.ArrayList)2 List (java.util.List)2 BooleanFeature (org.jpmml.converter.BooleanFeature)2 FlagManager (org.jpmml.converter.FlagManager)2 CategoricalSplit (org.apache.spark.ml.tree.CategoricalSplit)1 ContinuousSplit (org.apache.spark.ml.tree.ContinuousSplit)1 DecisionTreeModel (org.apache.spark.ml.tree.DecisionTreeModel)1 Split (org.apache.spark.ml.tree.Split)1 DataType (org.dmg.pmml.DataType)1