Search in sources :

Example 1 with LeafNode

use of org.dmg.pmml.tree.LeafNode in project jpmml-r by jpmml.

the class IForestConverter method encodeNode.

private Node encodeNode(int index, Predicate predicate, int depth, List<Integer> nodeStatus, List<Integer> nodeSize, List<Integer> leftDaughter, List<Integer> rightDaughter, List<Integer> splitAtt, List<Double> splitValue, Schema schema) {
    Integer id = Integer.valueOf(index + 1);
    int status = nodeStatus.get(index);
    int size = nodeSize.get(index);
    // Interior node
    if (status == -3) {
        int att = splitAtt.get(index);
        ContinuousFeature feature = (ContinuousFeature) schema.getFeature(att - 1);
        Double value = splitValue.get(index);
        Predicate leftPredicate = createSimplePredicate(feature, SimplePredicate.Operator.LESS_THAN, value);
        Predicate rightPredicate = createSimplePredicate(feature, SimplePredicate.Operator.GREATER_OR_EQUAL, value);
        Node leftChild = encodeNode(leftDaughter.get(index) - 1, leftPredicate, depth + 1, nodeStatus, nodeSize, leftDaughter, rightDaughter, splitAtt, splitValue, schema);
        Node rightChild = encodeNode(rightDaughter.get(index) - 1, rightPredicate, depth + 1, nodeStatus, nodeSize, leftDaughter, rightDaughter, splitAtt, splitValue, schema);
        Node result = new BranchNode(null, predicate).setId(id).addNodes(leftChild, rightChild);
        return result;
    } else // Terminal node
    if (status == -1) {
        Node result = new LeafNode(depth + avgPathLength(size), predicate).setId(id);
        return result;
    } else {
        throw new IllegalArgumentException();
    }
}
Also used : BranchNode(org.dmg.pmml.tree.BranchNode) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Node(org.dmg.pmml.tree.Node) BranchNode(org.dmg.pmml.tree.BranchNode) LeafNode(org.dmg.pmml.tree.LeafNode) LeafNode(org.dmg.pmml.tree.LeafNode) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate)

Example 2 with LeafNode

use of org.dmg.pmml.tree.LeafNode in project jpmml-r by jpmml.

the class PartyConverter method encodeNode.

