Search in sources :

Example 6 with BranchNode

use of org.dmg.pmml.tree.BranchNode in project jpmml-r by jpmml.

the class RandomForestConverter method encodeNode.

private <P extends Number> Node encodeNode(Predicate predicate, int i, ScoreEncoder<P> scoreEncoder, List<? extends Number> leftDaughter, List<? extends Number> rightDaughter, List<? extends Number> bestvar, List<Double> xbestsplit, List<P> nodepred, CategoryManager categoryManager, Schema schema) {
    Integer id = Integer.valueOf(i + 1);
    int var = ValueUtil.asInt(bestvar.get(i));
    if (var == 0) {
        P prediction = nodepred.get(i);
        Node result = new LeafNode(scoreEncoder.encode(prediction), predicate).setId(id);
        return result;
    }
    CategoryManager leftCategoryManager = categoryManager;
    CategoryManager rightCategoryManager = categoryManager;
    Predicate leftPredicate;
    Predicate rightPredicate;
    Feature feature = schema.getFeature(var - 1);
    Double split = xbestsplit.get(i);
    if (feature instanceof BooleanFeature) {
        BooleanFeature booleanFeature = (BooleanFeature) feature;
        if (split != 0.5d) {
            throw new IllegalArgumentException();
        }
        leftPredicate = createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(0));
        rightPredicate = createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(1));
    } else if (feature instanceof CategoricalFeature) {
        CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
        String name = categoricalFeature.getName();
        List<?> values = categoricalFeature.getValues();
        java.util.function.Predicate<Object> valueFilter = categoryManager.getValueFilter(name);
        List<Object> leftValues = selectValues(values, valueFilter, split, true);
        List<Object> rightValues = selectValues(values, valueFilter, split, false);
        leftCategoryManager = categoryManager.fork(name, leftValues);
        rightCategoryManager = categoryManager.fork(name, rightValues);
        leftPredicate = createPredicate(categoricalFeature, leftValues);
        rightPredicate = createPredicate(categoricalFeature, rightValues);
    } else {
        ContinuousFeature continuousFeature = feature.toContinuousFeature();
        leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, split);
        rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, split);
    }
    Node result = new BranchNode(null, predicate).setId(id);
    List<Node> nodes = result.getNodes();
    int left = ValueUtil.asInt(leftDaughter.get(i));
    if (left != 0) {
        Node leftChild = encodeNode(leftPredicate, left - 1, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit, nodepred, leftCategoryManager, schema);
        nodes.add(leftChild);
    }
    int right = ValueUtil.asInt(rightDaughter.get(i));
    if (right != 0) {
        Node rightChild = encodeNode(rightPredicate, right - 1, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit, nodepred, rightCategoryManager, schema);
        nodes.add(rightChild);
    }
    return result;
}
Also used : Node(org.dmg.pmml.tree.Node) BranchNode(org.dmg.pmml.tree.BranchNode) LeafNode(org.dmg.pmml.tree.LeafNode) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Feature(org.jpmml.converter.Feature) BooleanFeature(org.jpmml.converter.BooleanFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) BooleanFeature(org.jpmml.converter.BooleanFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) BranchNode(org.dmg.pmml.tree.BranchNode) ContinuousFeature(org.jpmml.converter.ContinuousFeature) LeafNode(org.dmg.pmml.tree.LeafNode) ArrayList(java.util.ArrayList) List(java.util.List) CategoryManager(org.jpmml.converter.CategoryManager)

Example 7 with BranchNode

use of org.dmg.pmml.tree.BranchNode in project jpmml-r by jpmml.

the class BinaryTreeConverter method encodeNode.

private Node encodeNode(RGenericVector tree, Predicate predicate, Schema schema) {
    RIntegerVector nodeId = tree.getIntegerElement("nodeID");
    RBooleanVector terminal = tree.getBooleanElement("terminal");
    RGenericVector psplit = tree.getGenericElement("psplit");
    RGenericVector ssplits = tree.getGenericElement("ssplits");
    RDoubleVector prediction = tree.getDoubleElement("prediction");
    RGenericVector left = tree.getGenericElement("left");
    RGenericVector right = tree.getGenericElement("right");
    Integer id = nodeId.asScalar();
    if ((Boolean.TRUE).equals(terminal.asScalar())) {
        Node result = new LeafNode(null, predicate).setId(id);
        return encodeScore(result, prediction, schema);
    }
    RNumberVector<?> splitpoint = psplit.getNumericElement("splitpoint");
    RStringVector variableName = psplit.getStringElement("variableName");
    if (ssplits.size() > 0) {
        throw new IllegalArgumentException();
    }
    Predicate leftPredicate;
    Predicate rightPredicate;
    String name = variableName.asScalar();
    Integer index = this.featureIndexes.get(name);
    if (index == null) {
        throw new IllegalArgumentException();
    }
    Feature feature = schema.getFeature(index);
    if (feature instanceof CategoricalFeature) {
        CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
        List<?> values = categoricalFeature.getValues();
        List<Integer> splitValues = (List<Integer>) splitpoint.getValues();
        leftPredicate = createPredicate(categoricalFeature, selectValues(values, splitValues, true));
        rightPredicate = createPredicate(categoricalFeature, selectValues(values, splitValues, false));
    } else {
        ContinuousFeature continuousFeature = feature.toContinuousFeature();
        Number value = splitpoint.asScalar();
        leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
        rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
    }
    Node leftChild = encodeNode(left, leftPredicate, schema);
    Node rightChild = encodeNode(right, rightPredicate, schema);
    Node result = new BranchNode(null, predicate).setId(id).addNodes(leftChild, rightChild);
    return result;
}
Also used : Node(org.dmg.pmml.tree.Node) ClassifierNode(org.dmg.pmml.tree.ClassifierNode) BranchNode(org.dmg.pmml.tree.BranchNode) LeafNode(org.dmg.pmml.tree.LeafNode) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Feature(org.jpmml.converter.Feature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) BranchNode(org.dmg.pmml.tree.BranchNode) ContinuousFeature(org.jpmml.converter.ContinuousFeature) LeafNode(org.dmg.pmml.tree.LeafNode) ArrayList(java.util.ArrayList) List(java.util.List)

Aggregations

Predicate (org.dmg.pmml.Predicate)7 SimplePredicate (org.dmg.pmml.SimplePredicate)7 BranchNode (org.dmg.pmml.tree.BranchNode)7 LeafNode (org.dmg.pmml.tree.LeafNode)7 Node (org.dmg.pmml.tree.Node)7 ContinuousFeature (org.jpmml.converter.ContinuousFeature)7 CategoricalFeature (org.jpmml.converter.CategoricalFeature)6 Feature (org.jpmml.converter.Feature)6 ArrayList (java.util.ArrayList)4 List (java.util.List)4 ClassifierNode (org.dmg.pmml.tree.ClassifierNode)4 CategoryManager (org.jpmml.converter.CategoryManager)4 BooleanFeature (org.jpmml.converter.BooleanFeature)2 CategoricalSplit (org.apache.spark.ml.tree.CategoricalSplit)1 ContinuousSplit (org.apache.spark.ml.tree.ContinuousSplit)1 Split (org.apache.spark.ml.tree.Split)1 DataType (org.dmg.pmml.DataType)1 FieldName (org.dmg.pmml.FieldName)1 ScoreDistribution (org.dmg.pmml.ScoreDistribution)1 BinaryFeature (org.jpmml.converter.BinaryFeature)1