use of org.dmg.pmml.tree.BranchNode in project jpmml-r by jpmml.
the class IForestConverter method encodeNode.
private Node encodeNode(int index, Predicate predicate, int depth, List<Integer> nodeStatus, List<Integer> nodeSize, List<Integer> leftDaughter, List<Integer> rightDaughter, List<Integer> splitAtt, List<Double> splitValue, Schema schema) {
Integer id = Integer.valueOf(index + 1);
int status = nodeStatus.get(index);
int size = nodeSize.get(index);
// Interior node
if (status == -3) {
int att = splitAtt.get(index);
ContinuousFeature feature = (ContinuousFeature) schema.getFeature(att - 1);
Double value = splitValue.get(index);
Predicate leftPredicate = createSimplePredicate(feature, SimplePredicate.Operator.LESS_THAN, value);
Predicate rightPredicate = createSimplePredicate(feature, SimplePredicate.Operator.GREATER_OR_EQUAL, value);
Node leftChild = encodeNode(leftDaughter.get(index) - 1, leftPredicate, depth + 1, nodeStatus, nodeSize, leftDaughter, rightDaughter, splitAtt, splitValue, schema);
Node rightChild = encodeNode(rightDaughter.get(index) - 1, rightPredicate, depth + 1, nodeStatus, nodeSize, leftDaughter, rightDaughter, splitAtt, splitValue, schema);
Node result = new BranchNode(null, predicate).setId(id).addNodes(leftChild, rightChild);
return result;
} else // Terminal node
if (status == -1) {
Node result = new LeafNode(depth + avgPathLength(size), predicate).setId(id);
return result;
} else {
throw new IllegalArgumentException();
}
}
use of org.dmg.pmml.tree.BranchNode in project jpmml-r by jpmml.
the class PartyConverter method encodeNode.
private Node encodeNode(RGenericVector partyNode, Predicate predicate, RVector<?> response, RDoubleVector prob, Schema schema) {
RIntegerVector id = partyNode.getIntegerElement("id");
RGenericVector split = partyNode.getGenericElement("split");
RGenericVector kids = partyNode.getGenericElement("kids");
RGenericVector surrogates = partyNode.getGenericElement("surrogates");
RGenericVector info = partyNode.getGenericElement("info");
if (surrogates != null) {
throw new IllegalArgumentException();
}
Label label = schema.getLabel();
List<? extends Feature> features = schema.getFeatures();
Node result;
if (response instanceof RFactorVector) {
result = new ClassifierNode(null, predicate);
} else {
if (kids == null) {
result = new LeafNode(null, predicate);
} else {
result = new BranchNode(null, predicate);
}
}
result.setId(Integer.valueOf(id.asScalar()));
if (response instanceof RFactorVector) {
RFactorVector factor = (RFactorVector) response;
int index = id.asScalar() - 1;
result.setScore(factor.getFactorValue(index));
CategoricalLabel categoricalLabel = (CategoricalLabel) label;
List<Double> probabilities = FortranMatrixUtil.getRow(prob.getValues(), response.size(), categoricalLabel.size(), index);
List<ScoreDistribution> scoreDistributions = result.getScoreDistributions();
for (int i = 0; i < categoricalLabel.size(); i++) {
Object value = categoricalLabel.getValue(i);
Double probability = probabilities.get(i);
ScoreDistribution scoreDistribution = new ScoreDistribution(value, probability);
scoreDistributions.add(scoreDistribution);
}
} else {
result.setScore(response.getValue(id.asScalar() - 1));
}
if (kids == null) {
return result;
}
RIntegerVector varid = split.getIntegerElement("varid");
RDoubleVector breaks = split.getDoubleElement("breaks");
RIntegerVector index = split.getIntegerElement("index");
RBooleanVector right = split.getBooleanElement("right");
Feature feature = features.get(varid.asScalar() - 1);
if (breaks != null && index == null) {
ContinuousFeature continuousFeature = (ContinuousFeature) feature;
if (kids.size() != 2) {
throw new IllegalArgumentException();
}
if (breaks.size() != 1) {
throw new IllegalArgumentException();
}
Predicate leftPredicate;
Predicate rightPredicate;
Double value = breaks.asScalar();
if (right.asScalar()) {
leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
} else {
leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_THAN, value);
rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_OR_EQUAL, value);
}
Node leftChild = encodeNode(kids.getGenericValue(0), leftPredicate, response, prob, schema);
Node rightChild = encodeNode(kids.getGenericValue(1), rightPredicate, response, prob, schema);
result.addNodes(leftChild, rightChild);
} else if (breaks == null && index != null) {
CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
if (kids.size() < 2) {
throw new IllegalArgumentException();
}
List<?> values = categoricalFeature.getValues();
for (int i = 0; i < kids.size(); i++) {
Predicate childPredicate;
if (right.asScalar()) {
childPredicate = createPredicate(categoricalFeature, selectValues(values, index, i + 1));
} else {
throw new IllegalArgumentException();
}
Node child = encodeNode(kids.getGenericValue(i), childPredicate, response, prob, schema);
result.addNodes(child);
}
} else {
throw new IllegalArgumentException();
}
return result;
}
use of org.dmg.pmml.tree.BranchNode in project jpmml-r by jpmml.
the class GBMConverter method encodeNode.
private Node encodeNode(int i, Predicate predicate, RGenericVector tree, RGenericVector c_splits, FlagManager flagManager, CategoryManager categoryManager, Schema schema) {
Integer id = Integer.valueOf(i + 1);
RIntegerVector splitVar = tree.getIntegerValue(0);
RDoubleVector splitCodePred = tree.getDoubleValue(1);
RIntegerVector leftNode = tree.getIntegerValue(2);
RIntegerVector rightNode = tree.getIntegerValue(3);
RIntegerVector missingNode = tree.getIntegerValue(4);
RDoubleVector prediction = tree.getDoubleValue(7);
Integer var = splitVar.getValue(i);
if (var == -1) {
Double value = prediction.getValue(i);
Node result = new LeafNode(value, predicate).setId(id);
return result;
}
Boolean isMissing;
FlagManager missingFlagManager = flagManager;
FlagManager nonMissingFlagManager = flagManager;
Predicate missingPredicate;
Feature feature = schema.getFeature(var);
{
String name = feature.getName();
isMissing = flagManager.getValue(name);
if (isMissing == null) {
missingFlagManager = missingFlagManager.fork(name, Boolean.TRUE);
nonMissingFlagManager = nonMissingFlagManager.fork(name, Boolean.FALSE);
}
missingPredicate = createSimplePredicate(feature, SimplePredicate.Operator.IS_MISSING, null);
}
CategoryManager leftCategoryManager = categoryManager;
CategoryManager rightCategoryManager = categoryManager;
Predicate leftPredicate;
Predicate rightPredicate;
Double split = splitCodePred.getValue(i);
if (feature instanceof CategoricalFeature) {
CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
String name = categoricalFeature.getName();
List<?> values = categoricalFeature.getValues();
int index = ValueUtil.asInt(split);
RIntegerVector c_split = c_splits.getIntegerValue(index);
List<Integer> splitValues = c_split.getValues();
java.util.function.Predicate<Object> valueFilter = categoryManager.getValueFilter(name);
List<Object> leftValues = selectValues(values, valueFilter, splitValues, true);
List<Object> rightValues = selectValues(values, valueFilter, splitValues, false);
leftCategoryManager = leftCategoryManager.fork(name, leftValues);
rightCategoryManager = rightCategoryManager.fork(name, rightValues);
leftPredicate = createPredicate(categoricalFeature, leftValues);
rightPredicate = createPredicate(categoricalFeature, rightValues);
} else {
ContinuousFeature continuousFeature = feature.toContinuousFeature();
leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_THAN, split);
rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_OR_EQUAL, split);
}
Node result = new BranchNode(null, predicate).setId(id);
List<Node> nodes = result.getNodes();
Integer missing = missingNode.getValue(i);
if (missing != -1 && (isMissing == null || isMissing)) {
Node missingChild = encodeNode(missing, missingPredicate, tree, c_splits, missingFlagManager, categoryManager, schema);
nodes.add(missingChild);
}
Integer left = leftNode.getValue(i);
if (left != -1 && (isMissing == null || !isMissing)) {
Node leftChild = encodeNode(left, leftPredicate, tree, c_splits, nonMissingFlagManager, leftCategoryManager, schema);
nodes.add(leftChild);
}
Integer right = rightNode.getValue(i);
if (right != -1 && (isMissing == null || !isMissing)) {
Node rightChild = encodeNode(right, rightPredicate, tree, c_splits, nonMissingFlagManager, rightCategoryManager, schema);
nodes.add(rightChild);
}
return result;
}
use of org.dmg.pmml.tree.BranchNode in project jpmml-r by jpmml.
the class RangerConverter method encodeNode.
private Node encodeNode(Predicate predicate, int index, ScoreEncoder scoreEncoder, RNumberVector<?> leftChildIDs, RNumberVector<?> rightChildIDs, RNumberVector<?> splitVarIDs, RNumberVector<?> splitValues, RGenericVector terminalClassCounts, CategoryManager categoryManager, Schema schema) {
int leftIndex = ValueUtil.asInt(leftChildIDs.getValue(index));
int rightIndex = ValueUtil.asInt(rightChildIDs.getValue(index));
Number splitValue = splitValues.getValue(index);
RNumberVector<?> terminalClassCount = (terminalClassCounts != null ? terminalClassCounts.getNumericValue(index) : null);
if (leftIndex == 0 && rightIndex == 0) {
Node result = new LeafNode(null, predicate);
return scoreEncoder.encode(result, splitValue, terminalClassCount);
}
CategoryManager leftCategoryManager = categoryManager;
CategoryManager rightCategoryManager = categoryManager;
Predicate leftPredicate;
Predicate rightPredicate;
int splitVarIndex = ValueUtil.asInt(splitVarIDs.getValue(index));
Feature feature = schema.getFeature(this.hasDependentVar ? (splitVarIndex - 1) : splitVarIndex);
if (feature instanceof CategoricalFeature) {
CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
int splitLevelIndex = ValueUtil.asInt(Math.floor(splitValue.doubleValue()));
String name = categoricalFeature.getName();
List<?> values = categoricalFeature.getValues();
java.util.function.Predicate<Object> valueFilter = categoryManager.getValueFilter(name);
List<Object> leftValues = filterValues(values.subList(0, splitLevelIndex), valueFilter);
List<Object> rightValues = filterValues(values.subList(splitLevelIndex, values.size()), valueFilter);
leftCategoryManager = leftCategoryManager.fork(name, leftValues);
rightCategoryManager = rightCategoryManager.fork(name, rightValues);
leftPredicate = createPredicate(categoricalFeature, leftValues);
rightPredicate = createPredicate(categoricalFeature, rightValues);
} else {
ContinuousFeature continuousFeature = feature.toContinuousFeature();
leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, splitValue);
rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, splitValue);
}
Node leftChild = encodeNode(leftPredicate, leftIndex, scoreEncoder, leftChildIDs, rightChildIDs, splitVarIDs, splitValues, terminalClassCounts, leftCategoryManager, schema);
Node rightChild = encodeNode(rightPredicate, rightIndex, scoreEncoder, leftChildIDs, rightChildIDs, splitVarIDs, splitValues, terminalClassCounts, rightCategoryManager, schema);
Node result = new BranchNode(null, predicate).addNodes(leftChild, rightChild);
return result;
}
use of org.dmg.pmml.tree.BranchNode in project jpmml-r by jpmml.
the class RandomForestConverter method encodeNode.
private <P extends Number> Node encodeNode(Predicate predicate, int i, ScoreEncoder<P> scoreEncoder, List<? extends Number> leftDaughter, List<? extends Number> rightDaughter, List<? extends Number> bestvar, List<Double> xbestsplit, List<P> nodepred, CategoryManager categoryManager, Schema schema) {
Integer id = Integer.valueOf(i + 1);
int var = ValueUtil.asInt(bestvar.get(i));
if (var == 0) {
P prediction = nodepred.get(i);
Node result = new LeafNode(scoreEncoder.encode(prediction), predicate).setId(id);
return result;
}
CategoryManager leftCategoryManager = categoryManager;
CategoryManager rightCategoryManager = categoryManager;
Predicate leftPredicate;
Predicate rightPredicate;
Feature feature = schema.getFeature(var - 1);
Double split = xbestsplit.get(i);
if (feature instanceof BooleanFeature) {
BooleanFeature booleanFeature = (BooleanFeature) feature;
if (split != 0.5d) {
throw new IllegalArgumentException();
}
leftPredicate = createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(0));
rightPredicate = createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(1));
} else if (feature instanceof CategoricalFeature) {
CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
String name = categoricalFeature.getName();
List<?> values = categoricalFeature.getValues();
java.util.function.Predicate<Object> valueFilter = categoryManager.getValueFilter(name);
List<Object> leftValues = selectValues(values, valueFilter, split, true);
List<Object> rightValues = selectValues(values, valueFilter, split, false);
leftCategoryManager = categoryManager.fork(name, leftValues);
rightCategoryManager = categoryManager.fork(name, rightValues);
leftPredicate = createPredicate(categoricalFeature, leftValues);
rightPredicate = createPredicate(categoricalFeature, rightValues);
} else {
ContinuousFeature continuousFeature = feature.toContinuousFeature();
leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, split);
rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, split);
}
Node result = new BranchNode(null, predicate).setId(id);
List<Node> nodes = result.getNodes();
int left = ValueUtil.asInt(leftDaughter.get(i));
if (left != 0) {
Node leftChild = encodeNode(leftPredicate, left - 1, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit, nodepred, leftCategoryManager, schema);
nodes.add(leftChild);
}
int right = ValueUtil.asInt(rightDaughter.get(i));
if (right != 0) {
Node rightChild = encodeNode(rightPredicate, right - 1, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit, nodepred, rightCategoryManager, schema);
nodes.add(rightChild);
}
return result;
}
Aggregations