Search in sources :

Example 21 with Node

use of org.dmg.pmml.tree.Node in project jpmml-r by jpmml.

the class GBMConverter method encodeNode.

private void encodeNode(Node node, int i, RGenericVector tree, RGenericVector c_splits, Schema schema) {
    RIntegerVector splitVar = (RIntegerVector) tree.getValue(0);
    RDoubleVector splitCodePred = (RDoubleVector) tree.getValue(1);
    RIntegerVector leftNode = (RIntegerVector) tree.getValue(2);
    RIntegerVector rightNode = (RIntegerVector) tree.getValue(3);
    RIntegerVector missingNode = (RIntegerVector) tree.getValue(4);
    RDoubleVector prediction = (RDoubleVector) tree.getValue(7);
    Predicate missingPredicate;
    Predicate leftPredicate;
    Predicate rightPredicate;
    Integer var = splitVar.getValue(i);
    if (var != -1) {
        Feature feature = schema.getFeature(var);
        missingPredicate = createSimplePredicate(feature, SimplePredicate.Operator.IS_MISSING, null);
        Double split = splitCodePred.getValue(i);
        if (feature instanceof CategoricalFeature) {
            CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
            List<String> values = categoricalFeature.getValues();
            int index = ValueUtil.asInt(split);
            RIntegerVector c_split = (RIntegerVector) c_splits.getValue(index);
            List<Integer> splitValues = c_split.getValues();
            leftPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, splitValues, true));
            rightPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, splitValues, false));
        } else {
            ContinuousFeature continuousFeature = feature.toContinuousFeature();
            String value = ValueUtil.formatValue(split);
            leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_THAN, value);
            rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_OR_EQUAL, value);
        }
    } else {
        Double value = prediction.getValue(i);
        node.setScore(ValueUtil.formatValue(value));
        return;
    }
    Integer missing = missingNode.getValue(i);
    if (missing != -1) {
        Node missingChild = new Node().setId(String.valueOf(missing + 1)).setPredicate(missingPredicate);
        encodeNode(missingChild, missing, tree, c_splits, schema);
        node.addNodes(missingChild);
    }
    Integer left = leftNode.getValue(i);
    if (left != -1) {
        Node leftChild = new Node().setId(String.valueOf(left + 1)).setPredicate(leftPredicate);
        encodeNode(leftChild, left, tree, c_splits, schema);
        node.addNodes(leftChild);
    }
    Integer right = rightNode.getValue(i);
    if (right != -1) {
        Node rightChild = new Node().setId(String.valueOf(right + 1)).setPredicate(rightPredicate);
        encodeNode(rightChild, right, tree, c_splits, schema);
        node.addNodes(rightChild);
    }
}
Also used : Node(org.dmg.pmml.tree.Node) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Feature(org.jpmml.converter.Feature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) ContinuousFeature(org.jpmml.converter.ContinuousFeature)

Example 22 with Node

use of org.dmg.pmml.tree.Node in project jpmml-r by jpmml.

the class BinaryTreeConverter method encodeNode.

private void encodeNode(Node node, RGenericVector tree, Schema schema) {
    RIntegerVector nodeId = (RIntegerVector) tree.getValue("nodeID");
    RBooleanVector terminal = (RBooleanVector) tree.getValue("terminal");
    RGenericVector psplit = (RGenericVector) tree.getValue("psplit");
    RGenericVector ssplits = (RGenericVector) tree.getValue("ssplits");
    RDoubleVector prediction = (RDoubleVector) tree.getValue("prediction");
    RGenericVector left = (RGenericVector) tree.getValue("left");
    RGenericVector right = (RGenericVector) tree.getValue("right");
    node.setId(String.valueOf(nodeId.asScalar()));
    if ((Boolean.TRUE).equals(terminal.asScalar())) {
        node = encodeScore(node, prediction, schema);
        return;
    }
    RNumberVector<?> splitpoint = (RNumberVector<?>) psplit.getValue("splitpoint");
    RStringVector variableName = (RStringVector) psplit.getValue("variableName");
    if (ssplits.size() > 0) {
        throw new IllegalArgumentException();
    }
    Predicate leftPredicate;
    Predicate rightPredicate;
    FieldName name = FieldName.create(variableName.asScalar());
    Integer index = this.featureIndexes.get(name);
    if (index == null) {
        throw new IllegalArgumentException();
    }
    Feature feature = schema.getFeature(index);
    if (feature instanceof CategoricalFeature) {
        CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
        List<String> values = categoricalFeature.getValues();
        List<Integer> splitValues = (List<Integer>) splitpoint.getValues();
        leftPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, splitValues, true));
        rightPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, splitValues, false));
    } else {
        ContinuousFeature continuousFeature = feature.toContinuousFeature();
        String value = ValueUtil.formatValue((Double) splitpoint.asScalar());
        leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
        rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
    }
    Node leftChild = new Node().setPredicate(leftPredicate);
    encodeNode(leftChild, left, schema);
    Node rightChild = new Node().setPredicate(rightPredicate);
    encodeNode(rightChild, right, schema);
    node.addNodes(leftChild, rightChild);
}
Also used : Node(org.dmg.pmml.tree.Node) Feature(org.jpmml.converter.Feature) ContinuousFeature(org.jpmml.converter.ContinuousFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) ContinuousFeature(org.jpmml.converter.ContinuousFeature) ArrayList(java.util.ArrayList) List(java.util.List) FieldName(org.dmg.pmml.FieldName)

