use of org.dmg.pmml.tree.Node in project jpmml-r by jpmml.
the class GBMConverter method encodeNode.
private void encodeNode(Node node, int i, RGenericVector tree, RGenericVector c_splits, Schema schema) {
RIntegerVector splitVar = (RIntegerVector) tree.getValue(0);
RDoubleVector splitCodePred = (RDoubleVector) tree.getValue(1);
RIntegerVector leftNode = (RIntegerVector) tree.getValue(2);
RIntegerVector rightNode = (RIntegerVector) tree.getValue(3);
RIntegerVector missingNode = (RIntegerVector) tree.getValue(4);
RDoubleVector prediction = (RDoubleVector) tree.getValue(7);
Predicate missingPredicate;
Predicate leftPredicate;
Predicate rightPredicate;
Integer var = splitVar.getValue(i);
if (var != -1) {
Feature feature = schema.getFeature(var);
missingPredicate = createSimplePredicate(feature, SimplePredicate.Operator.IS_MISSING, null);
Double split = splitCodePred.getValue(i);
if (feature instanceof CategoricalFeature) {
CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
List<String> values = categoricalFeature.getValues();
int index = ValueUtil.asInt(split);
RIntegerVector c_split = (RIntegerVector) c_splits.getValue(index);
List<Integer> splitValues = c_split.getValues();
leftPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, splitValues, true));
rightPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, splitValues, false));
} else {
ContinuousFeature continuousFeature = feature.toContinuousFeature();
String value = ValueUtil.formatValue(split);
leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_THAN, value);
rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_OR_EQUAL, value);
}
} else {
Double value = prediction.getValue(i);
node.setScore(ValueUtil.formatValue(value));
return;
}
Integer missing = missingNode.getValue(i);
if (missing != -1) {
Node missingChild = new Node().setId(String.valueOf(missing + 1)).setPredicate(missingPredicate);
encodeNode(missingChild, missing, tree, c_splits, schema);
node.addNodes(missingChild);
}
Integer left = leftNode.getValue(i);
if (left != -1) {
Node leftChild = new Node().setId(String.valueOf(left + 1)).setPredicate(leftPredicate);
encodeNode(leftChild, left, tree, c_splits, schema);
node.addNodes(leftChild);
}
Integer right = rightNode.getValue(i);
if (right != -1) {
Node rightChild = new Node().setId(String.valueOf(right + 1)).setPredicate(rightPredicate);
encodeNode(rightChild, right, tree, c_splits, schema);
node.addNodes(rightChild);
}
}
use of org.dmg.pmml.tree.Node in project jpmml-r by jpmml.
the class BinaryTreeConverter method encodeTreeModel.
private TreeModel encodeTreeModel(RGenericVector tree, Schema schema) {
Node root = new Node().setPredicate(new True());
encodeNode(root, tree, schema);
TreeModel treeModel = new TreeModel(this.miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
return treeModel;
}
use of org.dmg.pmml.tree.Node in project jpmml-r by jpmml.
the class BinaryTreeConverter method encodeNode.
private void encodeNode(Node node, RGenericVector tree, Schema schema) {
RIntegerVector nodeId = (RIntegerVector) tree.getValue("nodeID");
RBooleanVector terminal = (RBooleanVector) tree.getValue("terminal");
RGenericVector psplit = (RGenericVector) tree.getValue("psplit");
RGenericVector ssplits = (RGenericVector) tree.getValue("ssplits");
RDoubleVector prediction = (RDoubleVector) tree.getValue("prediction");
RGenericVector left = (RGenericVector) tree.getValue("left");
RGenericVector right = (RGenericVector) tree.getValue("right");
node.setId(String.valueOf(nodeId.asScalar()));
if ((Boolean.TRUE).equals(terminal.asScalar())) {
node = encodeScore(node, prediction, schema);
return;
}
RNumberVector<?> splitpoint = (RNumberVector<?>) psplit.getValue("splitpoint");
RStringVector variableName = (RStringVector) psplit.getValue("variableName");
if (ssplits.size() > 0) {
throw new IllegalArgumentException();
}
Predicate leftPredicate;
Predicate rightPredicate;
FieldName name = FieldName.create(variableName.asScalar());
Integer index = this.featureIndexes.get(name);
if (index == null) {
throw new IllegalArgumentException();
}
Feature feature = schema.getFeature(index);
if (feature instanceof CategoricalFeature) {
CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
List<String> values = categoricalFeature.getValues();
List<Integer> splitValues = (List<Integer>) splitpoint.getValues();
leftPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, splitValues, true));
rightPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, splitValues, false));
} else {
ContinuousFeature continuousFeature = feature.toContinuousFeature();
String value = ValueUtil.formatValue((Double) splitpoint.asScalar());
leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
}
Node leftChild = new Node().setPredicate(leftPredicate);
encodeNode(leftChild, left, schema);
Node rightChild = new Node().setPredicate(rightPredicate);
encodeNode(rightChild, right, schema);
node.addNodes(leftChild, rightChild);
}
use of org.dmg.pmml.tree.Node in project jpmml-r by jpmml.
the class RangerConverter method encodeRegression.
private MiningModel encodeRegression(RGenericVector ranger, Schema schema) {
RGenericVector forest = (RGenericVector) ranger.getValue("forest");
ScoreEncoder scoreEncoder = new ScoreEncoder() {
@Override
public void encode(Node node, Number splitValue, RNumberVector<?> terminalClassCount) {
node.setScore(ValueUtil.formatValue(splitValue));
}
};
List<TreeModel> treeModels = encodeForest(forest, MiningFunction.REGRESSION, scoreEncoder, schema);
MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel())).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.AVERAGE, treeModels));
return miningModel;
}
use of org.dmg.pmml.tree.Node in project jpmml-r by jpmml.
the class RangerConverter method encodeProbabilityForest.
private MiningModel encodeProbabilityForest(RGenericVector ranger, Schema schema) {
RGenericVector forest = (RGenericVector) ranger.getValue("forest");
final RStringVector levels = (RStringVector) forest.getValue("levels");
CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
ScoreEncoder scoreEncoder = new ScoreEncoder() {
@Override
public void encode(Node node, Number splitValue, RNumberVector<?> terminalClassCount) {
if (splitValue.doubleValue() != 0d || (terminalClassCount == null || terminalClassCount.size() != levels.size())) {
throw new IllegalArgumentException();
}
Double maxProbability = null;
for (int i = 0; i < terminalClassCount.size(); i++) {
String value = levels.getValue(i);
Double probability = ValueUtil.asDouble(terminalClassCount.getValue(i));
if (maxProbability == null || (maxProbability).compareTo(probability) < 0) {
node.setScore(value);
maxProbability = probability;
}
ScoreDistribution scoreDisctibution = new ScoreDistribution(value, probability);
node.addScoreDistributions(scoreDisctibution);
}
}
};
List<TreeModel> treeModels = encodeForest(forest, MiningFunction.CLASSIFICATION, scoreEncoder, schema);
MiningModel miningModel = new MiningModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel)).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.AVERAGE, treeModels)).setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, categoricalLabel));
return miningModel;
}
Aggregations