Search in sources :

Example 6 with ClassifierNode

use of org.dmg.pmml.tree.ClassifierNode in project jpmml-sparkml by jpmml.

the class TreeModelUtil method encodeDecisionTree.

private static <M extends Model<M> & DecisionTreeModel> TreeModel encodeDecisionTree(ModelConverter<?> converter, M model, PredicateManager predicateManager, ScoreDistributionManager scoreDistributionManager, Schema schema) {
    TreeModel treeModel;
    if (model instanceof DecisionTreeRegressionModel) {
        ScoreEncoder scoreEncoder = new ScoreEncoder() {

            @Override
            public Node encode(Node node, org.apache.spark.ml.tree.LeafNode leafNode) {
                node.setScore(leafNode.prediction());
                return node;
            }
        };
        treeModel = encodeTreeModel(MiningFunction.REGRESSION, scoreEncoder, model, predicateManager, schema);
    } else if (model instanceof DecisionTreeClassificationModel) {
        ScoreEncoder scoreEncoder = new ScoreEncoder() {

            private CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();

            @Override
            public Node encode(Node node, org.apache.spark.ml.tree.LeafNode leafNode) {
                node = new ClassifierNode(null, node.getPredicate());
                int index = ValueUtil.asInt(leafNode.prediction());
                node.setScore(this.categoricalLabel.getValue(index));
                ImpurityCalculator impurityCalculator = leafNode.impurityStats();
                node.setRecordCount(ValueUtil.narrow(impurityCalculator.count()));
                scoreDistributionManager.addScoreDistributions(node, this.categoricalLabel.getValues(), impurityCalculator.stats());
                return node;
            }
        };
        treeModel = encodeTreeModel(MiningFunction.CLASSIFICATION, scoreEncoder, model, predicateManager, schema);
    } else {
        throw new IllegalArgumentException();
    }
    Boolean compact = (Boolean) converter.getOption(HasTreeOptions.OPTION_COMPACT, Boolean.TRUE);
    if (compact != null && compact) {
        Visitor visitor = new TreeModelCompactor();
        visitor.applyTo(treeModel);
    }
    return treeModel;
}
Also used : DecisionTreeRegressionModel(org.apache.spark.ml.regression.DecisionTreeRegressionModel) Visitor(org.dmg.pmml.Visitor) Node(org.dmg.pmml.tree.Node) ClassifierNode(org.dmg.pmml.tree.ClassifierNode) BranchNode(org.dmg.pmml.tree.BranchNode) LeafNode(org.dmg.pmml.tree.LeafNode) DecisionTreeClassificationModel(org.apache.spark.ml.classification.DecisionTreeClassificationModel) TreeModelCompactor(org.jpmml.sparkml.visitors.TreeModelCompactor) TreeModel(org.dmg.pmml.tree.TreeModel) DecisionTreeModel(org.apache.spark.ml.tree.DecisionTreeModel) ImpurityCalculator(org.apache.spark.mllib.tree.impurity.ImpurityCalculator) CategoricalLabel(org.jpmml.converter.CategoricalLabel) LeafNode(org.dmg.pmml.tree.LeafNode) ClassifierNode(org.dmg.pmml.tree.ClassifierNode)

Example 7 with ClassifierNode

use of org.dmg.pmml.tree.ClassifierNode in project jpmml-sklearn by jpmml.

the class DummyClassifier method encodeModel.

