Search in sources :

Example 6 with Node

use of org.dmg.pmml.tree.Node in project jpmml-r by jpmml.

the class RangerConverter method encodeNode.

private void encodeNode(Node node, int index, ScoreEncoder scoreEncoder, RNumberVector<?> leftChildIDs, RNumberVector<?> rightChildIDs, RNumberVector<?> splitVarIDs, RNumberVector<?> splitValues, RGenericVector terminalClassCounts, Schema schema) {
    int leftIndex = ValueUtil.asInt(leftChildIDs.getValue(index));
    int rightIndex = ValueUtil.asInt(rightChildIDs.getValue(index));
    Number splitValue = splitValues.getValue(index);
    RNumberVector<?> terminalClassCount = (terminalClassCounts != null ? (RNumberVector<?>) terminalClassCounts.getValue(index) : null);
    if (leftIndex == 0 && rightIndex == 0) {
        scoreEncoder.encode(node, splitValue, terminalClassCount);
        return;
    }
    Predicate leftPredicate;
    Predicate rightPredicate;
    int splitVarIndex = ValueUtil.asInt(splitVarIDs.getValue(index));
    Feature feature = schema.getFeature(splitVarIndex - 1);
    if (feature instanceof CategoricalFeature) {
        CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
        int splitLevelIndex = ValueUtil.asInt(Math.floor(splitValue.doubleValue()));
        List<String> values = categoricalFeature.getValues();
        leftPredicate = createSimpleSetPredicate(categoricalFeature, values.subList(0, splitLevelIndex));
        rightPredicate = createSimpleSetPredicate(categoricalFeature, values.subList(splitLevelIndex, values.size()));
    } else {
        ContinuousFeature continuousFeature = feature.toContinuousFeature();
        String value = ValueUtil.formatValue(splitValue);
        leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
        rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
    }
    Node leftChild = new Node().setPredicate(leftPredicate);
    encodeNode(leftChild, leftIndex, scoreEncoder, leftChildIDs, rightChildIDs, splitVarIDs, splitValues, terminalClassCounts, schema);
    Node rightChild = new Node().setPredicate(rightPredicate);
    encodeNode(rightChild, rightIndex, scoreEncoder, leftChildIDs, rightChildIDs, splitVarIDs, splitValues, terminalClassCounts, schema);
    node.addNodes(leftChild, rightChild);
}
Also used : ContinuousFeature(org.jpmml.converter.ContinuousFeature) Node(org.dmg.pmml.tree.Node) Feature(org.jpmml.converter.Feature) ContinuousFeature(org.jpmml.converter.ContinuousFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate)

Example 7 with Node

use of org.dmg.pmml.tree.Node in project jpmml-r by jpmml.

the class RangerConverter method encodeTreeModel.

private TreeModel encodeTreeModel(MiningFunction miningFunction, ScoreEncoder scoreEncoder, RGenericVector childNodeIDs, RNumberVector<?> splitVarIDs, RNumberVector<?> splitValues, RGenericVector terminalClassCounts, Schema schema) {
    RNumberVector<?> leftChildIDs = (RNumberVector<?>) childNodeIDs.getValue(0);
    RNumberVector<?> rightChildIDs = (RNumberVector<?>) childNodeIDs.getValue(1);
    Node root = new Node().setPredicate(new True());
    encodeNode(root, 0, scoreEncoder, leftChildIDs, rightChildIDs, splitVarIDs, splitValues, terminalClassCounts, schema);
    TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
    return treeModel;
}
Also used : TreeModel(org.dmg.pmml.tree.TreeModel) Node(org.dmg.pmml.tree.Node) True(org.dmg.pmml.True)

Example 8 with Node

use of org.dmg.pmml.tree.Node in project jpmml-sparkml by jpmml.

the class TreeModelUtil method encodeNode.

