Search in sources :

Example 1 with ScoreDistribution

use of org.dmg.pmml.ScoreDistribution in project jpmml-sparkml by jpmml.

the class TreeModelUtil method encodeNode.

public static Node encodeNode(org.apache.spark.ml.tree.Node node, PredicateManager predicateManager, Map<FieldName, Set<String>> parentFieldValues, MiningFunction miningFunction, Schema schema) {
    if (node instanceof InternalNode) {
        InternalNode internalNode = (InternalNode) node;
        Map<FieldName, Set<String>> leftFieldValues = parentFieldValues;
        Map<FieldName, Set<String>> rightFieldValues = parentFieldValues;
        Predicate leftPredicate;
        Predicate rightPredicate;
        Split split = internalNode.split();
        Feature feature = schema.getFeature(split.featureIndex());
        if (split instanceof ContinuousSplit) {
            ContinuousSplit continuousSplit = (ContinuousSplit) split;
            double threshold = continuousSplit.threshold();
            if (feature instanceof BooleanFeature) {
                BooleanFeature booleanFeature = (BooleanFeature) feature;
                if (threshold != 0.5d) {
                    throw new IllegalArgumentException();
                }
                leftPredicate = predicateManager.createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(0));
                rightPredicate = predicateManager.createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(1));
            } else {
                ContinuousFeature continuousFeature = feature.toContinuousFeature();
                String value = ValueUtil.formatValue(threshold);
                leftPredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
                rightPredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
            }
        } else if (split instanceof CategoricalSplit) {
            CategoricalSplit categoricalSplit = (CategoricalSplit) split;
            double[] leftCategories = categoricalSplit.leftCategories();
            double[] rightCategories = categoricalSplit.rightCategories();
            if (feature instanceof BinaryFeature) {
                BinaryFeature binaryFeature = (BinaryFeature) feature;
                SimplePredicate.Operator leftOperator;
                SimplePredicate.Operator rightOperator;
                if (Arrays.equals(TRUE, leftCategories) && Arrays.equals(FALSE, rightCategories)) {
                    leftOperator = SimplePredicate.Operator.EQUAL;
                    rightOperator = SimplePredicate.Operator.NOT_EQUAL;
                } else if (Arrays.equals(FALSE, leftCategories) && Arrays.equals(TRUE, rightCategories)) {
                    leftOperator = SimplePredicate.Operator.NOT_EQUAL;
                    rightOperator = SimplePredicate.Operator.EQUAL;
                } else {
                    throw new IllegalArgumentException();
                }
                String value = ValueUtil.formatValue(binaryFeature.getValue());
                leftPredicate = predicateManager.createSimplePredicate(binaryFeature, leftOperator, value);
                rightPredicate = predicateManager.createSimplePredicate(binaryFeature, rightOperator, value);
            } else if (feature instanceof CategoricalFeature) {
                CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
                FieldName name = categoricalFeature.getName();
                List<String> values = categoricalFeature.getValues();
                if (values.size() != (leftCategories.length + rightCategories.length)) {
                    throw new IllegalArgumentException();
                }
                final Set<String> parentValues = parentFieldValues.get(name);
                com.google.common.base.Predicate<String> valueFilter = new com.google.common.base.Predicate<String>() {

                    @Override
                    public boolean apply(String value) {
                        if (parentValues != null) {
                            return parentValues.contains(value);
                        }
                        return true;
                    }
                };
                List<String> leftValues = selectValues(values, leftCategories, valueFilter);
                List<String> rightValues = selectValues(values, rightCategories, valueFilter);
                leftFieldValues = new HashMap<>(parentFieldValues);
                leftFieldValues.put(name, new HashSet<>(leftValues));
                rightFieldValues = new HashMap<>(parentFieldValues);
                rightFieldValues.put(name, new HashSet<>(rightValues));
                leftPredicate = predicateManager.createSimpleSetPredicate(categoricalFeature, leftValues);
                rightPredicate = predicateManager.createSimpleSetPredicate(categoricalFeature, rightValues);
            } else {
                throw new IllegalArgumentException();
            }
        } else {
            throw new IllegalArgumentException();
        }
        Node result = new Node();
        Node leftChild = encodeNode(internalNode.leftChild(), predicateManager, leftFieldValues, miningFunction, schema).setPredicate(leftPredicate);
        Node rightChild = encodeNode(internalNode.rightChild(), predicateManager, rightFieldValues, miningFunction, schema).setPredicate(rightPredicate);
        result.addNodes(leftChild, rightChild);
        return result;
    } else if (node instanceof LeafNode) {
        LeafNode leafNode = (LeafNode) node;
        Node result = new Node();
        switch(miningFunction) {
            case REGRESSION:
                {
                    String score = ValueUtil.formatValue(node.prediction());
                    result.setScore(score);
                }
                break;
            case CLASSIFICATION:
                {
                    CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
                    int index = ValueUtil.asInt(node.prediction());
                    result.setScore(categoricalLabel.getValue(index));
                    ImpurityCalculator impurityCalculator = node.impurityStats();
                    result.setRecordCount((double) impurityCalculator.count());
                    double[] stats = impurityCalculator.stats();
                    for (int i = 0; i < stats.length; i++) {
                        ScoreDistribution scoreDistribution = new ScoreDistribution(categoricalLabel.getValue(i), stats[i]);
                        result.addScoreDistributions(scoreDistribution);
                    }
                }
                break;
            default:
                throw new UnsupportedOperationException();
        }
        return result;
    } else {
        throw new IllegalArgumentException();
    }
}
Also used : HashSet(java.util.HashSet) Set(java.util.Set) InternalNode(org.apache.spark.ml.tree.InternalNode) Node(org.dmg.pmml.tree.Node) LeafNode(org.apache.spark.ml.tree.LeafNode) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Feature(org.jpmml.converter.Feature) BinaryFeature(org.jpmml.converter.BinaryFeature) BooleanFeature(org.jpmml.converter.BooleanFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) BooleanFeature(org.jpmml.converter.BooleanFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) ContinuousSplit(org.apache.spark.ml.tree.ContinuousSplit) LeafNode(org.apache.spark.ml.tree.LeafNode) FieldName(org.dmg.pmml.FieldName) BinaryFeature(org.jpmml.converter.BinaryFeature) ScoreDistribution(org.dmg.pmml.ScoreDistribution) ContinuousFeature(org.jpmml.converter.ContinuousFeature) ImpurityCalculator(org.apache.spark.mllib.tree.impurity.ImpurityCalculator) CategoricalSplit(org.apache.spark.ml.tree.CategoricalSplit) CategoricalLabel(org.jpmml.converter.CategoricalLabel) InternalNode(org.apache.spark.ml.tree.InternalNode) Split(org.apache.spark.ml.tree.Split) ContinuousSplit(org.apache.spark.ml.tree.ContinuousSplit) CategoricalSplit(org.apache.spark.ml.tree.CategoricalSplit)

