Search in sources :

Example 1 with CategoryManager

use of org.jpmml.converter.CategoryManager in project jpmml-r by jpmml.

the class GBMConverter method encodeNode.

private Node encodeNode(int i, Predicate predicate, RGenericVector tree, RGenericVector c_splits, FlagManager flagManager, CategoryManager categoryManager, Schema schema) {
    Integer id = Integer.valueOf(i + 1);
    RIntegerVector splitVar = tree.getIntegerValue(0);
    RDoubleVector splitCodePred = tree.getDoubleValue(1);
    RIntegerVector leftNode = tree.getIntegerValue(2);
    RIntegerVector rightNode = tree.getIntegerValue(3);
    RIntegerVector missingNode = tree.getIntegerValue(4);
    RDoubleVector prediction = tree.getDoubleValue(7);
    Integer var = splitVar.getValue(i);
    if (var == -1) {
        Double value = prediction.getValue(i);
        Node result = new LeafNode(value, predicate).setId(id);
        return result;
    }
    Boolean isMissing;
    FlagManager missingFlagManager = flagManager;
    FlagManager nonMissingFlagManager = flagManager;
    Predicate missingPredicate;
    Feature feature = schema.getFeature(var);
    {
        String name = feature.getName();
        isMissing = flagManager.getValue(name);
        if (isMissing == null) {
            missingFlagManager = missingFlagManager.fork(name, Boolean.TRUE);
            nonMissingFlagManager = nonMissingFlagManager.fork(name, Boolean.FALSE);
        }
        missingPredicate = createSimplePredicate(feature, SimplePredicate.Operator.IS_MISSING, null);
    }
    CategoryManager leftCategoryManager = categoryManager;
    CategoryManager rightCategoryManager = categoryManager;
    Predicate leftPredicate;
    Predicate rightPredicate;
    Double split = splitCodePred.getValue(i);
    if (feature instanceof CategoricalFeature) {
        CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
        String name = categoricalFeature.getName();
        List<?> values = categoricalFeature.getValues();
        int index = ValueUtil.asInt(split);
        RIntegerVector c_split = c_splits.getIntegerValue(index);
        List<Integer> splitValues = c_split.getValues();
        java.util.function.Predicate<Object> valueFilter = categoryManager.getValueFilter(name);
        List<Object> leftValues = selectValues(values, valueFilter, splitValues, true);
        List<Object> rightValues = selectValues(values, valueFilter, splitValues, false);
        leftCategoryManager = leftCategoryManager.fork(name, leftValues);
        rightCategoryManager = rightCategoryManager.fork(name, rightValues);
        leftPredicate = createPredicate(categoricalFeature, leftValues);
        rightPredicate = createPredicate(categoricalFeature, rightValues);
    } else {
        ContinuousFeature continuousFeature = feature.toContinuousFeature();
        leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_THAN, split);
        rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_OR_EQUAL, split);
    }
    Node result = new BranchNode(null, predicate).setId(id);
    List<Node> nodes = result.getNodes();
    Integer missing = missingNode.getValue(i);
    if (missing != -1 && (isMissing == null || isMissing)) {
        Node missingChild = encodeNode(missing, missingPredicate, tree, c_splits, missingFlagManager, categoryManager, schema);
        nodes.add(missingChild);
    }
    Integer left = leftNode.getValue(i);
    if (left != -1 && (isMissing == null || !isMissing)) {
        Node leftChild = encodeNode(left, leftPredicate, tree, c_splits, nonMissingFlagManager, leftCategoryManager, schema);
        nodes.add(leftChild);
    }
    Integer right = rightNode.getValue(i);
    if (right != -1 && (isMissing == null || !isMissing)) {
        Node rightChild = encodeNode(right, rightPredicate, tree, c_splits, nonMissingFlagManager, rightCategoryManager, schema);
        nodes.add(rightChild);
    }
    return result;
}
Also used : Node(org.dmg.pmml.tree.Node) BranchNode(org.dmg.pmml.tree.BranchNode) LeafNode(org.dmg.pmml.tree.LeafNode) FlagManager(org.jpmml.converter.FlagManager) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Feature(org.jpmml.converter.Feature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) BranchNode(org.dmg.pmml.tree.BranchNode) ContinuousFeature(org.jpmml.converter.ContinuousFeature) LeafNode(org.dmg.pmml.tree.LeafNode) CategoryManager(org.jpmml.converter.CategoryManager)