public static Node encodeNode(org.apache.spark.ml.tree.Node node, PredicateManager predicateManager, Map<FieldName, Set<String>> parentFieldValues, MiningFunction miningFunction, Schema schema) {
    if (node instanceof InternalNode) {
        InternalNode internalNode = (InternalNode) node;
        Map<FieldName, Set<String>> leftFieldValues = parentFieldValues;
        Map<FieldName, Set<String>> rightFieldValues = parentFieldValues;
        Predicate leftPredicate;
        Predicate rightPredicate;
        Split split = internalNode.split();
        Feature feature = schema.getFeature(split.featureIndex());
        if (split instanceof ContinuousSplit) {
            ContinuousSplit continuousSplit = (ContinuousSplit) split;
            double threshold = continuousSplit.threshold();
            if (feature instanceof BooleanFeature) {
                BooleanFeature booleanFeature = (BooleanFeature) feature;
                if (threshold != 0.5d) {
                    throw new IllegalArgumentException();
                }
                leftPredicate = predicateManager.createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(0));
                rightPredicate = predicateManager.createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(1));
            } else {
                ContinuousFeature continuousFeature = feature.toContinuousFeature();
                String value = ValueUtil.formatValue(threshold);
                leftPredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
                rightPredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
            }
        } else if (split instanceof CategoricalSplit) {
            CategoricalSplit categoricalSplit = (CategoricalSplit) split;
            double[] leftCategories = categoricalSplit.leftCategories();
            double[] rightCategories = categoricalSplit.rightCategories();
            if (feature instanceof BinaryFeature) {
                BinaryFeature binaryFeature = (BinaryFeature) feature;
                SimplePredicate.Operator leftOperator;
                SimplePredicate.Operator rightOperator;
                if (Arrays.equals(TRUE, leftCategories) && Arrays.equals(FALSE, rightCategories)) {
                    leftOperator = SimplePredicate.Operator.EQUAL;
                    rightOperator = SimplePredicate.Operator.NOT_EQUAL;
                } else if (Arrays.equals(FALSE, leftCategories) && Arrays.equals(TRUE, rightCategories)) {
                    leftOperator = SimplePredicate.Operator.NOT_EQUAL;
                    rightOperator = SimplePredicate.Operator.EQUAL;
                } else {
                    throw new IllegalArgumentException();
                }
                String value = ValueUtil.formatValue(binaryFeature.getValue());
                leftPredicate = predicateManager.createSimplePredicate(binaryFeature, leftOperator, value);
                rightPredicate = predicateManager.createSimplePredicate(binaryFeature, rightOperator, value);
            } else if (feature instanceof CategoricalFeature) {
                CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
                FieldName name = categoricalFeature.getName();
                List<String> values = categoricalFeature.getValues();
                if (values.size() != (leftCategories.length + rightCategories.length)) {
                    throw new IllegalArgumentException();
                }
                final Set<String> parentValues = parentFieldValues.get(name);
                com.google.common.base.Predicate<String> valueFilter = new com.google.common.base.Predicate<String>() {

                    @Override
                    public boolean apply(String value) {
                        if (parentValues != null) {
                            return parentValues.contains(value);
                        }
                        return true;
                    }
                };
                List<String> leftValues = selectValues(values, leftCategories, valueFilter);
                List<String> rightValues = selectValues(values, rightCategories, valueFilter);
                leftFieldValues = new HashMap<>(parentFieldValues);
                leftFieldValues.put(name, new HashSet<>(leftValues));
                rightFieldValues = new HashMap<>(parentFieldValues);
                rightFieldValues.put(name, new HashSet<>(rightValues));
                leftPredicate = predicateManager.createSimpleSetPredicate(categoricalFeature, leftValues);
                rightPredicate = predicateManager.createSimpleSetPredicate(categoricalFeature, rightValues);
            } else {
                throw new IllegalArgumentException();
            }
        } else {
            throw new IllegalArgumentException();
        }
        Node result = new Node();
        Node leftChild = encodeNode(internalNode.leftChild(), predicateManager, leftFieldValues, miningFunction, schema).setPredicate(leftPredicate);
        Node rightChild = encodeNode(internalNode.rightChild(), predicateManager, rightFieldValues, miningFunction, schema).setPredicate(rightPredicate);
        result.addNodes(leftChild, rightChild);
        return result;
    } else if (node instanceof LeafNode) {
        LeafNode leafNode = (LeafNode) node;
        Node result = new Node();
        switch(miningFunction) {
            case REGRESSION:
                {
                    String score = ValueUtil.formatValue(node.prediction());
                    result.setScore(score);
                }
                break;
            case CLASSIFICATION:
                {
                    CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
                    int index = ValueUtil.asInt(node.prediction());
                    result.setScore(categoricalLabel.getValue(index));
                    ImpurityCalculator impurityCalculator = node.impurityStats();
                    result.setRecordCount((double) impurityCalculator.count());
                    double[] stats = impurityCalculator.stats();
                    for (int i = 0; i < stats.length; i++) {
                        ScoreDistribution scoreDistribution = new ScoreDistribution(categoricalLabel.getValue(i), stats[i]);
                        result.addScoreDistributions(scoreDistribution);
                    }
                }
                break;
            default:
                throw new UnsupportedOperationException();
        }
        return result;
    } else {
        throw new IllegalArgumentException();
    }
}
Also used : HashSet(java.util.HashSet) Set(java.util.Set) InternalNode(org.apache.spark.ml.tree.InternalNode) Node(org.dmg.pmml.tree.Node) LeafNode(org.apache.spark.ml.tree.LeafNode) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Feature(org.jpmml.converter.Feature) BinaryFeature(org.jpmml.converter.BinaryFeature) BooleanFeature(org.jpmml.converter.BooleanFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) BooleanFeature(org.jpmml.converter.BooleanFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) ContinuousSplit(org.apache.spark.ml.tree.ContinuousSplit) LeafNode(org.apache.spark.ml.tree.LeafNode) FieldName(org.dmg.pmml.FieldName) BinaryFeature(org.jpmml.converter.BinaryFeature) ScoreDistribution(org.dmg.pmml.ScoreDistribution) ContinuousFeature(org.jpmml.converter.ContinuousFeature) ImpurityCalculator(org.apache.spark.mllib.tree.impurity.ImpurityCalculator) CategoricalSplit(org.apache.spark.ml.tree.CategoricalSplit) CategoricalLabel(org.jpmml.converter.CategoricalLabel) InternalNode(org.apache.spark.ml.tree.InternalNode) Split(org.apache.spark.ml.tree.Split) ContinuousSplit(org.apache.spark.ml.tree.ContinuousSplit) CategoricalSplit(org.apache.spark.ml.tree.CategoricalSplit)

