use of org.dmg.pmml.tree.Node in project jpmml-r by jpmml.
the class GBMConverter method encodeTreeModel.
private TreeModel encodeTreeModel(MiningFunction miningFunction, RGenericVector tree, RGenericVector c_splits, Schema schema) {
Node root = new Node().setId("1").setPredicate(new True());
encodeNode(root, 0, tree, c_splits, schema);
TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.MULTI_SPLIT);
return treeModel;
}
use of org.dmg.pmml.tree.Node in project jpmml-r by jpmml.
the class IForestConverter method encodeTreeModel.
private TreeModel encodeTreeModel(RGenericVector trees, int index, Schema schema) {
RIntegerVector nrnodes = (RIntegerVector) trees.getValue("nrnodes");
RIntegerVector ntree = (RIntegerVector) trees.getValue("ntree");
RIntegerVector nodeStatus = (RIntegerVector) trees.getValue("nodeStatus");
RIntegerVector leftDaughter = (RIntegerVector) trees.getValue("lDaughter");
RIntegerVector rightDaughter = (RIntegerVector) trees.getValue("rDaughter");
RIntegerVector splitAtt = (RIntegerVector) trees.getValue("splitAtt");
RDoubleVector splitPoint = (RDoubleVector) trees.getValue("splitPoint");
RIntegerVector nSam = (RIntegerVector) trees.getValue("nSam");
int rows = nrnodes.asScalar();
int columns = ntree.asScalar();
Node root = new Node().setPredicate(new True());
encodeNode(root, 0, 0, FortranMatrixUtil.getColumn(nodeStatus.getValues(), rows, columns, index), FortranMatrixUtil.getColumn(nSam.getValues(), rows, columns, index), FortranMatrixUtil.getColumn(leftDaughter.getValues(), rows, columns, index), FortranMatrixUtil.getColumn(rightDaughter.getValues(), rows, columns, index), FortranMatrixUtil.getColumn(splitAtt.getValues(), rows, columns, index), FortranMatrixUtil.getColumn(splitPoint.getValues(), rows, columns, index), schema);
TreeModel treeModel = new TreeModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel()), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
return treeModel;
}
use of org.dmg.pmml.tree.Node in project jpmml-r by jpmml.
the class IForestConverter method encodeNode.
private void encodeNode(Node node, int index, int depth, List<Integer> nodeStatus, List<Integer> nodeSize, List<Integer> leftDaughter, List<Integer> rightDaughter, List<Integer> splitAtt, List<Double> splitValue, Schema schema) {
int status = nodeStatus.get(index);
int size = nodeSize.get(index);
node.setId(String.valueOf(index + 1));
// Interior node
if (status == -3) {
int att = splitAtt.get(index);
ContinuousFeature feature = (ContinuousFeature) schema.getFeature(att - 1);
String value = ValueUtil.formatValue(splitValue.get(index));
Predicate leftPredicate = createSimplePredicate(feature, SimplePredicate.Operator.LESS_THAN, value);
Node leftChild = new Node().setPredicate(leftPredicate);
int leftIndex = (leftDaughter.get(index) - 1);
encodeNode(leftChild, leftIndex, depth + 1, nodeStatus, nodeSize, leftDaughter, rightDaughter, splitAtt, splitValue, schema);
Predicate rightPredicate = createSimplePredicate(feature, SimplePredicate.Operator.GREATER_OR_EQUAL, value);
Node rightChild = new Node().setPredicate(rightPredicate);
int rightIndex = (rightDaughter.get(index) - 1);
encodeNode(rightChild, rightIndex, depth + 1, nodeStatus, nodeSize, leftDaughter, rightDaughter, splitAtt, splitValue, schema);
node.addNodes(leftChild, rightChild);
} else // Terminal node
if (status == -1) {
node.setScore(ValueUtil.formatValue(depth + avgPathLength(size)));
} else {
throw new IllegalArgumentException();
}
}
use of org.dmg.pmml.tree.Node in project jpmml-r by jpmml.
the class RandomForestConverter method encodeNode.
private <P extends Number> void encodeNode(Node node, int i, ScoreEncoder<P> scoreEncoder, List<? extends Number> leftDaughter, List<? extends Number> rightDaughter, List<? extends Number> bestvar, List<Double> xbestsplit, List<P> nodepred, Schema schema) {
Predicate leftPredicate;
Predicate rightPredicate;
int var = ValueUtil.asInt(bestvar.get(i));
if (var != 0) {
Feature feature = schema.getFeature(var - 1);
Double split = xbestsplit.get(i);
if (feature instanceof BooleanFeature) {
BooleanFeature booleanFeature = (BooleanFeature) feature;
if (split != 0.5d) {
throw new IllegalArgumentException();
}
leftPredicate = createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(0));
rightPredicate = createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(1));
} else if (feature instanceof CategoricalFeature) {
CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
List<String> values = categoricalFeature.getValues();
leftPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, split, true));
rightPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, split, false));
} else {
ContinuousFeature continuousFeature = feature.toContinuousFeature();
String value = ValueUtil.formatValue(split);
leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
}
} else {
P prediction = nodepred.get(i);
node.setScore(scoreEncoder.encode(prediction));
return;
}
int left = ValueUtil.asInt(leftDaughter.get(i));
if (left != 0) {
Node leftChild = new Node().setId(String.valueOf(left)).setPredicate(leftPredicate);
encodeNode(leftChild, left - 1, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit, nodepred, schema);
node.addNodes(leftChild);
}
int right = ValueUtil.asInt(rightDaughter.get(i));
if (right != 0) {
Node rightChild = new Node().setId(String.valueOf(right)).setPredicate(rightPredicate);
encodeNode(rightChild, right - 1, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit, nodepred, schema);
node.addNodes(rightChild);
}
}
use of org.dmg.pmml.tree.Node in project jpmml-r by jpmml.
the class RangerConverter method encodeClassification.
private MiningModel encodeClassification(RGenericVector ranger, Schema schema) {
RGenericVector forest = (RGenericVector) ranger.getValue("forest");
final RStringVector levels = (RStringVector) forest.getValue("levels");
ScoreEncoder scoreEncoder = new ScoreEncoder() {
@Override
public void encode(Node node, Number splitValue, RNumberVector<?> terminalClassCount) {
int index = ValueUtil.asInt(splitValue);
if (terminalClassCount != null) {
throw new IllegalArgumentException();
}
node.setScore(levels.getValue(index - 1));
}
};
List<TreeModel> treeModels = encodeForest(forest, MiningFunction.CLASSIFICATION, scoreEncoder, schema);
MiningModel miningModel = new MiningModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(schema.getLabel())).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.MAJORITY_VOTE, treeModels));
return miningModel;
}
Aggregations