private Node encodeNode(RGenericVector partyNode, Predicate predicate, RVector<?> response, RDoubleVector prob, Schema schema) {
    RIntegerVector id = partyNode.getIntegerElement("id");
    RGenericVector split = partyNode.getGenericElement("split");
    RGenericVector kids = partyNode.getGenericElement("kids");
    RGenericVector surrogates = partyNode.getGenericElement("surrogates");
    RGenericVector info = partyNode.getGenericElement("info");
    if (surrogates != null) {
        throw new IllegalArgumentException();
    }
    Label label = schema.getLabel();
    List<? extends Feature> features = schema.getFeatures();
    Node result;
    if (response instanceof RFactorVector) {
        result = new ClassifierNode(null, predicate);
    } else {
        if (kids == null) {
            result = new LeafNode(null, predicate);
        } else {
            result = new BranchNode(null, predicate);
        }
    }
    result.setId(Integer.valueOf(id.asScalar()));
    if (response instanceof RFactorVector) {
        RFactorVector factor = (RFactorVector) response;
        int index = id.asScalar() - 1;
        result.setScore(factor.getFactorValue(index));
        CategoricalLabel categoricalLabel = (CategoricalLabel) label;
        List<Double> probabilities = FortranMatrixUtil.getRow(prob.getValues(), response.size(), categoricalLabel.size(), index);
        List<ScoreDistribution> scoreDistributions = result.getScoreDistributions();
        for (int i = 0; i < categoricalLabel.size(); i++) {
            Object value = categoricalLabel.getValue(i);
            Double probability = probabilities.get(i);
            ScoreDistribution scoreDistribution = new ScoreDistribution(value, probability);
            scoreDistributions.add(scoreDistribution);
        }
    } else {
        result.setScore(response.getValue(id.asScalar() - 1));
    }
    if (kids == null) {
        return result;
    }
    RIntegerVector varid = split.getIntegerElement("varid");
    RDoubleVector breaks = split.getDoubleElement("breaks");
    RIntegerVector index = split.getIntegerElement("index");
    RBooleanVector right = split.getBooleanElement("right");
    Feature feature = features.get(varid.asScalar() - 1);
    if (breaks != null && index == null) {
        ContinuousFeature continuousFeature = (ContinuousFeature) feature;
        if (kids.size() != 2) {
            throw new IllegalArgumentException();
        }
        if (breaks.size() != 1) {
            throw new IllegalArgumentException();
        }
        Predicate leftPredicate;
        Predicate rightPredicate;
        Double value = breaks.asScalar();
        if (right.asScalar()) {
            leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
            rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
        } else {
            leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_THAN, value);
            rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_OR_EQUAL, value);
        }
        Node leftChild = encodeNode(kids.getGenericValue(0), leftPredicate, response, prob, schema);
        Node rightChild = encodeNode(kids.getGenericValue(1), rightPredicate, response, prob, schema);
        result.addNodes(leftChild, rightChild);
    } else if (breaks == null && index != null) {
        CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
        if (kids.size() < 2) {
            throw new IllegalArgumentException();
        }
        List<?> values = categoricalFeature.getValues();
        for (int i = 0; i < kids.size(); i++) {
            Predicate childPredicate;
            if (right.asScalar()) {
                childPredicate = createPredicate(categoricalFeature, selectValues(values, index, i + 1));
            } else {
                throw new IllegalArgumentException();
            }
            Node child = encodeNode(kids.getGenericValue(i), childPredicate, response, prob, schema);
            result.addNodes(child);
        }
    } else {
        throw new IllegalArgumentException();
    }
    return result;
}
Also used : BranchNode(org.dmg.pmml.tree.BranchNode) LeafNode(org.dmg.pmml.tree.LeafNode) Node(org.dmg.pmml.tree.Node) ClassifierNode(org.dmg.pmml.tree.ClassifierNode) CategoricalLabel(org.jpmml.converter.CategoricalLabel) Label(org.jpmml.converter.Label) Feature(org.jpmml.converter.Feature) ContinuousFeature(org.jpmml.converter.ContinuousFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) ScoreDistribution(org.dmg.pmml.ScoreDistribution) BranchNode(org.dmg.pmml.tree.BranchNode) ContinuousFeature(org.jpmml.converter.ContinuousFeature) CategoricalLabel(org.jpmml.converter.CategoricalLabel) LeafNode(org.dmg.pmml.tree.LeafNode) ClassifierNode(org.dmg.pmml.tree.ClassifierNode) ArrayList(java.util.ArrayList) List(java.util.List)

Example 3 with LeafNode

use of org.dmg.pmml.tree.LeafNode in project jpmml-r by jpmml.

the class GBMConverter method encodeNode.

