use of org.dmg.pmml.tree.CountingLeafNode in project jpmml-r by jpmml.
the class RPartConverter method encodeNode.
private Node encodeNode(Predicate predicate, int rowName, RIntegerVector rowNames, RVector<?> var, RIntegerVector n, int[][] splitInfo, RNumberVector<?> splits, RIntegerVector csplit, ScoreEncoder scoreEncoder, Schema schema) {
int offset = getIndex(rowNames, rowName);
Integer id = Integer.valueOf(rowName);
List<? extends Feature> features = schema.getFeatures();
int splitVar = getFeatureIndex(var, offset, features);
if (splitVar == RPartConverter.INDEX_LEAF) {
Node result = new CountingLeafNode(null, predicate).setId(id);
return scoreEncoder.encode(result, offset);
}
int leftRowName = rowName * 2;
int rightRowName = (rowName * 2) + 1;
Integer majorityDir = null;
if (this.useSurrogate == 2) {
int leftOffset = getIndex(rowNames, leftRowName);
int rightOffset = getIndex(rowNames, rightRowName);
majorityDir = Double.compare(n.getValue(leftOffset), n.getValue(rightOffset));
}
Feature feature = features.get(splitVar - 1);
int splitOffset = splitInfo[offset][0];
int splitNumCompete = splitInfo[offset][1];
int splitNumSurrogate = splitInfo[offset][2];
List<Predicate> predicates = encodePredicates(feature, splitOffset, splits, csplit);
Predicate leftPredicate = predicates.get(0);
Predicate rightPredicate = predicates.get(1);
if (this.useSurrogate > 0 && splitNumSurrogate > 0) {
CompoundPredicate leftCompoundPredicate = new CompoundPredicate(CompoundPredicate.BooleanOperator.SURROGATE, null).addPredicates(leftPredicate);
CompoundPredicate rightCompoundPredicate = new CompoundPredicate(CompoundPredicate.BooleanOperator.SURROGATE, null).addPredicates(rightPredicate);
RStringVector splitRowNames = splits.dimnames(0);
for (int i = 0; i < splitNumSurrogate; i++) {
int surrogateSplitOffset = (splitOffset + 1) + splitNumCompete + i;
feature = getFeature(splitRowNames.getValue(surrogateSplitOffset));
predicates = encodePredicates(feature, surrogateSplitOffset, splits, csplit);
leftCompoundPredicate.addPredicates(predicates.get(0));
rightCompoundPredicate.addPredicates(predicates.get(1));
}
leftPredicate = leftCompoundPredicate;
rightPredicate = rightCompoundPredicate;
}
Node leftChild = encodeNode(leftPredicate, leftRowName, rowNames, var, n, splitInfo, splits, csplit, scoreEncoder, schema);
Node rightChild = encodeNode(rightPredicate, rightRowName, rowNames, var, n, splitInfo, splits, csplit, scoreEncoder, schema);
if (this.useSurrogate == 2) {
if (majorityDir < 0) {
makeDefault(rightChild);
} else if (majorityDir > 0) {
Node tmp = leftChild;
makeDefault(leftChild);
leftChild = rightChild;
rightChild = tmp;
}
}
Node result = new CountingBranchNode(null, predicate).setId(id).addNodes(leftChild, rightChild);
return scoreEncoder.encode(result, offset);
}
Aggregations