use of org.dmg.pmml.tree.Node in project jpmml-r by jpmml.
the class RPartConverter method encodeClassification.
private TreeModel encodeClassification(RGenericVector frame, RIntegerVector rowNames, RVector<?> var, RIntegerVector n, int[][] splitInfo, RNumberVector<?> splits, RIntegerVector csplit, Schema schema) {
RDoubleVector yval2 = frame.getDoubleElement("yval2");
CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
List<?> categories = categoricalLabel.getValues();
boolean hasScoreDistribution = hasScoreDistribution();
ScoreEncoder scoreEncoder = new ScoreEncoder() {
private List<Integer> classes = null;
private List<List<? extends Number>> recordCounts = null;
{
int rows = rowNames.size();
int columns = 1 + (2 * categories.size()) + 1;
List<Integer> classes = ValueUtil.asIntegers(FortranMatrixUtil.getColumn(yval2.getValues(), rows, columns, 0));
this.classes = new ArrayList<>(classes);
if (hasScoreDistribution) {
this.recordCounts = new ArrayList<>();
for (int i = 0; i < categories.size(); i++) {
List<? extends Number> recordCounts = FortranMatrixUtil.getColumn(yval2.getValues(), rows, columns, 1 + i);
this.recordCounts.add(new ArrayList<>(recordCounts));
}
}
}
@Override
public Node encode(Node node, int offset) {
Object score = categories.get(this.classes.get(offset) - 1);
Integer recordCount = n.getValue(offset);
node.setScore(score).setRecordCount(recordCount);
if (hasScoreDistribution) {
node = new ClassifierNode(node);
List<ScoreDistribution> scoreDistributions = node.getScoreDistributions();
for (int i = 0; i < categories.size(); i++) {
List<? extends Number> recordCounts = this.recordCounts.get(i);
ScoreDistribution scoreDistribution = new ScoreDistribution().setValue(categories.get(i)).setRecordCount(recordCounts.get(offset));
scoreDistributions.add(scoreDistribution);
}
}
return node;
}
};
Node root = encodeNode(True.INSTANCE, 1, rowNames, var, n, splitInfo, splits, csplit, scoreEncoder, schema);
TreeModel treeModel = new TreeModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(schema.getLabel()), root);
if (hasScoreDistribution) {
treeModel.setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, categoricalLabel));
}
return configureTreeModel(treeModel);
}
use of org.dmg.pmml.tree.Node in project jpmml-r by jpmml.
the class RPartConverter method encodeRegression.
private TreeModel encodeRegression(RGenericVector frame, RIntegerVector rowNames, RVector<?> var, RIntegerVector n, int[][] splitInfo, RNumberVector<?> splits, RIntegerVector csplit, Schema schema) {
RNumberVector<?> yval = frame.getNumericElement("yval");
ScoreEncoder scoreEncoder = new ScoreEncoder() {
@Override
public Node encode(Node node, int offset) {
Number score = yval.getValue(offset);
Number recordCount = n.getValue(offset);
node.setScore(score).setRecordCount(recordCount);
return node;
}
};
Node root = encodeNode(True.INSTANCE, 1, rowNames, var, n, splitInfo, splits, csplit, scoreEncoder, schema);
TreeModel treeModel = new TreeModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel()), root);
return configureTreeModel(treeModel);
}
use of org.dmg.pmml.tree.Node in project jpmml-r by jpmml.
the class RandomForestConverter method encodeTreeModel.
private <P extends Number> TreeModel encodeTreeModel(MiningFunction miningFunction, ScoreEncoder<P> scoreEncoder, List<? extends Number> leftDaughter, List<? extends Number> rightDaughter, List<P> nodepred, List<? extends Number> bestvar, List<Double> xbestsplit, Schema schema) {
RGenericVector randomForest = getObject();
Node root = encodeNode(True.INSTANCE, 0, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit, nodepred, new CategoryManager(), schema);
TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), root).setMissingValueStrategy(TreeModel.MissingValueStrategy.NULL_PREDICTION).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
if (this.compact) {
Visitor visitor = new RandomForestCompactor();
visitor.applyTo(treeModel);
}
return treeModel;
}
use of org.dmg.pmml.tree.Node in project jpmml-r by jpmml.
the class RandomForestConverter method encodeNode.
private <P extends Number> Node encodeNode(Predicate predicate, int i, ScoreEncoder<P> scoreEncoder, List<? extends Number> leftDaughter, List<? extends Number> rightDaughter, List<? extends Number> bestvar, List<Double> xbestsplit, List<P> nodepred, CategoryManager categoryManager, Schema schema) {
Integer id = Integer.valueOf(i + 1);
int var = ValueUtil.asInt(bestvar.get(i));
if (var == 0) {
P prediction = nodepred.get(i);
Node result = new LeafNode(scoreEncoder.encode(prediction), predicate).setId(id);
return result;
}
CategoryManager leftCategoryManager = categoryManager;
CategoryManager rightCategoryManager = categoryManager;
Predicate leftPredicate;
Predicate rightPredicate;
Feature feature = schema.getFeature(var - 1);
Double split = xbestsplit.get(i);
if (feature instanceof BooleanFeature) {
BooleanFeature booleanFeature = (BooleanFeature) feature;
if (split != 0.5d) {
throw new IllegalArgumentException();
}
leftPredicate = createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(0));
rightPredicate = createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(1));
} else if (feature instanceof CategoricalFeature) {
CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
String name = categoricalFeature.getName();
List<?> values = categoricalFeature.getValues();
java.util.function.Predicate<Object> valueFilter = categoryManager.getValueFilter(name);
List<Object> leftValues = selectValues(values, valueFilter, split, true);
List<Object> rightValues = selectValues(values, valueFilter, split, false);
leftCategoryManager = categoryManager.fork(name, leftValues);
rightCategoryManager = categoryManager.fork(name, rightValues);
leftPredicate = createPredicate(categoricalFeature, leftValues);
rightPredicate = createPredicate(categoricalFeature, rightValues);
} else {
ContinuousFeature continuousFeature = feature.toContinuousFeature();
leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, split);
rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, split);
}
Node result = new BranchNode(null, predicate).setId(id);
List<Node> nodes = result.getNodes();
int left = ValueUtil.asInt(leftDaughter.get(i));
if (left != 0) {
Node leftChild = encodeNode(leftPredicate, left - 1, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit, nodepred, leftCategoryManager, schema);
nodes.add(leftChild);
}
int right = ValueUtil.asInt(rightDaughter.get(i));
if (right != 0) {
Node rightChild = encodeNode(rightPredicate, right - 1, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit, nodepred, rightCategoryManager, schema);
nodes.add(rightChild);
}
return result;
}
use of org.dmg.pmml.tree.Node in project jpmml-r by jpmml.
the class RangerConverter method encodeProbabilityForest.
private MiningModel encodeProbabilityForest(RGenericVector forest, Schema schema) {
RStringVector levels = forest.getStringElement("levels");
CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
ScoreEncoder scoreEncoder = new ScoreEncoder() {
@Override
public Node encode(Node node, Number splitValue, RNumberVector<?> terminalClassCount) {
if (splitValue.doubleValue() != 0d || (terminalClassCount == null || terminalClassCount.size() != levels.size())) {
throw new IllegalArgumentException();
}
node = new ClassifierNode(node);
List<ScoreDistribution> scoreDistributions = node.getScoreDistributions();
Number maxProbability = null;
for (int i = 0; i < terminalClassCount.size(); i++) {
String value = levels.getValue(i);
Number probability = terminalClassCount.getValue(i);
if (maxProbability == null || ((Comparable) maxProbability).compareTo(probability) < 0) {
node.setScore(value);
maxProbability = probability;
}
ScoreDistribution scoreDistribution = new ScoreDistribution(value, probability);
scoreDistributions.add(scoreDistribution);
}
return node;
}
};
List<TreeModel> treeModels = encodeForest(forest, MiningFunction.CLASSIFICATION, scoreEncoder, schema);
MiningModel miningModel = new MiningModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel)).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.AVERAGE, treeModels)).setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, categoricalLabel));
return miningModel;
}
Aggregations