private Node encodeNode(int i, Predicate predicate, RGenericVector tree, RGenericVector c_splits, FlagManager flagManager, CategoryManager categoryManager, Schema schema) {
    Integer id = Integer.valueOf(i + 1);
    RIntegerVector splitVar = tree.getIntegerValue(0);
    RDoubleVector splitCodePred = tree.getDoubleValue(1);
    RIntegerVector leftNode = tree.getIntegerValue(2);
    RIntegerVector rightNode = tree.getIntegerValue(3);
    RIntegerVector missingNode = tree.getIntegerValue(4);
    RDoubleVector prediction = tree.getDoubleValue(7);
    Integer var = splitVar.getValue(i);
    if (var == -1) {
        Double value = prediction.getValue(i);
        Node result = new LeafNode(value, predicate).setId(id);
        return result;
    }
    Boolean isMissing;
    FlagManager missingFlagManager = flagManager;
    FlagManager nonMissingFlagManager = flagManager;
    Predicate missingPredicate;
    Feature feature = schema.getFeature(var);
    {
        String name = feature.getName();
        isMissing = flagManager.getValue(name);
        if (isMissing == null) {
            missingFlagManager = missingFlagManager.fork(name, Boolean.TRUE);
            nonMissingFlagManager = nonMissingFlagManager.fork(name, Boolean.FALSE);
        }
        missingPredicate = createSimplePredicate(feature, SimplePredicate.Operator.IS_MISSING, null);
    }
    CategoryManager leftCategoryManager = categoryManager;
    CategoryManager rightCategoryManager = categoryManager;
    Predicate leftPredicate;
    Predicate rightPredicate;
    Double split = splitCodePred.getValue(i);
    if (feature instanceof CategoricalFeature) {
        CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
        String name = categoricalFeature.getName();
        List<?> values = categoricalFeature.getValues();
        int index = ValueUtil.asInt(split);
        RIntegerVector c_split = c_splits.getIntegerValue(index);
        List<Integer> splitValues = c_split.getValues();
        java.util.function.Predicate<Object> valueFilter = categoryManager.getValueFilter(name);
        List<Object> leftValues = selectValues(values, valueFilter, splitValues, true);
        List<Object> rightValues = selectValues(values, valueFilter, splitValues, false);
        leftCategoryManager = leftCategoryManager.fork(name, leftValues);
        rightCategoryManager = rightCategoryManager.fork(name, rightValues);
        leftPredicate = createPredicate(categoricalFeature, leftValues);
        rightPredicate = createPredicate(categoricalFeature, rightValues);
    } else {
        ContinuousFeature continuousFeature = feature.toContinuousFeature();
        leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_THAN, split);
        rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_OR_EQUAL, split);
    }
    Node result = new BranchNode(null, predicate).setId(id);
    List<Node> nodes = result.getNodes();
    Integer missing = missingNode.getValue(i);
    if (missing != -1 && (isMissing == null || isMissing)) {
        Node missingChild = encodeNode(missing, missingPredicate, tree, c_splits, missingFlagManager, categoryManager, schema);
        nodes.add(missingChild);
    }
    Integer left = leftNode.getValue(i);
    if (left != -1 && (isMissing == null || !isMissing)) {
        Node leftChild = encodeNode(left, leftPredicate, tree, c_splits, nonMissingFlagManager, leftCategoryManager, schema);
        nodes.add(leftChild);
    }
    Integer right = rightNode.getValue(i);
    if (right != -1 && (isMissing == null || !isMissing)) {
        Node rightChild = encodeNode(right, rightPredicate, tree, c_splits, nonMissingFlagManager, rightCategoryManager, schema);
        nodes.add(rightChild);
    }
    return result;
}
Also used : Node(org.dmg.pmml.tree.Node) BranchNode(org.dmg.pmml.tree.BranchNode) LeafNode(org.dmg.pmml.tree.LeafNode) FlagManager(org.jpmml.converter.FlagManager) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Feature(org.jpmml.converter.Feature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) BranchNode(org.dmg.pmml.tree.BranchNode) ContinuousFeature(org.jpmml.converter.ContinuousFeature) LeafNode(org.dmg.pmml.tree.LeafNode) CategoryManager(org.jpmml.converter.CategoryManager)

Example 4 with LeafNode

use of org.dmg.pmml.tree.LeafNode in project jpmml-r by jpmml.

the class RangerConverter method encodeNode.

private Node encodeNode(Predicate predicate, int index, ScoreEncoder scoreEncoder, RNumberVector<?> leftChildIDs, RNumberVector<?> rightChildIDs, RNumberVector<?> splitVarIDs, RNumberVector<?> splitValues, RGenericVector terminalClassCounts, CategoryManager categoryManager, Schema schema) {
    int leftIndex = ValueUtil.asInt(leftChildIDs.getValue(index));
    int rightIndex = ValueUtil.asInt(rightChildIDs.getValue(index));
    Number splitValue = splitValues.getValue(index);
    RNumberVector<?> terminalClassCount = (terminalClassCounts != null ? terminalClassCounts.getNumericValue(index) : null);
    if (leftIndex == 0 && rightIndex == 0) {
        Node result = new LeafNode(null, predicate);
        return scoreEncoder.encode(result, splitValue, terminalClassCount);
    }
    CategoryManager leftCategoryManager = categoryManager;
    CategoryManager rightCategoryManager = categoryManager;
    Predicate leftPredicate;
    Predicate rightPredicate;
    int splitVarIndex = ValueUtil.asInt(splitVarIDs.getValue(index));
    Feature feature = schema.getFeature(this.hasDependentVar ? (splitVarIndex - 1) : splitVarIndex);
    if (feature instanceof CategoricalFeature) {
        CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
        int splitLevelIndex = ValueUtil.asInt(Math.floor(splitValue.doubleValue()));
        String name = categoricalFeature.getName();
        List<?> values = categoricalFeature.getValues();
        java.util.function.Predicate<Object> valueFilter = categoryManager.getValueFilter(name);
        List<Object> leftValues = filterValues(values.subList(0, splitLevelIndex), valueFilter);
        List<Object> rightValues = filterValues(values.subList(splitLevelIndex, values.size()), valueFilter);
        leftCategoryManager = leftCategoryManager.fork(name, leftValues);
        rightCategoryManager = rightCategoryManager.fork(name, rightValues);
        leftPredicate = createPredicate(categoricalFeature, leftValues);
        rightPredicate = createPredicate(categoricalFeature, rightValues);
    } else {
        ContinuousFeature continuousFeature = feature.toContinuousFeature();
        leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, splitValue);
        rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, splitValue);
    }
    Node leftChild = encodeNode(leftPredicate, leftIndex, scoreEncoder, leftChildIDs, rightChildIDs, splitVarIDs, splitValues, terminalClassCounts, leftCategoryManager, schema);
    Node rightChild = encodeNode(rightPredicate, rightIndex, scoreEncoder, leftChildIDs, rightChildIDs, splitVarIDs, splitValues, terminalClassCounts, rightCategoryManager, schema);
    Node result = new BranchNode(null, predicate).addNodes(leftChild, rightChild);
    return result;
}
Also used : Node(org.dmg.pmml.tree.Node) ClassifierNode(org.dmg.pmml.tree.ClassifierNode) BranchNode(org.dmg.pmml.tree.BranchNode) LeafNode(org.dmg.pmml.tree.LeafNode) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Feature(org.jpmml.converter.Feature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) BranchNode(org.dmg.pmml.tree.BranchNode) ContinuousFeature(org.jpmml.converter.ContinuousFeature) LeafNode(org.dmg.pmml.tree.LeafNode) CategoryManager(org.jpmml.converter.CategoryManager)

