Search in sources :

Example 6 with LeafNode

use of org.dmg.pmml.tree.LeafNode in project drools by kiegroup.

the class KiePMMLTreeModelNodeASTFactoryTest method isFinalLeaf.

@Test
public void isFinalLeaf() {
    Node node = new LeafNode();
    DATA_TYPE targetType = DATA_TYPE.STRING;
    KiePMMLTreeModelNodeASTFactory.factory(new HashMap<>(), Collections.emptyList(), TreeModel.NoTrueChildStrategy.RETURN_NULL_PREDICTION, targetType).isFinalLeaf(node);
    assertTrue(KiePMMLTreeModelNodeASTFactory.factory(new HashMap<>(), Collections.emptyList(), TreeModel.NoTrueChildStrategy.RETURN_NULL_PREDICTION, targetType).isFinalLeaf(node));
    node = new ClassifierNode();
    assertTrue(KiePMMLTreeModelNodeASTFactory.factory(new HashMap<>(), Collections.emptyList(), TreeModel.NoTrueChildStrategy.RETURN_NULL_PREDICTION, targetType).isFinalLeaf(node));
    node.addNodes(new LeafNode());
    assertFalse(KiePMMLTreeModelNodeASTFactory.factory(new HashMap<>(), Collections.emptyList(), TreeModel.NoTrueChildStrategy.RETURN_NULL_PREDICTION, targetType).isFinalLeaf(node));
}
Also used : HashMap(java.util.HashMap) LeafNode(org.dmg.pmml.tree.LeafNode) Node(org.dmg.pmml.tree.Node) ClassifierNode(org.dmg.pmml.tree.ClassifierNode) LeafNode(org.dmg.pmml.tree.LeafNode) ClassifierNode(org.dmg.pmml.tree.ClassifierNode) DATA_TYPE(org.kie.pmml.api.enums.DATA_TYPE) Test(org.junit.Test)

Example 7 with LeafNode

use of org.dmg.pmml.tree.LeafNode in project jpmml-r by jpmml.

the class RandomForestConverter method encodeNode.

private <P extends Number> Node encodeNode(Predicate predicate, int i, ScoreEncoder<P> scoreEncoder, List<? extends Number> leftDaughter, List<? extends Number> rightDaughter, List<? extends Number> bestvar, List<Double> xbestsplit, List<P> nodepred, CategoryManager categoryManager, Schema schema) {
    Integer id = Integer.valueOf(i + 1);
    int var = ValueUtil.asInt(bestvar.get(i));
    if (var == 0) {
        P prediction = nodepred.get(i);
        Node result = new LeafNode(scoreEncoder.encode(prediction), predicate).setId(id);
        return result;
    }
    CategoryManager leftCategoryManager = categoryManager;
    CategoryManager rightCategoryManager = categoryManager;
    Predicate leftPredicate;
    Predicate rightPredicate;
    Feature feature = schema.getFeature(var - 1);
    Double split = xbestsplit.get(i);
    if (feature instanceof BooleanFeature) {
        BooleanFeature booleanFeature = (BooleanFeature) feature;
        if (split != 0.5d) {
            throw new IllegalArgumentException();
        }
        leftPredicate = createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(0));
        rightPredicate = createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(1));
    } else if (feature instanceof CategoricalFeature) {
        CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
        String name = categoricalFeature.getName();
        List<?> values = categoricalFeature.getValues();
        java.util.function.Predicate<Object> valueFilter = categoryManager.getValueFilter(name);
        List<Object> leftValues = selectValues(values, valueFilter, split, true);
        List<Object> rightValues = selectValues(values, valueFilter, split, false);
        leftCategoryManager = categoryManager.fork(name, leftValues);
        rightCategoryManager = categoryManager.fork(name, rightValues);
        leftPredicate = createPredicate(categoricalFeature, leftValues);
        rightPredicate = createPredicate(categoricalFeature, rightValues);
    } else {
        ContinuousFeature continuousFeature = feature.toContinuousFeature();
        leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, split);
        rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, split);
    }
    Node result = new BranchNode(null, predicate).setId(id);
    List<Node> nodes = result.getNodes();
    int left = ValueUtil.asInt(leftDaughter.get(i));
    if (left != 0) {
        Node leftChild = encodeNode(leftPredicate, left - 1, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit, nodepred, leftCategoryManager, schema);
        nodes.add(leftChild);
    }
    int right = ValueUtil.asInt(rightDaughter.get(i));
    if (right != 0) {
        Node rightChild = encodeNode(rightPredicate, right - 1, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit, nodepred, rightCategoryManager, schema);
        nodes.add(rightChild);
    }
    return result;
}
Also used : Node(org.dmg.pmml.tree.Node) BranchNode(org.dmg.pmml.tree.BranchNode) LeafNode(org.dmg.pmml.tree.LeafNode) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Feature(org.jpmml.converter.Feature) BooleanFeature(org.jpmml.converter.BooleanFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) BooleanFeature(org.jpmml.converter.BooleanFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) BranchNode(org.dmg.pmml.tree.BranchNode) ContinuousFeature(org.jpmml.converter.ContinuousFeature) LeafNode(org.dmg.pmml.tree.LeafNode) ArrayList(java.util.ArrayList) List(java.util.List) CategoryManager(org.jpmml.converter.CategoryManager)

