Search in sources :

Example 11 with Node

use of org.dmg.pmml.tree.Node in project jpmml-r by jpmml.

the class GBMConverter method encodeNode.

private void encodeNode(Node node, int i, RGenericVector tree, RGenericVector c_splits, Schema schema) {
    RIntegerVector splitVar = (RIntegerVector) tree.getValue(0);
    RDoubleVector splitCodePred = (RDoubleVector) tree.getValue(1);
    RIntegerVector leftNode = (RIntegerVector) tree.getValue(2);
    RIntegerVector rightNode = (RIntegerVector) tree.getValue(3);
    RIntegerVector missingNode = (RIntegerVector) tree.getValue(4);
    RDoubleVector prediction = (RDoubleVector) tree.getValue(7);
    Predicate missingPredicate;
    Predicate leftPredicate;
    Predicate rightPredicate;
    Integer var = splitVar.getValue(i);
    if (var != -1) {
        Feature feature = schema.getFeature(var);
        missingPredicate = createSimplePredicate(feature, SimplePredicate.Operator.IS_MISSING, null);
        Double split = splitCodePred.getValue(i);
        if (feature instanceof CategoricalFeature) {
            CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
            List<String> values = categoricalFeature.getValues();
            int index = ValueUtil.asInt(split);
            RIntegerVector c_split = (RIntegerVector) c_splits.getValue(index);
            List<Integer> splitValues = c_split.getValues();
            leftPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, splitValues, true));
            rightPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, splitValues, false));
        } else {
            ContinuousFeature continuousFeature = feature.toContinuousFeature();
            String value = ValueUtil.formatValue(split);
            leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_THAN, value);
            rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_OR_EQUAL, value);
        }
    } else {
        Double value = prediction.getValue(i);
        node.setScore(ValueUtil.formatValue(value));
        return;
    }
    Integer missing = missingNode.getValue(i);
    if (missing != -1) {
        Node missingChild = new Node().setId(String.valueOf(missing + 1)).setPredicate(missingPredicate);
        encodeNode(missingChild, missing, tree, c_splits, schema);
        node.addNodes(missingChild);
    }
    Integer left = leftNode.getValue(i);
    if (left != -1) {
        Node leftChild = new Node().setId(String.valueOf(left + 1)).setPredicate(leftPredicate);
        encodeNode(leftChild, left, tree, c_splits, schema);
        node.addNodes(leftChild);
    }
    Integer right = rightNode.getValue(i);
    if (right != -1) {
        Node rightChild = new Node().setId(String.valueOf(right + 1)).setPredicate(rightPredicate);
        encodeNode(rightChild, right, tree, c_splits, schema);
        node.addNodes(rightChild);
    }
}
Also used : Node(org.dmg.pmml.tree.Node) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Feature(org.jpmml.converter.Feature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) ContinuousFeature(org.jpmml.converter.ContinuousFeature)

Example 12 with Node

use of org.dmg.pmml.tree.Node in project jpmml-r by jpmml.

the class BinaryTreeConverter method encodeTreeModel.

private TreeModel encodeTreeModel(RGenericVector tree, Schema schema) {
    Node root = new Node().setPredicate(new True());
    encodeNode(root, tree, schema);
    TreeModel treeModel = new TreeModel(this.miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
    return treeModel;
}
Also used : TreeModel(org.dmg.pmml.tree.TreeModel) Node(org.dmg.pmml.tree.Node) True(org.dmg.pmml.True)

Example 13 with Node

use of org.dmg.pmml.tree.Node in project jpmml-r by jpmml.

the class BinaryTreeConverter method encodeNode.

private void encodeNode(Node node, RGenericVector tree, Schema schema) {
    RIntegerVector nodeId = (RIntegerVector) tree.getValue("nodeID");
    RBooleanVector terminal = (RBooleanVector) tree.getValue("terminal");
    RGenericVector psplit = (RGenericVector) tree.getValue("psplit");
    RGenericVector ssplits = (RGenericVector) tree.getValue("ssplits");
    RDoubleVector prediction = (RDoubleVector) tree.getValue("prediction");
    RGenericVector left = (RGenericVector) tree.getValue("left");
    RGenericVector right = (RGenericVector) tree.getValue("right");
    node.setId(String.valueOf(nodeId.asScalar()));
    if ((Boolean.TRUE).equals(terminal.asScalar())) {
        node = encodeScore(node, prediction, schema);
        return;
    }
    RNumberVector<?> splitpoint = (RNumberVector<?>) psplit.getValue("splitpoint");
    RStringVector variableName = (RStringVector) psplit.getValue("variableName");
    if (ssplits.size() > 0) {
        throw new IllegalArgumentException();
    }
    Predicate leftPredicate;
    Predicate rightPredicate;
    FieldName name = FieldName.create(variableName.asScalar());
    Integer index = this.featureIndexes.get(name);
    if (index == null) {
        throw new IllegalArgumentException();
    }
    Feature feature = schema.getFeature(index);
    if (feature instanceof CategoricalFeature) {
        CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
        List<String> values = categoricalFeature.getValues();
        List<Integer> splitValues = (List<Integer>) splitpoint.getValues();
        leftPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, splitValues, true));
        rightPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, splitValues, false));
    } else {
        ContinuousFeature continuousFeature = feature.toContinuousFeature();
        String value = ValueUtil.formatValue((Double) splitpoint.asScalar());
        leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
        rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
    }
    Node leftChild = new Node().setPredicate(leftPredicate);
    encodeNode(leftChild, left, schema);
    Node rightChild = new Node().setPredicate(rightPredicate);
    encodeNode(rightChild, right, schema);
    node.addNodes(leftChild, rightChild);
}
Also used : Node(org.dmg.pmml.tree.Node) Feature(org.jpmml.converter.Feature) ContinuousFeature(org.jpmml.converter.ContinuousFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) ContinuousFeature(org.jpmml.converter.ContinuousFeature) ArrayList(java.util.ArrayList) List(java.util.List) FieldName(org.dmg.pmml.FieldName)