Example 2 with ScoreDistribution

use of org.dmg.pmml.ScoreDistribution in project jpmml-r by jpmml.

the class BinaryTreeConverter method encodeClassificationScore.

private static Node encodeClassificationScore(Node node, RDoubleVector probabilities, Schema schema) {
    CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
    if (categoricalLabel.size() != probabilities.size()) {
        throw new IllegalArgumentException();
    }
    Double maxProbability = null;
    for (int i = 0; i < categoricalLabel.size(); i++) {
        String value = categoricalLabel.getValue(i);
        Double probability = probabilities.getValue(i);
        if (maxProbability == null || (maxProbability).compareTo(probability) < 0) {
            node.setScore(value);
            maxProbability = probability;
        }
        ScoreDistribution scoreDistribution = new ScoreDistribution(value, probability);
        node.addScoreDistributions(scoreDistribution);
    }
    return node;
}
Also used : ScoreDistribution(org.dmg.pmml.ScoreDistribution) CategoricalLabel(org.jpmml.converter.CategoricalLabel)

Example 3 with ScoreDistribution

use of org.dmg.pmml.ScoreDistribution in project jpmml-r by jpmml.

the class RangerConverter method encodeProbabilityForest.

private MiningModel encodeProbabilityForest(RGenericVector ranger, Schema schema) {
    RGenericVector forest = (RGenericVector) ranger.getValue("forest");
    final RStringVector levels = (RStringVector) forest.getValue("levels");
    CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
    ScoreEncoder scoreEncoder = new ScoreEncoder() {

        @Override
        public void encode(Node node, Number splitValue, RNumberVector<?> terminalClassCount) {
            if (splitValue.doubleValue() != 0d || (terminalClassCount == null || terminalClassCount.size() != levels.size())) {
                throw new IllegalArgumentException();
            }
            Double maxProbability = null;
            for (int i = 0; i < terminalClassCount.size(); i++) {
                String value = levels.getValue(i);
                Double probability = ValueUtil.asDouble(terminalClassCount.getValue(i));
                if (maxProbability == null || (maxProbability).compareTo(probability) < 0) {
                    node.setScore(value);
                    maxProbability = probability;
                }
                ScoreDistribution scoreDisctibution = new ScoreDistribution(value, probability);
                node.addScoreDistributions(scoreDisctibution);
            }
        }
    };
    List<TreeModel> treeModels = encodeForest(forest, MiningFunction.CLASSIFICATION, scoreEncoder, schema);
    MiningModel miningModel = new MiningModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel)).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.AVERAGE, treeModels)).setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, categoricalLabel));
    return miningModel;
}
Also used : Node(org.dmg.pmml.tree.Node) ScoreDistribution(org.dmg.pmml.ScoreDistribution) TreeModel(org.dmg.pmml.tree.TreeModel) MiningModel(org.dmg.pmml.mining.MiningModel) CategoricalLabel(org.jpmml.converter.CategoricalLabel)

Aggregations

ScoreDistribution (org.dmg.pmml.ScoreDistribution)3 CategoricalLabel (org.jpmml.converter.CategoricalLabel)3 Node (org.dmg.pmml.tree.Node)2 HashSet (java.util.HashSet)1 Set (java.util.Set)1 CategoricalSplit (org.apache.spark.ml.tree.CategoricalSplit)1 ContinuousSplit (org.apache.spark.ml.tree.ContinuousSplit)1 InternalNode (org.apache.spark.ml.tree.InternalNode)1 LeafNode (org.apache.spark.ml.tree.LeafNode)1 Split (org.apache.spark.ml.tree.Split)1 ImpurityCalculator (org.apache.spark.mllib.tree.impurity.ImpurityCalculator)1 FieldName (org.dmg.pmml.FieldName)1 Predicate (org.dmg.pmml.Predicate)1 SimplePredicate (org.dmg.pmml.SimplePredicate)1 MiningModel (org.dmg.pmml.mining.MiningModel)1 TreeModel (org.dmg.pmml.tree.TreeModel)1 BinaryFeature (org.jpmml.converter.BinaryFeature)1 BooleanFeature (org.jpmml.converter.BooleanFeature)1 CategoricalFeature (org.jpmml.converter.CategoricalFeature)1 ContinuousFeature (org.jpmml.converter.ContinuousFeature)1