Example 8 with LeafNode

use of org.dmg.pmml.tree.LeafNode in project jpmml-r by jpmml.

the class BinaryTreeConverter method encodeNode.

private Node encodeNode(RGenericVector tree, Predicate predicate, Schema schema) {
    RIntegerVector nodeId = tree.getIntegerElement("nodeID");
    RBooleanVector terminal = tree.getBooleanElement("terminal");
    RGenericVector psplit = tree.getGenericElement("psplit");
    RGenericVector ssplits = tree.getGenericElement("ssplits");
    RDoubleVector prediction = tree.getDoubleElement("prediction");
    RGenericVector left = tree.getGenericElement("left");
    RGenericVector right = tree.getGenericElement("right");
    Integer id = nodeId.asScalar();
    if ((Boolean.TRUE).equals(terminal.asScalar())) {
        Node result = new LeafNode(null, predicate).setId(id);
        return encodeScore(result, prediction, schema);
    }
    RNumberVector<?> splitpoint = psplit.getNumericElement("splitpoint");
    RStringVector variableName = psplit.getStringElement("variableName");
    if (ssplits.size() > 0) {
        throw new IllegalArgumentException();
    }
    Predicate leftPredicate;
    Predicate rightPredicate;
    String name = variableName.asScalar();
    Integer index = this.featureIndexes.get(name);
    if (index == null) {
        throw new IllegalArgumentException();
    }
    Feature feature = schema.getFeature(index);
    if (feature instanceof CategoricalFeature) {
        CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
        List<?> values = categoricalFeature.getValues();
        List<Integer> splitValues = (List<Integer>) splitpoint.getValues();
        leftPredicate = createPredicate(categoricalFeature, selectValues(values, splitValues, true));
        rightPredicate = createPredicate(categoricalFeature, selectValues(values, splitValues, false));
    } else {
        ContinuousFeature continuousFeature = feature.toContinuousFeature();
        Number value = splitpoint.asScalar();
        leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
        rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
    }
    Node leftChild = encodeNode(left, leftPredicate, schema);
    Node rightChild = encodeNode(right, rightPredicate, schema);
    Node result = new BranchNode(null, predicate).setId(id).addNodes(leftChild, rightChild);
    return result;
}
Also used : Node(org.dmg.pmml.tree.Node) ClassifierNode(org.dmg.pmml.tree.ClassifierNode) BranchNode(org.dmg.pmml.tree.BranchNode) LeafNode(org.dmg.pmml.tree.LeafNode) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Feature(org.jpmml.converter.Feature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) BranchNode(org.dmg.pmml.tree.BranchNode) ContinuousFeature(org.jpmml.converter.ContinuousFeature) LeafNode(org.dmg.pmml.tree.LeafNode) ArrayList(java.util.ArrayList) List(java.util.List)

Example 9 with LeafNode

use of org.dmg.pmml.tree.LeafNode in project jpmml-sparkml by jpmml.

the class TreeModelUtil method encodeDecisionTree.

