Search in sources :

Example 1 with ClassifierNode

use of org.dmg.pmml.tree.ClassifierNode in project jpmml-r by jpmml.

the class PartyConverter method encodeNode.

private Node encodeNode(RGenericVector partyNode, Predicate predicate, RVector<?> response, RDoubleVector prob, Schema schema) {
    RIntegerVector id = partyNode.getIntegerElement("id");
    RGenericVector split = partyNode.getGenericElement("split");
    RGenericVector kids = partyNode.getGenericElement("kids");
    RGenericVector surrogates = partyNode.getGenericElement("surrogates");
    RGenericVector info = partyNode.getGenericElement("info");
    if (surrogates != null) {
        throw new IllegalArgumentException();
    }
    Label label = schema.getLabel();
    List<? extends Feature> features = schema.getFeatures();
    Node result;
    if (response instanceof RFactorVector) {
        result = new ClassifierNode(null, predicate);
    } else {
        if (kids == null) {
            result = new LeafNode(null, predicate);
        } else {
            result = new BranchNode(null, predicate);
        }
    }
    result.setId(Integer.valueOf(id.asScalar()));
    if (response instanceof RFactorVector) {
        RFactorVector factor = (RFactorVector) response;
        int index = id.asScalar() - 1;
        result.setScore(factor.getFactorValue(index));
        CategoricalLabel categoricalLabel = (CategoricalLabel) label;
        List<Double> probabilities = FortranMatrixUtil.getRow(prob.getValues(), response.size(), categoricalLabel.size(), index);
        List<ScoreDistribution> scoreDistributions = result.getScoreDistributions();
        for (int i = 0; i < categoricalLabel.size(); i++) {
            Object value = categoricalLabel.getValue(i);
            Double probability = probabilities.get(i);
            ScoreDistribution scoreDistribution = new ScoreDistribution(value, probability);
            scoreDistributions.add(scoreDistribution);
        }
    } else {
        result.setScore(response.getValue(id.asScalar() - 1));
    }
    if (kids == null) {
        return result;
    }
    RIntegerVector varid = split.getIntegerElement("varid");
    RDoubleVector breaks = split.getDoubleElement("breaks");
    RIntegerVector index = split.getIntegerElement("index");
    RBooleanVector right = split.getBooleanElement("right");
    Feature feature = features.get(varid.asScalar() - 1);
    if (breaks != null && index == null) {
        ContinuousFeature continuousFeature = (ContinuousFeature) feature;
        if (kids.size() != 2) {
            throw new IllegalArgumentException();
        }
        if (breaks.size() != 1) {
            throw new IllegalArgumentException();
        }
        Predicate leftPredicate;
        Predicate rightPredicate;
        Double value = breaks.asScalar();
        if (right.asScalar()) {
            leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
            rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
        } else {
            leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_THAN, value);
            rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_OR_EQUAL, value);
        }
        Node leftChild = encodeNode(kids.getGenericValue(0), leftPredicate, response, prob, schema);
        Node rightChild = encodeNode(kids.getGenericValue(1), rightPredicate, response, prob, schema);
        result.addNodes(leftChild, rightChild);
    } else if (breaks == null && index != null) {
        CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
        if (kids.size() < 2) {
            throw new IllegalArgumentException();
        }
        List<?> values = categoricalFeature.getValues();
        for (int i = 0; i < kids.size(); i++) {
            Predicate childPredicate;
            if (right.asScalar()) {
                childPredicate = createPredicate(categoricalFeature, selectValues(values, index, i + 1));
            } else {
                throw new IllegalArgumentException();
            }
            Node child = encodeNode(kids.getGenericValue(i), childPredicate, response, prob, schema);
            result.addNodes(child);
        }
    } else {
        throw new IllegalArgumentException();
    }
    return result;
}
Also used : BranchNode(org.dmg.pmml.tree.BranchNode) LeafNode(org.dmg.pmml.tree.LeafNode) Node(org.dmg.pmml.tree.Node) ClassifierNode(org.dmg.pmml.tree.ClassifierNode) CategoricalLabel(org.jpmml.converter.CategoricalLabel) Label(org.jpmml.converter.Label) Feature(org.jpmml.converter.Feature) ContinuousFeature(org.jpmml.converter.ContinuousFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) ScoreDistribution(org.dmg.pmml.ScoreDistribution) BranchNode(org.dmg.pmml.tree.BranchNode) ContinuousFeature(org.jpmml.converter.ContinuousFeature) CategoricalLabel(org.jpmml.converter.CategoricalLabel) LeafNode(org.dmg.pmml.tree.LeafNode) ClassifierNode(org.dmg.pmml.tree.ClassifierNode) ArrayList(java.util.ArrayList) List(java.util.List)

Example 2 with ClassifierNode

use of org.dmg.pmml.tree.ClassifierNode in project drools by kiegroup.