Example 2 with CategoryManager

use of org.jpmml.converter.CategoryManager in project jpmml-r by jpmml.

the class GBMConverter method encodeTreeModel.

private TreeModel encodeTreeModel(MiningFunction miningFunction, RGenericVector tree, RGenericVector c_splits, Schema schema) {
    Node root = encodeNode(0, True.INSTANCE, tree, c_splits, new FlagManager(), new CategoryManager(), schema);
    TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.MULTI_SPLIT);
    return treeModel;
}
Also used : TreeModel(org.dmg.pmml.tree.TreeModel) Node(org.dmg.pmml.tree.Node) BranchNode(org.dmg.pmml.tree.BranchNode) LeafNode(org.dmg.pmml.tree.LeafNode) FlagManager(org.jpmml.converter.FlagManager) CategoryManager(org.jpmml.converter.CategoryManager)

Example 3 with CategoryManager

use of org.jpmml.converter.CategoryManager in project jpmml-r by jpmml.

the class RangerConverter method encodeNode.

private Node encodeNode(Predicate predicate, int index, ScoreEncoder scoreEncoder, RNumberVector<?> leftChildIDs, RNumberVector<?> rightChildIDs, RNumberVector<?> splitVarIDs, RNumberVector<?> splitValues, RGenericVector terminalClassCounts, CategoryManager categoryManager, Schema schema) {
    int leftIndex = ValueUtil.asInt(leftChildIDs.getValue(index));
    int rightIndex = ValueUtil.asInt(rightChildIDs.getValue(index));
    Number splitValue = splitValues.getValue(index);
    RNumberVector<?> terminalClassCount = (terminalClassCounts != null ? terminalClassCounts.getNumericValue(index) : null);
    if (leftIndex == 0 && rightIndex == 0) {
        Node result = new LeafNode(null, predicate);
        return scoreEncoder.encode(result, splitValue, terminalClassCount);
    }
    CategoryManager leftCategoryManager = categoryManager;
    CategoryManager rightCategoryManager = categoryManager;
    Predicate leftPredicate;
    Predicate rightPredicate;
    int splitVarIndex = ValueUtil.asInt(splitVarIDs.getValue(index));
    Feature feature = schema.getFeature(this.hasDependentVar ? (splitVarIndex - 1) : splitVarIndex);
    if (feature instanceof CategoricalFeature) {
        CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
        int splitLevelIndex = ValueUtil.asInt(Math.floor(splitValue.doubleValue()));
        String name = categoricalFeature.getName();
        List<?> values = categoricalFeature.getValues();
        java.util.function.Predicate<Object> valueFilter = categoryManager.getValueFilter(name);
        List<Object> leftValues = filterValues(values.subList(0, splitLevelIndex), valueFilter);
        List<Object> rightValues = filterValues(values.subList(splitLevelIndex, values.size()), valueFilter);
        leftCategoryManager = leftCategoryManager.fork(name, leftValues);
        rightCategoryManager = rightCategoryManager.fork(name, rightValues);
        leftPredicate = createPredicate(categoricalFeature, leftValues);
        rightPredicate = createPredicate(categoricalFeature, rightValues);
    } else {
        ContinuousFeature continuousFeature = feature.toContinuousFeature();
        leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, splitValue);
        rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, splitValue);
    }
    Node leftChild = encodeNode(leftPredicate, leftIndex, scoreEncoder, leftChildIDs, rightChildIDs, splitVarIDs, splitValues, terminalClassCounts, leftCategoryManager, schema);
    Node rightChild = encodeNode(rightPredicate, rightIndex, scoreEncoder, leftChildIDs, rightChildIDs, splitVarIDs, splitValues, terminalClassCounts, rightCategoryManager, schema);
    Node result = new BranchNode(null, predicate).addNodes(leftChild, rightChild);
    return result;
}
Also used : Node(org.dmg.pmml.tree.Node) ClassifierNode(org.dmg.pmml.tree.ClassifierNode) BranchNode(org.dmg.pmml.tree.BranchNode) LeafNode(org.dmg.pmml.tree.LeafNode) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Feature(org.jpmml.converter.Feature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) BranchNode(org.dmg.pmml.tree.BranchNode) ContinuousFeature(org.jpmml.converter.ContinuousFeature) LeafNode(org.dmg.pmml.tree.LeafNode) CategoryManager(org.jpmml.converter.CategoryManager)

