Search in sources :

Example 36 with Node

use of org.dmg.pmml.tree.Node in project jpmml-r by jpmml.

the class RangerConverter method encodeRegression.

private MiningModel encodeRegression(RGenericVector forest, Schema schema) {
    ScoreEncoder scoreEncoder = new ScoreEncoder() {

        @Override
        public Node encode(Node node, Number splitValue, RNumberVector<?> terminalClassCount) {
            node.setScore(splitValue);
            return node;
        }
    };
    List<TreeModel> treeModels = encodeForest(forest, MiningFunction.REGRESSION, scoreEncoder, schema);
    MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel())).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.AVERAGE, treeModels));
    return miningModel;
}
Also used : TreeModel(org.dmg.pmml.tree.TreeModel) MiningModel(org.dmg.pmml.mining.MiningModel) Node(org.dmg.pmml.tree.Node) ClassifierNode(org.dmg.pmml.tree.ClassifierNode) BranchNode(org.dmg.pmml.tree.BranchNode) LeafNode(org.dmg.pmml.tree.LeafNode)

Example 37 with Node

use of org.dmg.pmml.tree.Node in project jpmml-r by jpmml.

the class RandomForestCompactor method isDefinedField.

private boolean isDefinedField(HasFieldReference<?> hasFieldReference) {
    String name = hasFieldReference.requireField();
    Node ancestorNode = getAncestorNode(node -> hasFieldReference(node.requirePredicate(), name));
    return (ancestorNode != null);
}
Also used : Node(org.dmg.pmml.tree.Node)

Example 38 with Node

use of org.dmg.pmml.tree.Node in project jpmml-r by jpmml.

the class RandomForestCompactor method enterNode.

@Override
public void enterNode(Node node) {
    Object id = node.getId();
    Object score = node.getScore();
    if (id == null) {
        throw new IllegalArgumentException();
    }
    if (node.hasNodes()) {
        List<Node> children = node.getNodes();
        if (children.size() != 2 || score != null) {
            throw new IllegalArgumentException();
        }
        Node firstChild = children.get(0);
        Node secondChild = children.get(1);
        Predicate firstPredicate = firstChild.requirePredicate();
        Predicate secondPredicate = secondChild.requirePredicate();
        checkFieldReference(firstPredicate, secondPredicate);
        boolean update = isDefinedField((HasFieldReference<?>) firstPredicate);
        if (hasOperator(firstPredicate, SimplePredicate.Operator.EQUAL) && hasOperator(secondPredicate, SimplePredicate.Operator.EQUAL)) {
        // Ignored
        } else if (hasOperator(firstPredicate, SimplePredicate.Operator.LESS_OR_EQUAL) && hasOperator(secondPredicate, SimplePredicate.Operator.GREATER_THAN)) {
            update = true;
        } else if (hasOperator(firstPredicate, SimplePredicate.Operator.EQUAL) && hasBooleanOperator(secondPredicate, SimpleSetPredicate.BooleanOperator.IS_IN)) {
        // Ignored
        } else if (hasBooleanOperator(firstPredicate, SimpleSetPredicate.BooleanOperator.IS_IN) && hasOperator(secondPredicate, SimplePredicate.Operator.EQUAL)) {
            if (update) {
                children = swapChildren(node);
                firstChild = children.get(0);
                secondChild = children.get(1);
            }
        } else if (hasBooleanOperator(firstPredicate, SimpleSetPredicate.BooleanOperator.IS_IN) && hasBooleanOperator(secondPredicate, SimpleSetPredicate.BooleanOperator.IS_IN)) {
        // Ignored
        } else {
            throw new IllegalArgumentException();
        }
        if (update) {
            secondChild.setPredicate(True.INSTANCE);
        }
    } else {
        if (score == null) {
            throw new IllegalArgumentException();
        }
    }
    node.setId(null);
}
Also used : Node(org.dmg.pmml.tree.Node) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) SimpleSetPredicate(org.dmg.pmml.SimpleSetPredicate)

Example 39 with Node

use of org.dmg.pmml.tree.Node in project jpmml-r by jpmml.

the class BinaryTreeConverter method encodeNode.