the class KiePMMLTreeModelNodeASTFactoryTest method isFinalLeaf.

@Test
public void isFinalLeaf() {
    Node node = new LeafNode();
    DATA_TYPE targetType = DATA_TYPE.STRING;
    KiePMMLTreeModelNodeASTFactory.factory(new HashMap<>(), Collections.emptyList(), TreeModel.NoTrueChildStrategy.RETURN_NULL_PREDICTION, targetType).isFinalLeaf(node);
    assertTrue(KiePMMLTreeModelNodeASTFactory.factory(new HashMap<>(), Collections.emptyList(), TreeModel.NoTrueChildStrategy.RETURN_NULL_PREDICTION, targetType).isFinalLeaf(node));
    node = new ClassifierNode();
    assertTrue(KiePMMLTreeModelNodeASTFactory.factory(new HashMap<>(), Collections.emptyList(), TreeModel.NoTrueChildStrategy.RETURN_NULL_PREDICTION, targetType).isFinalLeaf(node));
    node.addNodes(new LeafNode());
    assertFalse(KiePMMLTreeModelNodeASTFactory.factory(new HashMap<>(), Collections.emptyList(), TreeModel.NoTrueChildStrategy.RETURN_NULL_PREDICTION, targetType).isFinalLeaf(node));
}
Also used : HashMap(java.util.HashMap) LeafNode(org.dmg.pmml.tree.LeafNode) Node(org.dmg.pmml.tree.Node) ClassifierNode(org.dmg.pmml.tree.ClassifierNode) LeafNode(org.dmg.pmml.tree.LeafNode) ClassifierNode(org.dmg.pmml.tree.ClassifierNode) DATA_TYPE(org.kie.pmml.api.enums.DATA_TYPE) Test(org.junit.Test)

Example 3 with ClassifierNode

use of org.dmg.pmml.tree.ClassifierNode in project jpmml-r by jpmml.

the class RPartConverter method encodeClassification.

private TreeModel encodeClassification(RGenericVector frame, RIntegerVector rowNames, RVector<?> var, RIntegerVector n, int[][] splitInfo, RNumberVector<?> splits, RIntegerVector csplit, Schema schema) {
    RDoubleVector yval2 = frame.getDoubleElement("yval2");
    CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
    List<?> categories = categoricalLabel.getValues();
    boolean hasScoreDistribution = hasScoreDistribution();
    ScoreEncoder scoreEncoder = new ScoreEncoder() {

        private List<Integer> classes = null;

        private List<List<? extends Number>> recordCounts = null;

        {
            int rows = rowNames.size();
            int columns = 1 + (2 * categories.size()) + 1;
            List<Integer> classes = ValueUtil.asIntegers(FortranMatrixUtil.getColumn(yval2.getValues(), rows, columns, 0));
            this.classes = new ArrayList<>(classes);
            if (hasScoreDistribution) {
                this.recordCounts = new ArrayList<>();
                for (int i = 0; i < categories.size(); i++) {
                    List<? extends Number> recordCounts = FortranMatrixUtil.getColumn(yval2.getValues(), rows, columns, 1 + i);
                    this.recordCounts.add(new ArrayList<>(recordCounts));
                }
            }
        }

        @Override
        public Node encode(Node node, int offset) {
            Object score = categories.get(this.classes.get(offset) - 1);
            Integer recordCount = n.getValue(offset);
            node.setScore(score).setRecordCount(recordCount);
            if (hasScoreDistribution) {
                node = new ClassifierNode(node);
                List<ScoreDistribution> scoreDistributions = node.getScoreDistributions();
                for (int i = 0; i < categories.size(); i++) {
                    List<? extends Number> recordCounts = this.recordCounts.get(i);
                    ScoreDistribution scoreDistribution = new ScoreDistribution().setValue(categories.get(i)).setRecordCount(recordCounts.get(offset));
                    scoreDistributions.add(scoreDistribution);
                }
            }
            return node;
        }
    };
    Node root = encodeNode(True.INSTANCE, 1, rowNames, var, n, splitInfo, splits, csplit, scoreEncoder, schema);
    TreeModel treeModel = new TreeModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(schema.getLabel()), root);
    if (hasScoreDistribution) {
        treeModel.setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, categoricalLabel));
    }
    return configureTreeModel(treeModel);
}
Also used : Node(org.dmg.pmml.tree.Node) ClassifierNode(org.dmg.pmml.tree.ClassifierNode) CountingLeafNode(org.dmg.pmml.tree.CountingLeafNode) CountingBranchNode(org.dmg.pmml.tree.CountingBranchNode) ScoreDistribution(org.dmg.pmml.ScoreDistribution) TreeModel(org.dmg.pmml.tree.TreeModel) CategoricalLabel(org.jpmml.converter.CategoricalLabel) ArrayList(java.util.ArrayList) List(java.util.List) ClassifierNode(org.dmg.pmml.tree.ClassifierNode)

