Search in sources :

Example 1 with Node

use of org.dmg.pmml.tree.Node in project jpmml-r by jpmml.

the class GBMConverter method encodeTreeModel.

private TreeModel encodeTreeModel(MiningFunction miningFunction, RGenericVector tree, RGenericVector c_splits, Schema schema) {
    Node root = new Node().setId("1").setPredicate(new True());
    encodeNode(root, 0, tree, c_splits, schema);
    TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.MULTI_SPLIT);
    return treeModel;
}
Also used : TreeModel(org.dmg.pmml.tree.TreeModel) Node(org.dmg.pmml.tree.Node) True(org.dmg.pmml.True)

Example 2 with Node

use of org.dmg.pmml.tree.Node in project jpmml-r by jpmml.

the class IForestConverter method encodeTreeModel.

private TreeModel encodeTreeModel(RGenericVector trees, int index, Schema schema) {
    RIntegerVector nrnodes = (RIntegerVector) trees.getValue("nrnodes");
    RIntegerVector ntree = (RIntegerVector) trees.getValue("ntree");
    RIntegerVector nodeStatus = (RIntegerVector) trees.getValue("nodeStatus");
    RIntegerVector leftDaughter = (RIntegerVector) trees.getValue("lDaughter");
    RIntegerVector rightDaughter = (RIntegerVector) trees.getValue("rDaughter");
    RIntegerVector splitAtt = (RIntegerVector) trees.getValue("splitAtt");
    RDoubleVector splitPoint = (RDoubleVector) trees.getValue("splitPoint");
    RIntegerVector nSam = (RIntegerVector) trees.getValue("nSam");
    int rows = nrnodes.asScalar();
    int columns = ntree.asScalar();
    Node root = new Node().setPredicate(new True());
    encodeNode(root, 0, 0, FortranMatrixUtil.getColumn(nodeStatus.getValues(), rows, columns, index), FortranMatrixUtil.getColumn(nSam.getValues(), rows, columns, index), FortranMatrixUtil.getColumn(leftDaughter.getValues(), rows, columns, index), FortranMatrixUtil.getColumn(rightDaughter.getValues(), rows, columns, index), FortranMatrixUtil.getColumn(splitAtt.getValues(), rows, columns, index), FortranMatrixUtil.getColumn(splitPoint.getValues(), rows, columns, index), schema);
    TreeModel treeModel = new TreeModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel()), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
    return treeModel;
}
Also used : TreeModel(org.dmg.pmml.tree.TreeModel) Node(org.dmg.pmml.tree.Node) True(org.dmg.pmml.True)

Example 3 with Node

use of org.dmg.pmml.tree.Node in project jpmml-r by jpmml.

the class IForestConverter method encodeNode.

private void encodeNode(Node node, int index, int depth, List<Integer> nodeStatus, List<Integer> nodeSize, List<Integer> leftDaughter, List<Integer> rightDaughter, List<Integer> splitAtt, List<Double> splitValue, Schema schema) {
    int status = nodeStatus.get(index);
    int size = nodeSize.get(index);
    node.setId(String.valueOf(index + 1));
    // Interior node
    if (status == -3) {
        int att = splitAtt.get(index);
        ContinuousFeature feature = (ContinuousFeature) schema.getFeature(att - 1);
        String value = ValueUtil.formatValue(splitValue.get(index));
        Predicate leftPredicate = createSimplePredicate(feature, SimplePredicate.Operator.LESS_THAN, value);
        Node leftChild = new Node().setPredicate(leftPredicate);
        int leftIndex = (leftDaughter.get(index) - 1);
        encodeNode(leftChild, leftIndex, depth + 1, nodeStatus, nodeSize, leftDaughter, rightDaughter, splitAtt, splitValue, schema);
        Predicate rightPredicate = createSimplePredicate(feature, SimplePredicate.Operator.GREATER_OR_EQUAL, value);
        Node rightChild = new Node().setPredicate(rightPredicate);
        int rightIndex = (rightDaughter.get(index) - 1);
        encodeNode(rightChild, rightIndex, depth + 1, nodeStatus, nodeSize, leftDaughter, rightDaughter, splitAtt, splitValue, schema);
        node.addNodes(leftChild, rightChild);
    } else // Terminal node
    if (status == -1) {
        node.setScore(ValueUtil.formatValue(depth + avgPathLength(size)));
    } else {
        throw new IllegalArgumentException();
    }
}
Also used : ContinuousFeature(org.jpmml.converter.ContinuousFeature) Node(org.dmg.pmml.tree.Node) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate)

Example 4 with Node

use of org.dmg.pmml.tree.Node in project jpmml-r by jpmml.

the class RandomForestConverter method encodeNode.

