use of org.jpmml.converter.CategoryManager in project jpmml-r by jpmml.
the class GBMConverter method encodeNode.
private Node encodeNode(int i, Predicate predicate, RGenericVector tree, RGenericVector c_splits, FlagManager flagManager, CategoryManager categoryManager, Schema schema) {
Integer id = Integer.valueOf(i + 1);
RIntegerVector splitVar = tree.getIntegerValue(0);
RDoubleVector splitCodePred = tree.getDoubleValue(1);
RIntegerVector leftNode = tree.getIntegerValue(2);
RIntegerVector rightNode = tree.getIntegerValue(3);
RIntegerVector missingNode = tree.getIntegerValue(4);
RDoubleVector prediction = tree.getDoubleValue(7);
Integer var = splitVar.getValue(i);
if (var == -1) {
Double value = prediction.getValue(i);
Node result = new LeafNode(value, predicate).setId(id);
return result;
}
Boolean isMissing;
FlagManager missingFlagManager = flagManager;
FlagManager nonMissingFlagManager = flagManager;
Predicate missingPredicate;
Feature feature = schema.getFeature(var);
{
String name = feature.getName();
isMissing = flagManager.getValue(name);
if (isMissing == null) {
missingFlagManager = missingFlagManager.fork(name, Boolean.TRUE);
nonMissingFlagManager = nonMissingFlagManager.fork(name, Boolean.FALSE);
}
missingPredicate = createSimplePredicate(feature, SimplePredicate.Operator.IS_MISSING, null);
}
CategoryManager leftCategoryManager = categoryManager;
CategoryManager rightCategoryManager = categoryManager;
Predicate leftPredicate;
Predicate rightPredicate;
Double split = splitCodePred.getValue(i);
if (feature instanceof CategoricalFeature) {
CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
String name = categoricalFeature.getName();
List<?> values = categoricalFeature.getValues();
int index = ValueUtil.asInt(split);
RIntegerVector c_split = c_splits.getIntegerValue(index);
List<Integer> splitValues = c_split.getValues();
java.util.function.Predicate<Object> valueFilter = categoryManager.getValueFilter(name);
List<Object> leftValues = selectValues(values, valueFilter, splitValues, true);
List<Object> rightValues = selectValues(values, valueFilter, splitValues, false);
leftCategoryManager = leftCategoryManager.fork(name, leftValues);
rightCategoryManager = rightCategoryManager.fork(name, rightValues);
leftPredicate = createPredicate(categoricalFeature, leftValues);
rightPredicate = createPredicate(categoricalFeature, rightValues);
} else {
ContinuousFeature continuousFeature = feature.toContinuousFeature();
leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_THAN, split);
rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_OR_EQUAL, split);
}
Node result = new BranchNode(null, predicate).setId(id);
List<Node> nodes = result.getNodes();
Integer missing = missingNode.getValue(i);
if (missing != -1 && (isMissing == null || isMissing)) {
Node missingChild = encodeNode(missing, missingPredicate, tree, c_splits, missingFlagManager, categoryManager, schema);
nodes.add(missingChild);
}
Integer left = leftNode.getValue(i);
if (left != -1 && (isMissing == null || !isMissing)) {
Node leftChild = encodeNode(left, leftPredicate, tree, c_splits, nonMissingFlagManager, leftCategoryManager, schema);
nodes.add(leftChild);
}
Integer right = rightNode.getValue(i);
if (right != -1 && (isMissing == null || !isMissing)) {
Node rightChild = encodeNode(right, rightPredicate, tree, c_splits, nonMissingFlagManager, rightCategoryManager, schema);
nodes.add(rightChild);
}
return result;
}
use of org.jpmml.converter.CategoryManager in project jpmml-r by jpmml.
the class GBMConverter method encodeTreeModel.
private TreeModel encodeTreeModel(MiningFunction miningFunction, RGenericVector tree, RGenericVector c_splits, Schema schema) {
Node root = encodeNode(0, True.INSTANCE, tree, c_splits, new FlagManager(), new CategoryManager(), schema);
TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.MULTI_SPLIT);
return treeModel;
}
use of org.jpmml.converter.CategoryManager in project jpmml-r by jpmml.
the class RangerConverter method encodeNode.
private Node encodeNode(Predicate predicate, int index, ScoreEncoder scoreEncoder, RNumberVector<?> leftChildIDs, RNumberVector<?> rightChildIDs, RNumberVector<?> splitVarIDs, RNumberVector<?> splitValues, RGenericVector terminalClassCounts, CategoryManager categoryManager, Schema schema) {
int leftIndex = ValueUtil.asInt(leftChildIDs.getValue(index));
int rightIndex = ValueUtil.asInt(rightChildIDs.getValue(index));
Number splitValue = splitValues.getValue(index);
RNumberVector<?> terminalClassCount = (terminalClassCounts != null ? terminalClassCounts.getNumericValue(index) : null);
if (leftIndex == 0 && rightIndex == 0) {
Node result = new LeafNode(null, predicate);
return scoreEncoder.encode(result, splitValue, terminalClassCount);
}
CategoryManager leftCategoryManager = categoryManager;
CategoryManager rightCategoryManager = categoryManager;
Predicate leftPredicate;
Predicate rightPredicate;
int splitVarIndex = ValueUtil.asInt(splitVarIDs.getValue(index));
Feature feature = schema.getFeature(this.hasDependentVar ? (splitVarIndex - 1) : splitVarIndex);
if (feature instanceof CategoricalFeature) {
CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
int splitLevelIndex = ValueUtil.asInt(Math.floor(splitValue.doubleValue()));
String name = categoricalFeature.getName();
List<?> values = categoricalFeature.getValues();
java.util.function.Predicate<Object> valueFilter = categoryManager.getValueFilter(name);
List<Object> leftValues = filterValues(values.subList(0, splitLevelIndex), valueFilter);
List<Object> rightValues = filterValues(values.subList(splitLevelIndex, values.size()), valueFilter);
leftCategoryManager = leftCategoryManager.fork(name, leftValues);
rightCategoryManager = rightCategoryManager.fork(name, rightValues);
leftPredicate = createPredicate(categoricalFeature, leftValues);
rightPredicate = createPredicate(categoricalFeature, rightValues);
} else {
ContinuousFeature continuousFeature = feature.toContinuousFeature();
leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, splitValue);
rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, splitValue);
}
Node leftChild = encodeNode(leftPredicate, leftIndex, scoreEncoder, leftChildIDs, rightChildIDs, splitVarIDs, splitValues, terminalClassCounts, leftCategoryManager, schema);
Node rightChild = encodeNode(rightPredicate, rightIndex, scoreEncoder, leftChildIDs, rightChildIDs, splitVarIDs, splitValues, terminalClassCounts, rightCategoryManager, schema);
Node result = new BranchNode(null, predicate).addNodes(leftChild, rightChild);
return result;
}
use of org.jpmml.converter.CategoryManager in project jpmml-r by jpmml.
the class RangerConverter method encodeTreeModel.
private TreeModel encodeTreeModel(MiningFunction miningFunction, ScoreEncoder scoreEncoder, RGenericVector childNodeIDs, RNumberVector<?> splitVarIDs, RNumberVector<?> splitValues, RGenericVector terminalClassCounts, Schema schema) {
RNumberVector<?> leftChildIDs = childNodeIDs.getNumericValue(0);
RNumberVector<?> rightChildIDs = childNodeIDs.getNumericValue(1);
Node root = encodeNode(True.INSTANCE, 0, scoreEncoder, leftChildIDs, rightChildIDs, splitVarIDs, splitValues, terminalClassCounts, new CategoryManager(), schema);
TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
return treeModel;
}
use of org.jpmml.converter.CategoryManager in project jpmml-sparkml by jpmml.
the class TreeModelUtil method encodeNode.
private static Node encodeNode(Predicate predicate, ScoreEncoder scoreEncoder, org.apache.spark.ml.tree.Node sparkNode, PredicateManager predicateManager, CategoryManager categoryManager, Schema schema) {
if (sparkNode instanceof org.apache.spark.ml.tree.LeafNode) {
org.apache.spark.ml.tree.LeafNode leafNode = (org.apache.spark.ml.tree.LeafNode) sparkNode;
Node result = new LeafNode(null, predicate);
return scoreEncoder.encode(result, leafNode);
} else if (sparkNode instanceof org.apache.spark.ml.tree.InternalNode) {
org.apache.spark.ml.tree.InternalNode internalNode = (org.apache.spark.ml.tree.InternalNode) sparkNode;
CategoryManager leftCategoryManager = categoryManager;
CategoryManager rightCategoryManager = categoryManager;
Predicate leftPredicate;
Predicate rightPredicate;
Split split = internalNode.split();
Feature feature = schema.getFeature(split.featureIndex());
if (split instanceof ContinuousSplit) {
ContinuousSplit continuousSplit = (ContinuousSplit) split;
Double threshold = continuousSplit.threshold();
if (feature instanceof BooleanFeature) {
BooleanFeature booleanFeature = (BooleanFeature) feature;
if (threshold != 0.5d) {
throw new IllegalArgumentException("Invalid split threshold value " + threshold + " for a boolean feature");
}
leftPredicate = predicateManager.createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(0));
rightPredicate = predicateManager.createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(1));
} else {
ContinuousFeature continuousFeature = feature.toContinuousFeature();
DataType dataType = continuousFeature.getDataType();
switch(dataType) {
case INTEGER:
threshold = Math.floor(threshold);
break;
default:
break;
}
leftPredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, threshold);
rightPredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, threshold);
}
} else if (split instanceof CategoricalSplit) {
CategoricalSplit categoricalSplit = (CategoricalSplit) split;
double[] leftCategories = categoricalSplit.leftCategories();
double[] rightCategories = categoricalSplit.rightCategories();
if (feature instanceof BinaryFeature) {
BinaryFeature binaryFeature = (BinaryFeature) feature;
SimplePredicate.Operator leftOperator;
SimplePredicate.Operator rightOperator;
if (Arrays.equals(TRUE, leftCategories) && Arrays.equals(FALSE, rightCategories)) {
leftOperator = SimplePredicate.Operator.EQUAL;
rightOperator = SimplePredicate.Operator.NOT_EQUAL;
} else if (Arrays.equals(FALSE, leftCategories) && Arrays.equals(TRUE, rightCategories)) {
leftOperator = SimplePredicate.Operator.NOT_EQUAL;
rightOperator = SimplePredicate.Operator.EQUAL;
} else {
throw new IllegalArgumentException();
}
Object value = binaryFeature.getValue();
leftPredicate = predicateManager.createSimplePredicate(binaryFeature, leftOperator, value);
rightPredicate = predicateManager.createSimplePredicate(binaryFeature, rightOperator, value);
} else if (feature instanceof CategoricalFeature) {
CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
FieldName name = categoricalFeature.getName();
List<?> values = categoricalFeature.getValues();
if (values.size() != (leftCategories.length + rightCategories.length)) {
throw new IllegalArgumentException();
}
java.util.function.Predicate<Object> valueFilter = categoryManager.getValueFilter(name);
List<Object> leftValues = selectValues(values, leftCategories, valueFilter);
List<Object> rightValues = selectValues(values, rightCategories, valueFilter);
leftCategoryManager = categoryManager.fork(name, leftValues);
rightCategoryManager = categoryManager.fork(name, rightValues);
leftPredicate = predicateManager.createPredicate(categoricalFeature, leftValues);
rightPredicate = predicateManager.createPredicate(categoricalFeature, rightValues);
} else {
throw new IllegalArgumentException();
}
} else {
throw new IllegalArgumentException();
}
Node leftChild = encodeNode(leftPredicate, scoreEncoder, internalNode.leftChild(), predicateManager, leftCategoryManager, schema);
Node rightChild = encodeNode(rightPredicate, scoreEncoder, internalNode.rightChild(), predicateManager, rightCategoryManager, schema);
Node result = new BranchNode(null, predicate).addNodes(leftChild, rightChild);
return result;
} else {
throw new IllegalArgumentException();
}
}
Aggregations