Example 4 with ClassifierNode

use of org.dmg.pmml.tree.ClassifierNode in project jpmml-r by jpmml.

the class RangerConverter method encodeProbabilityForest.

private MiningModel encodeProbabilityForest(RGenericVector forest, Schema schema) {
    RStringVector levels = forest.getStringElement("levels");
    CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
    ScoreEncoder scoreEncoder = new ScoreEncoder() {

        @Override
        public Node encode(Node node, Number splitValue, RNumberVector<?> terminalClassCount) {
            if (splitValue.doubleValue() != 0d || (terminalClassCount == null || terminalClassCount.size() != levels.size())) {
                throw new IllegalArgumentException();
            }
            node = new ClassifierNode(node);
            List<ScoreDistribution> scoreDistributions = node.getScoreDistributions();
            Number maxProbability = null;
            for (int i = 0; i < terminalClassCount.size(); i++) {
                String value = levels.getValue(i);
                Number probability = terminalClassCount.getValue(i);
                if (maxProbability == null || ((Comparable) maxProbability).compareTo(probability) < 0) {
                    node.setScore(value);
                    maxProbability = probability;
                }
                ScoreDistribution scoreDistribution = new ScoreDistribution(value, probability);
                scoreDistributions.add(scoreDistribution);
            }
            return node;
        }
    };
    List<TreeModel> treeModels = encodeForest(forest, MiningFunction.CLASSIFICATION, scoreEncoder, schema);
    MiningModel miningModel = new MiningModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel)).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.AVERAGE, treeModels)).setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, categoricalLabel));
    return miningModel;
}
Also used : Node(org.dmg.pmml.tree.Node) ClassifierNode(org.dmg.pmml.tree.ClassifierNode) BranchNode(org.dmg.pmml.tree.BranchNode) LeafNode(org.dmg.pmml.tree.LeafNode) ScoreDistribution(org.dmg.pmml.ScoreDistribution) TreeModel(org.dmg.pmml.tree.TreeModel) MiningModel(org.dmg.pmml.mining.MiningModel) CategoricalLabel(org.jpmml.converter.CategoricalLabel) ClassifierNode(org.dmg.pmml.tree.ClassifierNode)

Example 5 with ClassifierNode

use of org.dmg.pmml.tree.ClassifierNode in project jpmml-r by jpmml.

the class BinaryTreeConverter method encodeClassificationScore.

private static Node encodeClassificationScore(Node node, RDoubleVector probabilities, Schema schema) {
    CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
    SchemaUtil.checkSize(probabilities.size(), categoricalLabel);
    node = new ClassifierNode(node);
    List<ScoreDistribution> scoreDistributions = node.getScoreDistributions();
    Double maxProbability = null;
    for (int i = 0; i < categoricalLabel.size(); i++) {
        Object value = categoricalLabel.getValue(i);
        Double probability = probabilities.getValue(i);
        if (maxProbability == null || (maxProbability).compareTo(probability) < 0) {
            node.setScore(value);
            maxProbability = probability;
        }
        ScoreDistribution scoreDistribution = new ScoreDistribution(value, probability);
        scoreDistributions.add(scoreDistribution);
    }
    return node;
}
Also used : ScoreDistribution(org.dmg.pmml.ScoreDistribution) CategoricalLabel(org.jpmml.converter.CategoricalLabel) ClassifierNode(org.dmg.pmml.tree.ClassifierNode)

Aggregations

ClassifierNode (org.dmg.pmml.tree.ClassifierNode)6 Node (org.dmg.pmml.tree.Node)5 CategoricalLabel (org.jpmml.converter.CategoricalLabel)5 ScoreDistribution (org.dmg.pmml.ScoreDistribution)4 LeafNode (org.dmg.pmml.tree.LeafNode)4 BranchNode (org.dmg.pmml.tree.BranchNode)3 TreeModel (org.dmg.pmml.tree.TreeModel)3 ArrayList (java.util.ArrayList)2 List (java.util.List)2 HashMap (java.util.HashMap)1 DecisionTreeClassificationModel (org.apache.spark.ml.classification.DecisionTreeClassificationModel)1 DecisionTreeRegressionModel (org.apache.spark.ml.regression.DecisionTreeRegressionModel)1 DecisionTreeModel (org.apache.spark.ml.tree.DecisionTreeModel)1 ImpurityCalculator (org.apache.spark.mllib.tree.impurity.ImpurityCalculator)1 Predicate (org.dmg.pmml.Predicate)1 SimplePredicate (org.dmg.pmml.SimplePredicate)1 Visitor (org.dmg.pmml.Visitor)1 MiningModel (org.dmg.pmml.mining.MiningModel)1 CountingBranchNode (org.dmg.pmml.tree.CountingBranchNode)1 CountingLeafNode (org.dmg.pmml.tree.CountingLeafNode)1