@Override
public TreeModel encodeModel(Schema schema) {
    List<?> classes = getClasses();
    List<? extends Number> classPrior = getClassPrior();
    Object constant = getConstant();
    String strategy = getStrategy();
    ClassDictUtil.checkSize(classes, classPrior);
    CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
    int index;
    double[] probabilities;
    switch(strategy) {
        case "constant":
            {
                index = classes.indexOf(constant);
                if (index < 0) {
                    throw new IllegalArgumentException();
                }
                probabilities = new double[classes.size()];
                probabilities[index] = 1d;
            }
            break;
        case "most_frequent":
            {
                index = indexOfMax(classPrior);
                probabilities = new double[classes.size()];
                probabilities[index] = 1d;
            }
            break;
        case "prior":
            {
                index = indexOfMax(classPrior);
                probabilities = Doubles.toArray(classPrior);
            }
            break;
        default:
            throw new IllegalArgumentException(strategy);
    }
    Node root = new ClassifierNode(categoricalLabel.getValue(index), True.INSTANCE);
    ScoreDistributionManager scoreDistributionManager = new ScoreDistributionManager();
    scoreDistributionManager.addScoreDistributions(root, categoricalLabel.getValues(), probabilities);
    TreeModel treeModel = new TreeModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel), root).setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, categoricalLabel));
    return treeModel;
}
Also used : ScoreDistributionManager(org.jpmml.converter.ScoreDistributionManager) TreeModel(org.dmg.pmml.tree.TreeModel) CategoricalLabel(org.jpmml.converter.CategoricalLabel) Node(org.dmg.pmml.tree.Node) ClassifierNode(org.dmg.pmml.tree.ClassifierNode) ClassifierNode(org.dmg.pmml.tree.ClassifierNode)

Example 8 with ClassifierNode

use of org.dmg.pmml.tree.ClassifierNode in project jpmml-sklearn by jpmml.

the class TreeUtil method encodeNode.

private static Node encodeNode(int index, Predicate predicate, MiningFunction miningFunction, boolean numeric, int[] leftChildren, int[] rightChildren, int[] features, double[] thresholds, double[] values, CategoryManager categoryManager, PredicateManager predicateManager, ScoreDistributionManager scoreDistributionManager, Schema schema) {
    Integer id = Integer.valueOf(index);
    int featureIndex = features[index];
    // A non-leaf (binary split) node
    if (featureIndex >= 0) {
        Feature feature = schema.getFeature(featureIndex);
        double threshold = thresholds[index];
        CategoryManager leftCategoryManager = categoryManager;
        CategoryManager rightCategoryManager = categoryManager;
        Predicate leftPredicate;
        Predicate rightPredicate;
        if (feature instanceof BinaryFeature) {
            BinaryFeature binaryFeature = (BinaryFeature) feature;
            if (threshold < 0 || threshold > 1) {
                throw new IllegalArgumentException();
            }
            Object value = binaryFeature.getValue();
            leftPredicate = predicateManager.createSimplePredicate(binaryFeature, SimplePredicate.Operator.NOT_EQUAL, value);
            rightPredicate = predicateManager.createSimplePredicate(binaryFeature, SimplePredicate.Operator.EQUAL, value);
        } else if (feature instanceof ThresholdFeature && !numeric) {
            ThresholdFeature thresholdFeature = (ThresholdFeature) feature;
            String name = thresholdFeature.getName();
            Object missingValue = thresholdFeature.getMissingValue();
            java.util.function.Predicate<Object> valueFilter = categoryManager.getValueFilter(name);
            if (!ValueUtil.isNaN(missingValue)) {
                valueFilter = valueFilter.and(value -> !ValueUtil.isNaN(value));
            }
            List<Object> leftValues = thresholdFeature.getValues((Number value) -> (toSplitValue(value) <= threshold)).stream().filter(valueFilter).collect(Collectors.toList());
            List<Object> rightValues = thresholdFeature.getValues((Number value) -> (toSplitValue(value)) > threshold).stream().filter(valueFilter).collect(Collectors.toList());
            leftCategoryManager = leftCategoryManager.fork(name, leftValues);
            rightCategoryManager = rightCategoryManager.fork(name, rightValues);
            leftPredicate = ThresholdFeatureUtil.createPredicate(thresholdFeature, leftValues, missingValue, predicateManager);
            rightPredicate = ThresholdFeatureUtil.createPredicate(thresholdFeature, rightValues, missingValue, predicateManager);
        } else {
            ContinuousFeature continuousFeature = toContinuousFeature(feature);
            Double value = threshold;
            leftPredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
            rightPredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
        }
        int leftIndex = leftChildren[index];
        int rightIndex = rightChildren[index];
        Node leftChild = encodeNode(leftIndex, leftPredicate, miningFunction, numeric, leftChildren, rightChildren, features, thresholds, values, leftCategoryManager, predicateManager, scoreDistributionManager, schema);
        Node rightChild = encodeNode(rightIndex, rightPredicate, miningFunction, numeric, leftChildren, rightChildren, features, thresholds, values, rightCategoryManager, predicateManager, scoreDistributionManager, schema);
        Node result;
        if (miningFunction == MiningFunction.CLASSIFICATION) {
            result = new ClassifierNode(null, predicate);
        } else if (miningFunction == MiningFunction.REGRESSION) {
            double value = values[index];
            result = new BranchNode(value, predicate);
        } else {
            throw new IllegalArgumentException();
        }
        result.setId(id).addNodes(leftChild, rightChild);
        return result;
    } else // A leaf node
    {
        Node result;
        if (miningFunction == MiningFunction.CLASSIFICATION) {
            CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
            double[] recordCounts = getRow(values, leftChildren.length, categoricalLabel.size(), index);
            double totalRecordCount = 0d;
            Object score = null;
            double scoreRecordCount = -Double.MAX_VALUE;
            for (int i = 0; i < recordCounts.length; i++) {
                double recordCount = recordCounts[i];
                totalRecordCount += recordCount;
                if (recordCount > scoreRecordCount) {
                    score = categoricalLabel.getValue(i);
                    scoreRecordCount = recordCount;
                }
            }
            result = new ClassifierNode(score, predicate).setId(id).setRecordCount(ValueUtil.narrow(totalRecordCount));
            scoreDistributionManager.addScoreDistributions(result, categoricalLabel.getValues(), recordCounts);
        } else if (miningFunction == MiningFunction.REGRESSION) {
            double value = values[index];
            result = new LeafNode(value, predicate).setId(id);
        } else {
            throw new IllegalArgumentException();
        }
        return result;
    }
}
Also used : Node(org.dmg.pmml.tree.Node) ClassifierNode(org.dmg.pmml.tree.ClassifierNode) BranchNode(org.dmg.pmml.tree.BranchNode) LeafNode(org.dmg.pmml.tree.LeafNode) BinaryFeature(org.jpmml.converter.BinaryFeature) ContinuousFeature(org.jpmml.converter.ContinuousFeature) ThresholdFeature(org.jpmml.converter.ThresholdFeature) Feature(org.jpmml.converter.Feature) BinaryFeature(org.jpmml.converter.BinaryFeature) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) BranchNode(org.dmg.pmml.tree.BranchNode) ContinuousFeature(org.jpmml.converter.ContinuousFeature) CategoricalLabel(org.jpmml.converter.CategoricalLabel) LeafNode(org.dmg.pmml.tree.LeafNode) ThresholdFeature(org.jpmml.converter.ThresholdFeature) List(java.util.List) ArrayList(java.util.ArrayList) ClassifierNode(org.dmg.pmml.tree.ClassifierNode) CategoryManager(org.jpmml.converter.CategoryManager)