Example 4 with CategoryManager

use of org.jpmml.converter.CategoryManager in project jpmml-r by jpmml.

the class RangerConverter method encodeTreeModel.

private TreeModel encodeTreeModel(MiningFunction miningFunction, ScoreEncoder scoreEncoder, RGenericVector childNodeIDs, RNumberVector<?> splitVarIDs, RNumberVector<?> splitValues, RGenericVector terminalClassCounts, Schema schema) {
    RNumberVector<?> leftChildIDs = childNodeIDs.getNumericValue(0);
    RNumberVector<?> rightChildIDs = childNodeIDs.getNumericValue(1);
    Node root = encodeNode(True.INSTANCE, 0, scoreEncoder, leftChildIDs, rightChildIDs, splitVarIDs, splitValues, terminalClassCounts, new CategoryManager(), schema);
    TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
    return treeModel;
}
Also used : TreeModel(org.dmg.pmml.tree.TreeModel) Node(org.dmg.pmml.tree.Node) ClassifierNode(org.dmg.pmml.tree.ClassifierNode) BranchNode(org.dmg.pmml.tree.BranchNode) LeafNode(org.dmg.pmml.tree.LeafNode) CategoryManager(org.jpmml.converter.CategoryManager)

Example 5 with CategoryManager

use of org.jpmml.converter.CategoryManager in project jpmml-sparkml by jpmml.

the class TreeModelUtil method encodeNode.