Example 23 with Node

use of org.dmg.pmml.tree.Node in project jpmml-sparkml by jpmml.

the class TreeModelUtil method encodeTreeModel.

public static TreeModel encodeTreeModel(org.apache.spark.ml.tree.Node node, PredicateManager predicateManager, MiningFunction miningFunction, Schema schema) {
    Node root = encodeNode(node, predicateManager, Collections.<FieldName, Set<String>>emptyMap(), miningFunction, schema).setPredicate(new True());
    TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
    String compact = TreeModelOptions.COMPACT;
    if (compact != null && Boolean.valueOf(compact)) {
        Visitor visitor = new TreeModelCompactor();
        visitor.applyTo(treeModel);
    }
    return treeModel;
}
Also used : TreeModel(org.dmg.pmml.tree.TreeModel) DecisionTreeModel(org.apache.spark.ml.tree.DecisionTreeModel) HashSet(java.util.HashSet) Set(java.util.Set) Visitor(org.dmg.pmml.Visitor) InternalNode(org.apache.spark.ml.tree.InternalNode) Node(org.dmg.pmml.tree.Node) LeafNode(org.apache.spark.ml.tree.LeafNode) True(org.dmg.pmml.True) TreeModelCompactor(org.jpmml.sparkml.visitors.TreeModelCompactor) FieldName(org.dmg.pmml.FieldName)

Example 24 with Node

use of org.dmg.pmml.tree.Node in project jpmml-sparkml by jpmml.

the class TreeModelCompactor method handleNodePop.

private void handleNodePop(Node node) {
    String score = node.getScore();
    Predicate predicate = node.getPredicate();
    if (predicate instanceof True) {
        Node parentNode = getParentNode();
        if (parentNode == null) {
            return;
        }
        String parentScore = parentNode.getScore();
        if (parentScore != null) {
            throw new IllegalArgumentException();
        }
        if ((MiningFunction.REGRESSION).equals(this.miningFunction)) {
            parentNode.setScore(score);
            List<Node> parentChildren = parentNode.getNodes();
            boolean success = parentChildren.remove(node);
            if (!success) {
                throw new IllegalArgumentException();
            }
            if (node.hasNodes()) {
                List<Node> children = node.getNodes();
                parentChildren.addAll(children);
            }
        } else if ((MiningFunction.CLASSIFICATION).equals(this.miningFunction)) {
            if (node.hasNodes()) {
                List<Node> parentChildren = parentNode.getNodes();
                boolean success = parentChildren.remove(node);
                if (!success) {
                    throw new IllegalArgumentException();
                }
                List<Node> children = node.getNodes();
                parentChildren.addAll(children);
            }
        } else {
            throw new IllegalArgumentException();
        }
    }
}
Also used : Node(org.dmg.pmml.tree.Node) True(org.dmg.pmml.True) List(java.util.List) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) SimpleSetPredicate(org.dmg.pmml.SimpleSetPredicate)

Example 25 with Node

use of org.dmg.pmml.tree.Node in project jpmml-sparkml by jpmml.

the class TreeModelCompactor method handleNodePush.