Aggregations

ClassifierNode (org.dmg.pmml.tree.ClassifierNode)8 Node (org.dmg.pmml.tree.Node)7 CategoricalLabel (org.jpmml.converter.CategoricalLabel)7 LeafNode (org.dmg.pmml.tree.LeafNode)5 ScoreDistribution (org.dmg.pmml.ScoreDistribution)4 BranchNode (org.dmg.pmml.tree.BranchNode)4 TreeModel (org.dmg.pmml.tree.TreeModel)4 ArrayList (java.util.ArrayList)3 List (java.util.List)3 Predicate (org.dmg.pmml.Predicate)2 SimplePredicate (org.dmg.pmml.SimplePredicate)2 ContinuousFeature (org.jpmml.converter.ContinuousFeature)2 Feature (org.jpmml.converter.Feature)2 HashMap (java.util.HashMap)1 DecisionTreeClassificationModel (org.apache.spark.ml.classification.DecisionTreeClassificationModel)1 DecisionTreeRegressionModel (org.apache.spark.ml.regression.DecisionTreeRegressionModel)1 DecisionTreeModel (org.apache.spark.ml.tree.DecisionTreeModel)1 ImpurityCalculator (org.apache.spark.mllib.tree.impurity.ImpurityCalculator)1 Visitor (org.dmg.pmml.Visitor)1 MiningModel (org.dmg.pmml.mining.MiningModel)1