private Node encodeNode(RGenericVector tree, Predicate predicate, Schema schema) {
    RIntegerVector nodeId = tree.getIntegerElement("nodeID");
    RBooleanVector terminal = tree.getBooleanElement("terminal");
    RGenericVector psplit = tree.getGenericElement("psplit");
    RGenericVector ssplits = tree.getGenericElement("ssplits");
    RDoubleVector prediction = tree.getDoubleElement("prediction");
    RGenericVector left = tree.getGenericElement("left");
    RGenericVector right = tree.getGenericElement("right");
    Integer id = nodeId.asScalar();
    if ((Boolean.TRUE).equals(terminal.asScalar())) {
        Node result = new LeafNode(null, predicate).setId(id);
        return encodeScore(result, prediction, schema);
    }
    RNumberVector<?> splitpoint = psplit.getNumericElement("splitpoint");
    RStringVector variableName = psplit.getStringElement("variableName");
    if (ssplits.size() > 0) {
        throw new IllegalArgumentException();
    }
    Predicate leftPredicate;
    Predicate rightPredicate;
    String name = variableName.asScalar();
    Integer index = this.featureIndexes.get(name);
    if (index == null) {
        throw new IllegalArgumentException();
    }
    Feature feature = schema.getFeature(index);
    if (feature instanceof CategoricalFeature) {
        CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
        List<?> values = categoricalFeature.getValues();
        List<Integer> splitValues = (List<Integer>) splitpoint.getValues();
        leftPredicate = createPredicate(categoricalFeature, selectValues(values, splitValues, true));
        rightPredicate = createPredicate(categoricalFeature, selectValues(values, splitValues, false));
    } else {
        ContinuousFeature continuousFeature = feature.toContinuousFeature();
        Number value = splitpoint.asScalar();
        leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
        rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
    }
    Node leftChild = encodeNode(left, leftPredicate, schema);
    Node rightChild = encodeNode(right, rightPredicate, schema);
    Node result = new BranchNode(null, predicate).setId(id).addNodes(leftChild, rightChild);
    return result;
}
Also used : Node(org.dmg.pmml.tree.Node) ClassifierNode(org.dmg.pmml.tree.ClassifierNode) BranchNode(org.dmg.pmml.tree.BranchNode) LeafNode(org.dmg.pmml.tree.LeafNode) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Feature(org.jpmml.converter.Feature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) BranchNode(org.dmg.pmml.tree.BranchNode) ContinuousFeature(org.jpmml.converter.ContinuousFeature) LeafNode(org.dmg.pmml.tree.LeafNode) ArrayList(java.util.ArrayList) List(java.util.List)

Example 40 with Node

use of org.dmg.pmml.tree.Node in project jpmml-r by jpmml.

the class BinaryTreeConverter method encodeTreeModel.

private TreeModel encodeTreeModel(RGenericVector tree, Schema schema) {
    Node root = encodeNode(tree, True.INSTANCE, schema);
    TreeModel treeModel = new TreeModel(this.miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
    return treeModel;
}
Also used : TreeModel(org.dmg.pmml.tree.TreeModel) Node(org.dmg.pmml.tree.Node) ClassifierNode(org.dmg.pmml.tree.ClassifierNode) BranchNode(org.dmg.pmml.tree.BranchNode) LeafNode(org.dmg.pmml.tree.LeafNode)

Aggregations

Node (org.dmg.pmml.tree.Node)40 LeafNode (org.dmg.pmml.tree.LeafNode)20 Predicate (org.dmg.pmml.Predicate)18 SimplePredicate (org.dmg.pmml.SimplePredicate)18 ClassifierNode (org.dmg.pmml.tree.ClassifierNode)17 BranchNode (org.dmg.pmml.tree.BranchNode)15 TreeModel (org.dmg.pmml.tree.TreeModel)12 ContinuousFeature (org.jpmml.converter.ContinuousFeature)12 CategoricalFeature (org.jpmml.converter.CategoricalFeature)11 Feature (org.jpmml.converter.Feature)11 ArrayList (java.util.ArrayList)8 List (java.util.List)7 CategoryManager (org.jpmml.converter.CategoryManager)6 Test (org.junit.Test)6 HashMap (java.util.HashMap)5 SimpleSetPredicate (org.dmg.pmml.SimpleSetPredicate)5 True (org.dmg.pmml.True)5 CategoricalLabel (org.jpmml.converter.CategoricalLabel)5 DATA_TYPE (org.kie.pmml.api.enums.DATA_TYPE)5 Field (org.dmg.pmml.Field)4