private static Node encodeNode(Predicate predicate, ScoreEncoder scoreEncoder, org.apache.spark.ml.tree.Node sparkNode, PredicateManager predicateManager, CategoryManager categoryManager, Schema schema) {
    if (sparkNode instanceof org.apache.spark.ml.tree.LeafNode) {
        org.apache.spark.ml.tree.LeafNode leafNode = (org.apache.spark.ml.tree.LeafNode) sparkNode;
        Node result = new LeafNode(null, predicate);
        return scoreEncoder.encode(result, leafNode);
    } else if (sparkNode instanceof org.apache.spark.ml.tree.InternalNode) {
        org.apache.spark.ml.tree.InternalNode internalNode = (org.apache.spark.ml.tree.InternalNode) sparkNode;
        CategoryManager leftCategoryManager = categoryManager;
        CategoryManager rightCategoryManager = categoryManager;
        Predicate leftPredicate;
        Predicate rightPredicate;
        Split split = internalNode.split();
        Feature feature = schema.getFeature(split.featureIndex());
        if (split instanceof ContinuousSplit) {
            ContinuousSplit continuousSplit = (ContinuousSplit) split;
            Double threshold = continuousSplit.threshold();
            if (feature instanceof BooleanFeature) {
                BooleanFeature booleanFeature = (BooleanFeature) feature;
                if (threshold != 0.5d) {
                    throw new IllegalArgumentException("Invalid split threshold value " + threshold + " for a boolean feature");
                }
                leftPredicate = predicateManager.createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(0));
                rightPredicate = predicateManager.createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(1));
            } else {
                ContinuousFeature continuousFeature = feature.toContinuousFeature();
                DataType dataType = continuousFeature.getDataType();
                switch(dataType) {
                    case INTEGER:
                        threshold = Math.floor(threshold);
                        break;
                    default:
                        break;
                }
                leftPredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, threshold);
                rightPredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, threshold);
            }
        } else if (split instanceof CategoricalSplit) {
            CategoricalSplit categoricalSplit = (CategoricalSplit) split;
            double[] leftCategories = categoricalSplit.leftCategories();
            double[] rightCategories = categoricalSplit.rightCategories();
            if (feature instanceof BinaryFeature) {
                BinaryFeature binaryFeature = (BinaryFeature) feature;
                SimplePredicate.Operator leftOperator;
                SimplePredicate.Operator rightOperator;
                if (Arrays.equals(TRUE, leftCategories) && Arrays.equals(FALSE, rightCategories)) {
                    leftOperator = SimplePredicate.Operator.EQUAL;
                    rightOperator = SimplePredicate.Operator.NOT_EQUAL;
                } else if (Arrays.equals(FALSE, leftCategories) && Arrays.equals(TRUE, rightCategories)) {
                    leftOperator = SimplePredicate.Operator.NOT_EQUAL;
                    rightOperator = SimplePredicate.Operator.EQUAL;
                } else {
                    throw new IllegalArgumentException();
                }
                Object value = binaryFeature.getValue();
                leftPredicate = predicateManager.createSimplePredicate(binaryFeature, leftOperator, value);
                rightPredicate = predicateManager.createSimplePredicate(binaryFeature, rightOperator, value);
            } else if (feature instanceof CategoricalFeature) {
                CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
                FieldName name = categoricalFeature.getName();
                List<?> values = categoricalFeature.getValues();
                if (values.size() != (leftCategories.length + rightCategories.length)) {
                    throw new IllegalArgumentException();
                }
                java.util.function.Predicate<Object> valueFilter = categoryManager.getValueFilter(name);
                List<Object> leftValues = selectValues(values, leftCategories, valueFilter);
                List<Object> rightValues = selectValues(values, rightCategories, valueFilter);
                leftCategoryManager = categoryManager.fork(name, leftValues);
                rightCategoryManager = categoryManager.fork(name, rightValues);
                leftPredicate = predicateManager.createPredicate(categoricalFeature, leftValues);
                rightPredicate = predicateManager.createPredicate(categoricalFeature, rightValues);
            } else {
                throw new IllegalArgumentException();
            }
        } else {
            throw new IllegalArgumentException();
        }
        Node leftChild = encodeNode(leftPredicate, scoreEncoder, internalNode.leftChild(), predicateManager, leftCategoryManager, schema);
        Node rightChild = encodeNode(rightPredicate, scoreEncoder, internalNode.rightChild(), predicateManager, rightCategoryManager, schema);
        Node result = new BranchNode(null, predicate).addNodes(leftChild, rightChild);
        return result;
    } else {
        throw new IllegalArgumentException();
    }
}
Also used : Node(org.dmg.pmml.tree.Node) ClassifierNode(org.dmg.pmml.tree.ClassifierNode) BranchNode(org.dmg.pmml.tree.BranchNode) LeafNode(org.dmg.pmml.tree.LeafNode) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Feature(org.jpmml.converter.Feature) BinaryFeature(org.jpmml.converter.BinaryFeature) BooleanFeature(org.jpmml.converter.BooleanFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) BooleanFeature(org.jpmml.converter.BooleanFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) ContinuousSplit(org.apache.spark.ml.tree.ContinuousSplit) BranchNode(org.dmg.pmml.tree.BranchNode) LeafNode(org.dmg.pmml.tree.LeafNode) DataType(org.dmg.pmml.DataType) ArrayList(java.util.ArrayList) List(java.util.List) FieldName(org.dmg.pmml.FieldName) CategoryManager(org.jpmml.converter.CategoryManager) BinaryFeature(org.jpmml.converter.BinaryFeature) SimplePredicate(org.dmg.pmml.SimplePredicate) ContinuousFeature(org.jpmml.converter.ContinuousFeature) CategoricalSplit(org.apache.spark.ml.tree.CategoricalSplit) Split(org.apache.spark.ml.tree.Split) ContinuousSplit(org.apache.spark.ml.tree.ContinuousSplit) CategoricalSplit(org.apache.spark.ml.tree.CategoricalSplit)

Aggregations

BranchNode (org.dmg.pmml.tree.BranchNode)8 LeafNode (org.dmg.pmml.tree.LeafNode)8 Node (org.dmg.pmml.tree.Node)8 CategoryManager (org.jpmml.converter.CategoryManager)8 Predicate (org.dmg.pmml.Predicate)4 SimplePredicate (org.dmg.pmml.SimplePredicate)4 ClassifierNode (org.dmg.pmml.tree.ClassifierNode)4 TreeModel (org.dmg.pmml.tree.TreeModel)4 CategoricalFeature (org.jpmml.converter.CategoricalFeature)4 ContinuousFeature (org.jpmml.converter.ContinuousFeature)4 Feature (org.jpmml.converter.Feature)4 ArrayList (java.util.ArrayList)2 List (java.util.List)2 BooleanFeature (org.jpmml.converter.BooleanFeature)2 FlagManager (org.jpmml.converter.FlagManager)2 CategoricalSplit (org.apache.spark.ml.tree.CategoricalSplit)1 ContinuousSplit (org.apache.spark.ml.tree.ContinuousSplit)1 DecisionTreeModel (org.apache.spark.ml.tree.DecisionTreeModel)1 Split (org.apache.spark.ml.tree.Split)1 DataType (org.dmg.pmml.DataType)1