use of org.dmg.pmml.tree.ClassifierNode in project jpmml-sparkml by jpmml.
the class TreeModelUtil method encodeDecisionTree.
private static <M extends Model<M> & DecisionTreeModel> TreeModel encodeDecisionTree(ModelConverter<?> converter, M model, PredicateManager predicateManager, ScoreDistributionManager scoreDistributionManager, Schema schema) {
TreeModel treeModel;
if (model instanceof DecisionTreeRegressionModel) {
ScoreEncoder scoreEncoder = new ScoreEncoder() {
@Override
public Node encode(Node node, org.apache.spark.ml.tree.LeafNode leafNode) {
node.setScore(leafNode.prediction());
return node;
}
};
treeModel = encodeTreeModel(MiningFunction.REGRESSION, scoreEncoder, model, predicateManager, schema);
} else if (model instanceof DecisionTreeClassificationModel) {
ScoreEncoder scoreEncoder = new ScoreEncoder() {
private CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
@Override
public Node encode(Node node, org.apache.spark.ml.tree.LeafNode leafNode) {
node = new ClassifierNode(null, node.getPredicate());
int index = ValueUtil.asInt(leafNode.prediction());
node.setScore(this.categoricalLabel.getValue(index));
ImpurityCalculator impurityCalculator = leafNode.impurityStats();
node.setRecordCount(ValueUtil.narrow(impurityCalculator.count()));
scoreDistributionManager.addScoreDistributions(node, this.categoricalLabel.getValues(), impurityCalculator.stats());
return node;
}
};
treeModel = encodeTreeModel(MiningFunction.CLASSIFICATION, scoreEncoder, model, predicateManager, schema);
} else {
throw new IllegalArgumentException();
}
Boolean compact = (Boolean) converter.getOption(HasTreeOptions.OPTION_COMPACT, Boolean.TRUE);
if (compact != null && compact) {
Visitor visitor = new TreeModelCompactor();
visitor.applyTo(treeModel);
}
return treeModel;
}
use of org.dmg.pmml.tree.ClassifierNode in project jpmml-sklearn by jpmml.
the class DummyClassifier method encodeModel.
@Override
public TreeModel encodeModel(Schema schema) {
List<?> classes = getClasses();
List<? extends Number> classPrior = getClassPrior();
Object constant = getConstant();
String strategy = getStrategy();
ClassDictUtil.checkSize(classes, classPrior);
CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
int index;
double[] probabilities;
switch(strategy) {
case "constant":
{
index = classes.indexOf(constant);
if (index < 0) {
throw new IllegalArgumentException();
}
probabilities = new double[classes.size()];
probabilities[index] = 1d;
}
break;
case "most_frequent":
{
index = indexOfMax(classPrior);
probabilities = new double[classes.size()];
probabilities[index] = 1d;
}
break;
case "prior":
{
index = indexOfMax(classPrior);
probabilities = Doubles.toArray(classPrior);
}
break;
default:
throw new IllegalArgumentException(strategy);
}
Node root = new ClassifierNode(categoricalLabel.getValue(index), True.INSTANCE);
ScoreDistributionManager scoreDistributionManager = new ScoreDistributionManager();
scoreDistributionManager.addScoreDistributions(root, categoricalLabel.getValues(), probabilities);
TreeModel treeModel = new TreeModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel), root).setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, categoricalLabel));
return treeModel;
}
use of org.dmg.pmml.tree.ClassifierNode in project jpmml-sklearn by jpmml.
the class TreeUtil method encodeNode.
private static Node encodeNode(int index, Predicate predicate, MiningFunction miningFunction, boolean numeric, int[] leftChildren, int[] rightChildren, int[] features, double[] thresholds, double[] values, CategoryManager categoryManager, PredicateManager predicateManager, ScoreDistributionManager scoreDistributionManager, Schema schema) {
Integer id = Integer.valueOf(index);
int featureIndex = features[index];
// A non-leaf (binary split) node
if (featureIndex >= 0) {
Feature feature = schema.getFeature(featureIndex);
double threshold = thresholds[index];
CategoryManager leftCategoryManager = categoryManager;
CategoryManager rightCategoryManager = categoryManager;
Predicate leftPredicate;
Predicate rightPredicate;
if (feature instanceof BinaryFeature) {
BinaryFeature binaryFeature = (BinaryFeature) feature;
if (threshold < 0 || threshold > 1) {
throw new IllegalArgumentException();
}
Object value = binaryFeature.getValue();
leftPredicate = predicateManager.createSimplePredicate(binaryFeature, SimplePredicate.Operator.NOT_EQUAL, value);
rightPredicate = predicateManager.createSimplePredicate(binaryFeature, SimplePredicate.Operator.EQUAL, value);
} else if (feature instanceof ThresholdFeature && !numeric) {
ThresholdFeature thresholdFeature = (ThresholdFeature) feature;
String name = thresholdFeature.getName();
Object missingValue = thresholdFeature.getMissingValue();
java.util.function.Predicate<Object> valueFilter = categoryManager.getValueFilter(name);
if (!ValueUtil.isNaN(missingValue)) {
valueFilter = valueFilter.and(value -> !ValueUtil.isNaN(value));
}
List<Object> leftValues = thresholdFeature.getValues((Number value) -> (toSplitValue(value) <= threshold)).stream().filter(valueFilter).collect(Collectors.toList());
List<Object> rightValues = thresholdFeature.getValues((Number value) -> (toSplitValue(value)) > threshold).stream().filter(valueFilter).collect(Collectors.toList());
leftCategoryManager = leftCategoryManager.fork(name, leftValues);
rightCategoryManager = rightCategoryManager.fork(name, rightValues);
leftPredicate = ThresholdFeatureUtil.createPredicate(thresholdFeature, leftValues, missingValue, predicateManager);
rightPredicate = ThresholdFeatureUtil.createPredicate(thresholdFeature, rightValues, missingValue, predicateManager);
} else {
ContinuousFeature continuousFeature = toContinuousFeature(feature);
Double value = threshold;
leftPredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
rightPredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
}
int leftIndex = leftChildren[index];
int rightIndex = rightChildren[index];
Node leftChild = encodeNode(leftIndex, leftPredicate, miningFunction, numeric, leftChildren, rightChildren, features, thresholds, values, leftCategoryManager, predicateManager, scoreDistributionManager, schema);
Node rightChild = encodeNode(rightIndex, rightPredicate, miningFunction, numeric, leftChildren, rightChildren, features, thresholds, values, rightCategoryManager, predicateManager, scoreDistributionManager, schema);
Node result;
if (miningFunction == MiningFunction.CLASSIFICATION) {
result = new ClassifierNode(null, predicate);
} else if (miningFunction == MiningFunction.REGRESSION) {
double value = values[index];
result = new BranchNode(value, predicate);
} else {
throw new IllegalArgumentException();
}
result.setId(id).addNodes(leftChild, rightChild);
return result;
} else // A leaf node
{
Node result;
if (miningFunction == MiningFunction.CLASSIFICATION) {
CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
double[] recordCounts = getRow(values, leftChildren.length, categoricalLabel.size(), index);
double totalRecordCount = 0d;
Object score = null;
double scoreRecordCount = -Double.MAX_VALUE;
for (int i = 0; i < recordCounts.length; i++) {
double recordCount = recordCounts[i];
totalRecordCount += recordCount;
if (recordCount > scoreRecordCount) {
score = categoricalLabel.getValue(i);
scoreRecordCount = recordCount;
}
}
result = new ClassifierNode(score, predicate).setId(id).setRecordCount(ValueUtil.narrow(totalRecordCount));
scoreDistributionManager.addScoreDistributions(result, categoricalLabel.getValues(), recordCounts);
} else if (miningFunction == MiningFunction.REGRESSION) {
double value = values[index];
result = new LeafNode(value, predicate).setId(id);
} else {
throw new IllegalArgumentException();
}
return result;
}
}
Aggregations