Example 5 with LeafNode

use of org.dmg.pmml.tree.LeafNode in project jpmml-sparkml by jpmml.

the class TreeModelUtil method encodeNode.

private static Node encodeNode(Predicate predicate, ScoreEncoder scoreEncoder, org.apache.spark.ml.tree.Node sparkNode, PredicateManager predicateManager, CategoryManager categoryManager, Schema schema) {
    if (sparkNode instanceof org.apache.spark.ml.tree.LeafNode) {
        org.apache.spark.ml.tree.LeafNode leafNode = (org.apache.spark.ml.tree.LeafNode) sparkNode;
        Node result = new LeafNode(null, predicate);
        return scoreEncoder.encode(result, leafNode);
    } else if (sparkNode instanceof org.apache.spark.ml.tree.InternalNode) {
        org.apache.spark.ml.tree.InternalNode internalNode = (org.apache.spark.ml.tree.InternalNode) sparkNode;
        CategoryManager leftCategoryManager = categoryManager;
        CategoryManager rightCategoryManager = categoryManager;
        Predicate leftPredicate;
        Predicate rightPredicate;
        Split split = internalNode.split();
        Feature feature = schema.getFeature(split.featureIndex());
        if (split instanceof ContinuousSplit) {
            ContinuousSplit continuousSplit = (ContinuousSplit) split;
            Double threshold = continuousSplit.threshold();
            if (feature instanceof BooleanFeature) {
                BooleanFeature booleanFeature = (BooleanFeature) feature;
                if (threshold != 0.5d) {
                    throw new IllegalArgumentException("Invalid split threshold value " + threshold + " for a boolean feature");
                }
                leftPredicate = predicateManager.createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(0));
                rightPredicate = predicateManager.createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(1));
            } else {
                ContinuousFeature continuousFeature = feature.toContinuousFeature();
                DataType dataType = continuousFeature.getDataType();
                switch(dataType) {
                    case INTEGER:
                        threshold = Math.floor(threshold);
                        break;
                    default:
                        break;
                }
                leftPredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, threshold);
                rightPredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, threshold);
            }
        } else if (split instanceof CategoricalSplit) {
            CategoricalSplit categoricalSplit = (CategoricalSplit) split;
            double[] leftCategories = categoricalSplit.leftCategories();
            double[] rightCategories = categoricalSplit.rightCategories();
            if (feature instanceof BinaryFeature) {
                BinaryFeature binaryFeature = (BinaryFeature) feature;
                SimplePredicate.Operator leftOperator;
                SimplePredicate.Operator rightOperator;
                if (Arrays.equals(TRUE, leftCategories) && Arrays.equals(FALSE, rightCategories)) {
                    leftOperator = SimplePredicate.Operator.EQUAL;
                    rightOperator = SimplePredicate.Operator.NOT_EQUAL;
                } else if (Arrays.equals(FALSE, leftCategories) && Arrays.equals(TRUE, rightCategories)) {
                    leftOperator = SimplePredicate.Operator.NOT_EQUAL;
                    rightOperator = SimplePredicate.Operator.EQUAL;
                } else {
                    throw new IllegalArgumentException();
                }
                Object value = binaryFeature.getValue();
                leftPredicate = predicateManager.createSimplePredicate(binaryFeature, leftOperator, value);
                rightPredicate = predicateManager.createSimplePredicate(binaryFeature, rightOperator, value);
            } else if (feature instanceof CategoricalFeature) {
                CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
                FieldName name = categoricalFeature.getName();
                List<?> values = categoricalFeature.getValues();
                if (values.size() != (leftCategories.length + rightCategories.length)) {
                    throw new IllegalArgumentException();
                }
                java.util.function.Predicate<Object> valueFilter = categoryManager.getValueFilter(name);
                List<Object> leftValues = selectValues(values, leftCategories, valueFilter);
                List<Object> rightValues = selectValues(values, rightCategories, valueFilter);
                leftCategoryManager = categoryManager.fork(name, leftValues);
                rightCategoryManager = categoryManager.fork(name, rightValues);
                leftPredicate = predicateManager.createPredicate(categoricalFeature, leftValues);
                rightPredicate = predicateManager.createPredicate(categoricalFeature, rightValues);
            } else {
                throw new IllegalArgumentException();
            }
        } else {
            throw new IllegalArgumentException();
        }
        Node leftChild = encodeNode(leftPredicate, scoreEncoder, internalNode.leftChild(), predicateManager, leftCategoryManager, schema);
        Node rightChild = encodeNode(rightPredicate, scoreEncoder, internalNode.rightChild(), predicateManager, rightCategoryManager, schema);
        Node result = new BranchNode(null, predicate).addNodes(leftChild, rightChild);
        return result;
    } else {
        throw new IllegalArgumentException();
    }
}
Also used : Node(org.dmg.pmml.tree.Node) ClassifierNode(org.dmg.pmml.tree.ClassifierNode) BranchNode(org.dmg.pmml.tree.BranchNode) LeafNode(org.dmg.pmml.tree.LeafNode) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Feature(org.jpmml.converter.Feature) BinaryFeature(org.jpmml.converter.BinaryFeature) BooleanFeature(org.jpmml.converter.BooleanFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) BooleanFeature(org.jpmml.converter.BooleanFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) ContinuousSplit(org.apache.spark.ml.tree.ContinuousSplit) BranchNode(org.dmg.pmml.tree.BranchNode) LeafNode(org.dmg.pmml.tree.LeafNode) DataType(org.dmg.pmml.DataType) ArrayList(java.util.ArrayList) List(java.util.List) FieldName(org.dmg.pmml.FieldName) CategoryManager(org.jpmml.converter.CategoryManager) BinaryFeature(org.jpmml.converter.BinaryFeature) SimplePredicate(org.dmg.pmml.SimplePredicate) ContinuousFeature(org.jpmml.converter.ContinuousFeature) CategoricalSplit(org.apache.spark.ml.tree.CategoricalSplit) Split(org.apache.spark.ml.tree.Split) ContinuousSplit(org.apache.spark.ml.tree.ContinuousSplit) CategoricalSplit(org.apache.spark.ml.tree.CategoricalSplit)