Example 14 with Node

use of org.dmg.pmml.tree.Node in project jpmml-r by jpmml.

the class RangerConverter method encodeRegression.

private MiningModel encodeRegression(RGenericVector ranger, Schema schema) {
    RGenericVector forest = (RGenericVector) ranger.getValue("forest");
    ScoreEncoder scoreEncoder = new ScoreEncoder() {

        @Override
        public void encode(Node node, Number splitValue, RNumberVector<?> terminalClassCount) {
            node.setScore(ValueUtil.formatValue(splitValue));
        }
    };
    List<TreeModel> treeModels = encodeForest(forest, MiningFunction.REGRESSION, scoreEncoder, schema);
    MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel())).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.AVERAGE, treeModels));
    return miningModel;
}
Also used : TreeModel(org.dmg.pmml.tree.TreeModel) MiningModel(org.dmg.pmml.mining.MiningModel) Node(org.dmg.pmml.tree.Node)

Example 15 with Node

use of org.dmg.pmml.tree.Node in project jpmml-r by jpmml.

the class RangerConverter method encodeProbabilityForest.

private MiningModel encodeProbabilityForest(RGenericVector ranger, Schema schema) {
    RGenericVector forest = (RGenericVector) ranger.getValue("forest");
    final RStringVector levels = (RStringVector) forest.getValue("levels");
    CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
    ScoreEncoder scoreEncoder = new ScoreEncoder() {

        @Override
        public void encode(Node node, Number splitValue, RNumberVector<?> terminalClassCount) {
            if (splitValue.doubleValue() != 0d || (terminalClassCount == null || terminalClassCount.size() != levels.size())) {
                throw new IllegalArgumentException();
            }
            Double maxProbability = null;
            for (int i = 0; i < terminalClassCount.size(); i++) {
                String value = levels.getValue(i);
                Double probability = ValueUtil.asDouble(terminalClassCount.getValue(i));
                if (maxProbability == null || (maxProbability).compareTo(probability) < 0) {
                    node.setScore(value);
                    maxProbability = probability;
                }
                ScoreDistribution scoreDisctibution = new ScoreDistribution(value, probability);
                node.addScoreDistributions(scoreDisctibution);
            }
        }
    };
    List<TreeModel> treeModels = encodeForest(forest, MiningFunction.CLASSIFICATION, scoreEncoder, schema);
    MiningModel miningModel = new MiningModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel)).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.AVERAGE, treeModels)).setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, categoricalLabel));
    return miningModel;
}
Also used : Node(org.dmg.pmml.tree.Node) ScoreDistribution(org.dmg.pmml.ScoreDistribution) TreeModel(org.dmg.pmml.tree.TreeModel) MiningModel(org.dmg.pmml.mining.MiningModel) CategoricalLabel(org.jpmml.converter.CategoricalLabel)

Aggregations

Node (org.dmg.pmml.tree.Node)20 Predicate (org.dmg.pmml.Predicate)9 SimplePredicate (org.dmg.pmml.SimplePredicate)9 True (org.dmg.pmml.True)9 TreeModel (org.dmg.pmml.tree.TreeModel)9 ContinuousFeature (org.jpmml.converter.ContinuousFeature)6 CategoricalFeature (org.jpmml.converter.CategoricalFeature)5 Feature (org.jpmml.converter.Feature)5 FieldName (org.dmg.pmml.FieldName)4 List (java.util.List)3 SimpleSetPredicate (org.dmg.pmml.SimpleSetPredicate)3 MiningModel (org.dmg.pmml.mining.MiningModel)3 ArrayList (java.util.ArrayList)2 HashSet (java.util.HashSet)2 Set (java.util.Set)2 InternalNode (org.apache.spark.ml.tree.InternalNode)2 LeafNode (org.apache.spark.ml.tree.LeafNode)2 PMMLObject (org.dmg.pmml.PMMLObject)2 ScoreDistribution (org.dmg.pmml.ScoreDistribution)2 BooleanFeature (org.jpmml.converter.BooleanFeature)2