private <P extends Number> void encodeNode(Node node, int i, ScoreEncoder<P> scoreEncoder, List<? extends Number> leftDaughter, List<? extends Number> rightDaughter, List<? extends Number> bestvar, List<Double> xbestsplit, List<P> nodepred, Schema schema) {
    Predicate leftPredicate;
    Predicate rightPredicate;
    int var = ValueUtil.asInt(bestvar.get(i));
    if (var != 0) {
        Feature feature = schema.getFeature(var - 1);
        Double split = xbestsplit.get(i);
        if (feature instanceof BooleanFeature) {
            BooleanFeature booleanFeature = (BooleanFeature) feature;
            if (split != 0.5d) {
                throw new IllegalArgumentException();
            }
            leftPredicate = createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(0));
            rightPredicate = createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(1));
        } else if (feature instanceof CategoricalFeature) {
            CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
            List<String> values = categoricalFeature.getValues();
            leftPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, split, true));
            rightPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, split, false));
        } else {
            ContinuousFeature continuousFeature = feature.toContinuousFeature();
            String value = ValueUtil.formatValue(split);
            leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
            rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
        }
    } else {
        P prediction = nodepred.get(i);
        node.setScore(scoreEncoder.encode(prediction));
        return;
    }
    int left = ValueUtil.asInt(leftDaughter.get(i));
    if (left != 0) {
        Node leftChild = new Node().setId(String.valueOf(left)).setPredicate(leftPredicate);
        encodeNode(leftChild, left - 1, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit, nodepred, schema);
        node.addNodes(leftChild);
    }
    int right = ValueUtil.asInt(rightDaughter.get(i));
    if (right != 0) {
        Node rightChild = new Node().setId(String.valueOf(right)).setPredicate(rightPredicate);
        encodeNode(rightChild, right - 1, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit, nodepred, schema);
        node.addNodes(rightChild);
    }
}
Also used : ContinuousFeature(org.jpmml.converter.ContinuousFeature) Node(org.dmg.pmml.tree.Node) ArrayList(java.util.ArrayList) List(java.util.List) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Feature(org.jpmml.converter.Feature) BooleanFeature(org.jpmml.converter.BooleanFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) BooleanFeature(org.jpmml.converter.BooleanFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate)

Example 5 with Node

use of org.dmg.pmml.tree.Node in project jpmml-r by jpmml.

the class RangerConverter method encodeClassification.

private MiningModel encodeClassification(RGenericVector ranger, Schema schema) {
    RGenericVector forest = (RGenericVector) ranger.getValue("forest");
    final RStringVector levels = (RStringVector) forest.getValue("levels");
    ScoreEncoder scoreEncoder = new ScoreEncoder() {

        @Override
        public void encode(Node node, Number splitValue, RNumberVector<?> terminalClassCount) {
            int index = ValueUtil.asInt(splitValue);
            if (terminalClassCount != null) {
                throw new IllegalArgumentException();
            }
            node.setScore(levels.getValue(index - 1));
        }
    };
    List<TreeModel> treeModels = encodeForest(forest, MiningFunction.CLASSIFICATION, scoreEncoder, schema);
    MiningModel miningModel = new MiningModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(schema.getLabel())).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.MAJORITY_VOTE, treeModels));
    return miningModel;
}
Also used : TreeModel(org.dmg.pmml.tree.TreeModel) MiningModel(org.dmg.pmml.mining.MiningModel) Node(org.dmg.pmml.tree.Node)

Aggregations

Node (org.dmg.pmml.tree.Node)27 Predicate (org.dmg.pmml.Predicate)9 SimplePredicate (org.dmg.pmml.SimplePredicate)9 True (org.dmg.pmml.True)9 TreeModel (org.dmg.pmml.tree.TreeModel)9 ContinuousFeature (org.jpmml.converter.ContinuousFeature)6 Test (org.junit.Test)6 HashMap (java.util.HashMap)5 ClassifierNode (org.dmg.pmml.tree.ClassifierNode)5 LeafNode (org.dmg.pmml.tree.LeafNode)5 CategoricalFeature (org.jpmml.converter.CategoricalFeature)5 Feature (org.jpmml.converter.Feature)5 DATA_TYPE (org.kie.pmml.api.enums.DATA_TYPE)5 ArrayList (java.util.ArrayList)4 Field (org.dmg.pmml.Field)4 FieldName (org.dmg.pmml.FieldName)4 KiePMMLOriginalTypeGeneratedType (org.kie.pmml.models.drools.tuples.KiePMMLOriginalTypeGeneratedType)4 List (java.util.List)3 SimpleSetPredicate (org.dmg.pmml.SimpleSetPredicate)3 MiningModel (org.dmg.pmml.mining.MiningModel)3