private void handleNodePush(Node node) {
    String id = node.getId();
    String score = node.getScore();
    if (id != null) {
        throw new IllegalArgumentException();
    }
    if (node.hasNodes()) {
        List<Node> children = node.getNodes();
        if (children.size() != 2 || score != null) {
            throw new IllegalArgumentException();
        }
        Node firstChild = children.get(0);
        Node secondChild = children.get(1);
        Predicate firstPredicate = firstChild.getPredicate();
        Predicate secondPredicate = secondChild.getPredicate();
        predicate: if (firstPredicate instanceof SimplePredicate && secondPredicate instanceof SimplePredicate) {
            SimplePredicate firstSimplePredicate = (SimplePredicate) firstPredicate;
            SimplePredicate secondSimplePredicate = (SimplePredicate) secondPredicate;
            SimplePredicate.Operator firstOperator = firstSimplePredicate.getOperator();
            SimplePredicate.Operator secondOperator = secondSimplePredicate.getOperator();
            if (!(firstSimplePredicate.getField()).equals(secondSimplePredicate.getField())) {
                throw new IllegalArgumentException();
            }
            if ((SimplePredicate.Operator.EQUAL).equals(firstOperator) && (SimplePredicate.Operator.EQUAL).equals(secondOperator)) {
                if (!isCategoricalField(firstSimplePredicate.getField())) {
                    break predicate;
                }
                secondChild.setPredicate(new True());
            } else {
                if (!(firstSimplePredicate.getValue()).equals(secondSimplePredicate.getValue())) {
                    throw new IllegalArgumentException();
                }
                if ((SimplePredicate.Operator.NOT_EQUAL).equals(firstOperator) && (SimplePredicate.Operator.EQUAL).equals(secondOperator)) {
                    swapChildren(children);
                    firstChild = children.get(0);
                    secondChild = children.get(1);
                } else if ((SimplePredicate.Operator.EQUAL).equals(firstOperator) && (SimplePredicate.Operator.NOT_EQUAL).equals(secondOperator)) {
                // Ignored
                } else if ((SimplePredicate.Operator.LESS_OR_EQUAL).equals(firstOperator) && (SimplePredicate.Operator.GREATER_THAN).equals(secondOperator)) {
                // Ignored
                } else {
                    throw new IllegalArgumentException();
                }
                secondChild.setPredicate(new True());
            }
        } else if (firstPredicate instanceof SimplePredicate && secondPredicate instanceof SimpleSetPredicate) {
            SimplePredicate simplePredicate = (SimplePredicate) firstPredicate;
            SimpleSetPredicate simpleSetPredicate = (SimpleSetPredicate) secondPredicate;
            if (!(simplePredicate.getField()).equals(simpleSetPredicate.getField()) || !(SimplePredicate.Operator.EQUAL).equals(simplePredicate.getOperator()) || !(SimpleSetPredicate.BooleanOperator.IS_IN).equals(simpleSetPredicate.getBooleanOperator())) {
                throw new IllegalArgumentException();
            }
            secondChild.setPredicate(addCategoricalField(simpleSetPredicate));
        } else if (firstPredicate instanceof SimpleSetPredicate && secondPredicate instanceof SimplePredicate) {
            SimpleSetPredicate simpleSetPredicate = (SimpleSetPredicate) firstPredicate;
            SimplePredicate simplePredicate = (SimplePredicate) secondPredicate;
            if (!(simpleSetPredicate.getField()).equals(simplePredicate.getField()) || !(SimpleSetPredicate.BooleanOperator.IS_IN).equals(simpleSetPredicate.getBooleanOperator()) || !(SimplePredicate.Operator.EQUAL).equals(simplePredicate.getOperator())) {
                throw new IllegalArgumentException();
            }
            swapChildren(children);
            firstChild = children.get(0);
            secondChild = children.get(1);
            secondChild.setPredicate(addCategoricalField(simpleSetPredicate));
        } else if (firstPredicate instanceof SimpleSetPredicate && secondPredicate instanceof SimpleSetPredicate) {
            SimpleSetPredicate firstSimpleSetPredicate = (SimpleSetPredicate) firstPredicate;
            SimpleSetPredicate secondSimpleSetPredicate = (SimpleSetPredicate) secondPredicate;
            if (!(firstSimpleSetPredicate.getField()).equals(secondSimpleSetPredicate.getField()) || !(SimpleSetPredicate.BooleanOperator.IS_IN).equals(firstSimpleSetPredicate.getBooleanOperator()) || !(SimpleSetPredicate.BooleanOperator.IS_IN).equals(secondSimpleSetPredicate.getBooleanOperator())) {
                throw new IllegalArgumentException();
            }
            secondChild.setPredicate(addCategoricalField(secondSimpleSetPredicate));
        } else {
            throw new IllegalArgumentException();
        }
    } else {
        if (score == null) {
            throw new IllegalArgumentException();
        }
    }
}
Also used : Node(org.dmg.pmml.tree.Node) True(org.dmg.pmml.True) SimplePredicate(org.dmg.pmml.SimplePredicate) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) SimpleSetPredicate(org.dmg.pmml.SimpleSetPredicate) SimpleSetPredicate(org.dmg.pmml.SimpleSetPredicate)

Aggregations

Node (org.dmg.pmml.tree.Node)40 LeafNode (org.dmg.pmml.tree.LeafNode)20 Predicate (org.dmg.pmml.Predicate)18 SimplePredicate (org.dmg.pmml.SimplePredicate)18 ClassifierNode (org.dmg.pmml.tree.ClassifierNode)17 BranchNode (org.dmg.pmml.tree.BranchNode)15 TreeModel (org.dmg.pmml.tree.TreeModel)12 ContinuousFeature (org.jpmml.converter.ContinuousFeature)12 CategoricalFeature (org.jpmml.converter.CategoricalFeature)11 Feature (org.jpmml.converter.Feature)11 ArrayList (java.util.ArrayList)8 List (java.util.List)7 CategoryManager (org.jpmml.converter.CategoryManager)6 Test (org.junit.Test)6 HashMap (java.util.HashMap)5 SimpleSetPredicate (org.dmg.pmml.SimpleSetPredicate)5 True (org.dmg.pmml.True)5 CategoricalLabel (org.jpmml.converter.CategoricalLabel)5 DATA_TYPE (org.kie.pmml.api.enums.DATA_TYPE)5 Field (org.dmg.pmml.Field)4