private static <M extends Model<M> & DecisionTreeModel> TreeModel encodeDecisionTree(ModelConverter<?> converter, M model, PredicateManager predicateManager, ScoreDistributionManager scoreDistributionManager, Schema schema) {
    TreeModel treeModel;
    if (model instanceof DecisionTreeRegressionModel) {
        ScoreEncoder scoreEncoder = new ScoreEncoder() {

            @Override
            public Node encode(Node node, org.apache.spark.ml.tree.LeafNode leafNode) {
                node.setScore(leafNode.prediction());
                return node;
            }
        };
        treeModel = encodeTreeModel(MiningFunction.REGRESSION, scoreEncoder, model, predicateManager, schema);
    } else if (model instanceof DecisionTreeClassificationModel) {
        ScoreEncoder scoreEncoder = new ScoreEncoder() {

            private CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();

            @Override
            public Node encode(Node node, org.apache.spark.ml.tree.LeafNode leafNode) {
                node = new ClassifierNode(null, node.getPredicate());
                int index = ValueUtil.asInt(leafNode.prediction());
                node.setScore(this.categoricalLabel.getValue(index));
                ImpurityCalculator impurityCalculator = leafNode.impurityStats();
                node.setRecordCount(ValueUtil.narrow(impurityCalculator.count()));
                scoreDistributionManager.addScoreDistributions(node, this.categoricalLabel.getValues(), impurityCalculator.stats());
                return node;
            }
        };
        treeModel = encodeTreeModel(MiningFunction.CLASSIFICATION, scoreEncoder, model, predicateManager, schema);
    } else {
        throw new IllegalArgumentException();
    }
    Boolean compact = (Boolean) converter.getOption(HasTreeOptions.OPTION_COMPACT, Boolean.TRUE);
    if (compact != null && compact) {
        Visitor visitor = new TreeModelCompactor();
        visitor.applyTo(treeModel);
    }
    return treeModel;
}
Also used : DecisionTreeRegressionModel(org.apache.spark.ml.regression.DecisionTreeRegressionModel) Visitor(org.dmg.pmml.Visitor) Node(org.dmg.pmml.tree.Node) ClassifierNode(org.dmg.pmml.tree.ClassifierNode) BranchNode(org.dmg.pmml.tree.BranchNode) LeafNode(org.dmg.pmml.tree.LeafNode) DecisionTreeClassificationModel(org.apache.spark.ml.classification.DecisionTreeClassificationModel) TreeModelCompactor(org.jpmml.sparkml.visitors.TreeModelCompactor) TreeModel(org.dmg.pmml.tree.TreeModel) DecisionTreeModel(org.apache.spark.ml.tree.DecisionTreeModel) ImpurityCalculator(org.apache.spark.mllib.tree.impurity.ImpurityCalculator) CategoricalLabel(org.jpmml.converter.CategoricalLabel) LeafNode(org.dmg.pmml.tree.LeafNode) ClassifierNode(org.dmg.pmml.tree.ClassifierNode)

Aggregations

LeafNode (org.dmg.pmml.tree.LeafNode)9 Node (org.dmg.pmml.tree.Node)9 BranchNode (org.dmg.pmml.tree.BranchNode)8 Predicate (org.dmg.pmml.Predicate)7 SimplePredicate (org.dmg.pmml.SimplePredicate)7 ContinuousFeature (org.jpmml.converter.ContinuousFeature)7 ClassifierNode (org.dmg.pmml.tree.ClassifierNode)6 CategoricalFeature (org.jpmml.converter.CategoricalFeature)6 Feature (org.jpmml.converter.Feature)6 ArrayList (java.util.ArrayList)4 List (java.util.List)4 CategoryManager (org.jpmml.converter.CategoryManager)4 BooleanFeature (org.jpmml.converter.BooleanFeature)2 CategoricalLabel (org.jpmml.converter.CategoricalLabel)2 HashMap (java.util.HashMap)1 DecisionTreeClassificationModel (org.apache.spark.ml.classification.DecisionTreeClassificationModel)1 DecisionTreeRegressionModel (org.apache.spark.ml.regression.DecisionTreeRegressionModel)1 CategoricalSplit (org.apache.spark.ml.tree.CategoricalSplit)1 ContinuousSplit (org.apache.spark.ml.tree.ContinuousSplit)1 DecisionTreeModel (org.apache.spark.ml.tree.DecisionTreeModel)1