Example 9 with Node

use of org.dmg.pmml.tree.Node in project jpmml-sparkml by jpmml.

the class TreeModelCompactor method getParentNode.

private Node getParentNode() {
    Deque<PMMLObject> parents = getParents();
    PMMLObject parent = parents.peekFirst();
    if (parent instanceof Node) {
        return (Node) parent;
    }
    return null;
}
Also used : Node(org.dmg.pmml.tree.Node) PMMLObject(org.dmg.pmml.PMMLObject)

Example 10 with Node

use of org.dmg.pmml.tree.Node in project jpmml-sparkml by jpmml.

the class TreeModelCompactor method swapChildren.

private static void swapChildren(List<Node> children) {
    Node firstChild = children.remove(0);
    children.add(1, firstChild);
}
Also used : Node(org.dmg.pmml.tree.Node)

Aggregations

Node (org.dmg.pmml.tree.Node)20 Predicate (org.dmg.pmml.Predicate)9 SimplePredicate (org.dmg.pmml.SimplePredicate)9 True (org.dmg.pmml.True)9 TreeModel (org.dmg.pmml.tree.TreeModel)9 ContinuousFeature (org.jpmml.converter.ContinuousFeature)6 CategoricalFeature (org.jpmml.converter.CategoricalFeature)5 Feature (org.jpmml.converter.Feature)5 FieldName (org.dmg.pmml.FieldName)4 List (java.util.List)3 SimpleSetPredicate (org.dmg.pmml.SimpleSetPredicate)3 MiningModel (org.dmg.pmml.mining.MiningModel)3 ArrayList (java.util.ArrayList)2 HashSet (java.util.HashSet)2 Set (java.util.Set)2 InternalNode (org.apache.spark.ml.tree.InternalNode)2 LeafNode (org.apache.spark.ml.tree.LeafNode)2 PMMLObject (org.dmg.pmml.PMMLObject)2 ScoreDistribution (org.dmg.pmml.ScoreDistribution)2 BooleanFeature (org.jpmml.converter.BooleanFeature)2