Aggregations

LeafNode (org.dmg.pmml.tree.LeafNode)9 Node (org.dmg.pmml.tree.Node)9 BranchNode (org.dmg.pmml.tree.BranchNode)8 Predicate (org.dmg.pmml.Predicate)7 SimplePredicate (org.dmg.pmml.SimplePredicate)7 ContinuousFeature (org.jpmml.converter.ContinuousFeature)7 ClassifierNode (org.dmg.pmml.tree.ClassifierNode)6 CategoricalFeature (org.jpmml.converter.CategoricalFeature)6 Feature (org.jpmml.converter.Feature)6 ArrayList (java.util.ArrayList)4 List (java.util.List)4 CategoryManager (org.jpmml.converter.CategoryManager)4 BooleanFeature (org.jpmml.converter.BooleanFeature)2 CategoricalLabel (org.jpmml.converter.CategoricalLabel)2 HashMap (java.util.HashMap)1 DecisionTreeClassificationModel (org.apache.spark.ml.classification.DecisionTreeClassificationModel)1 DecisionTreeRegressionModel (org.apache.spark.ml.regression.DecisionTreeRegressionModel)1 CategoricalSplit (org.apache.spark.ml.tree.CategoricalSplit)1 ContinuousSplit (org.apache.spark.ml.tree.ContinuousSplit)1 DecisionTreeModel (org.apache.spark.ml.tree.DecisionTreeModel)1