use of org.dmg.pmml.Predicate in project jpmml-r by jpmml.
the class IForestConverter method encodeNode.
private void encodeNode(Node node, int index, int depth, List<Integer> nodeStatus, List<Integer> nodeSize, List<Integer> leftDaughter, List<Integer> rightDaughter, List<Integer> splitAtt, List<Double> splitValue, Schema schema) {
int status = nodeStatus.get(index);
int size = nodeSize.get(index);
node.setId(String.valueOf(index + 1));
// Interior node
if (status == -3) {
int att = splitAtt.get(index);
ContinuousFeature feature = (ContinuousFeature) schema.getFeature(att - 1);
String value = ValueUtil.formatValue(splitValue.get(index));
Predicate leftPredicate = createSimplePredicate(feature, SimplePredicate.Operator.LESS_THAN, value);
Node leftChild = new Node().setPredicate(leftPredicate);
int leftIndex = (leftDaughter.get(index) - 1);
encodeNode(leftChild, leftIndex, depth + 1, nodeStatus, nodeSize, leftDaughter, rightDaughter, splitAtt, splitValue, schema);
Predicate rightPredicate = createSimplePredicate(feature, SimplePredicate.Operator.GREATER_OR_EQUAL, value);
Node rightChild = new Node().setPredicate(rightPredicate);
int rightIndex = (rightDaughter.get(index) - 1);
encodeNode(rightChild, rightIndex, depth + 1, nodeStatus, nodeSize, leftDaughter, rightDaughter, splitAtt, splitValue, schema);
node.addNodes(leftChild, rightChild);
} else // Terminal node
if (status == -1) {
node.setScore(ValueUtil.formatValue(depth + avgPathLength(size)));
} else {
throw new IllegalArgumentException();
}
}
use of org.dmg.pmml.Predicate in project jpmml-r by jpmml.
the class RandomForestConverter method encodeNode.
private <P extends Number> void encodeNode(Node node, int i, ScoreEncoder<P> scoreEncoder, List<? extends Number> leftDaughter, List<? extends Number> rightDaughter, List<? extends Number> bestvar, List<Double> xbestsplit, List<P> nodepred, Schema schema) {
Predicate leftPredicate;
Predicate rightPredicate;
int var = ValueUtil.asInt(bestvar.get(i));
if (var != 0) {
Feature feature = schema.getFeature(var - 1);
Double split = xbestsplit.get(i);
if (feature instanceof BooleanFeature) {
BooleanFeature booleanFeature = (BooleanFeature) feature;
if (split != 0.5d) {
throw new IllegalArgumentException();
}
leftPredicate = createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(0));
rightPredicate = createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(1));
} else if (feature instanceof CategoricalFeature) {
CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
List<String> values = categoricalFeature.getValues();
leftPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, split, true));
rightPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, split, false));
} else {
ContinuousFeature continuousFeature = feature.toContinuousFeature();
String value = ValueUtil.formatValue(split);
leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
}
} else {
P prediction = nodepred.get(i);
node.setScore(scoreEncoder.encode(prediction));
return;
}
int left = ValueUtil.asInt(leftDaughter.get(i));
if (left != 0) {
Node leftChild = new Node().setId(String.valueOf(left)).setPredicate(leftPredicate);
encodeNode(leftChild, left - 1, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit, nodepred, schema);
node.addNodes(leftChild);
}
int right = ValueUtil.asInt(rightDaughter.get(i));
if (right != 0) {
Node rightChild = new Node().setId(String.valueOf(right)).setPredicate(rightPredicate);
encodeNode(rightChild, right - 1, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit, nodepred, schema);
node.addNodes(rightChild);
}
}
use of org.dmg.pmml.Predicate in project jpmml-r by jpmml.
the class RangerConverter method encodeNode.
private void encodeNode(Node node, int index, ScoreEncoder scoreEncoder, RNumberVector<?> leftChildIDs, RNumberVector<?> rightChildIDs, RNumberVector<?> splitVarIDs, RNumberVector<?> splitValues, RGenericVector terminalClassCounts, Schema schema) {
int leftIndex = ValueUtil.asInt(leftChildIDs.getValue(index));
int rightIndex = ValueUtil.asInt(rightChildIDs.getValue(index));
Number splitValue = splitValues.getValue(index);
RNumberVector<?> terminalClassCount = (terminalClassCounts != null ? (RNumberVector<?>) terminalClassCounts.getValue(index) : null);
if (leftIndex == 0 && rightIndex == 0) {
scoreEncoder.encode(node, splitValue, terminalClassCount);
return;
}
Predicate leftPredicate;
Predicate rightPredicate;
int splitVarIndex = ValueUtil.asInt(splitVarIDs.getValue(index));
Feature feature = schema.getFeature(splitVarIndex - 1);
if (feature instanceof CategoricalFeature) {
CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
int splitLevelIndex = ValueUtil.asInt(Math.floor(splitValue.doubleValue()));
List<String> values = categoricalFeature.getValues();
leftPredicate = createSimpleSetPredicate(categoricalFeature, values.subList(0, splitLevelIndex));
rightPredicate = createSimpleSetPredicate(categoricalFeature, values.subList(splitLevelIndex, values.size()));
} else {
ContinuousFeature continuousFeature = feature.toContinuousFeature();
String value = ValueUtil.formatValue(splitValue);
leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
}
Node leftChild = new Node().setPredicate(leftPredicate);
encodeNode(leftChild, leftIndex, scoreEncoder, leftChildIDs, rightChildIDs, splitVarIDs, splitValues, terminalClassCounts, schema);
Node rightChild = new Node().setPredicate(rightPredicate);
encodeNode(rightChild, rightIndex, scoreEncoder, leftChildIDs, rightChildIDs, splitVarIDs, splitValues, terminalClassCounts, schema);
node.addNodes(leftChild, rightChild);
}
use of org.dmg.pmml.Predicate in project jpmml-sparkml by jpmml.
the class TreeModelUtil method encodeNode.
public static Node encodeNode(org.apache.spark.ml.tree.Node node, PredicateManager predicateManager, Map<FieldName, Set<String>> parentFieldValues, MiningFunction miningFunction, Schema schema) {
if (node instanceof InternalNode) {
InternalNode internalNode = (InternalNode) node;
Map<FieldName, Set<String>> leftFieldValues = parentFieldValues;
Map<FieldName, Set<String>> rightFieldValues = parentFieldValues;
Predicate leftPredicate;
Predicate rightPredicate;
Split split = internalNode.split();
Feature feature = schema.getFeature(split.featureIndex());
if (split instanceof ContinuousSplit) {
ContinuousSplit continuousSplit = (ContinuousSplit) split;
double threshold = continuousSplit.threshold();
if (feature instanceof BooleanFeature) {
BooleanFeature booleanFeature = (BooleanFeature) feature;
if (threshold != 0.5d) {
throw new IllegalArgumentException();
}
leftPredicate = predicateManager.createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(0));
rightPredicate = predicateManager.createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(1));
} else {
ContinuousFeature continuousFeature = feature.toContinuousFeature();
String value = ValueUtil.formatValue(threshold);
leftPredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
rightPredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
}
} else if (split instanceof CategoricalSplit) {
CategoricalSplit categoricalSplit = (CategoricalSplit) split;
double[] leftCategories = categoricalSplit.leftCategories();
double[] rightCategories = categoricalSplit.rightCategories();
if (feature instanceof BinaryFeature) {
BinaryFeature binaryFeature = (BinaryFeature) feature;
SimplePredicate.Operator leftOperator;
SimplePredicate.Operator rightOperator;
if (Arrays.equals(TRUE, leftCategories) && Arrays.equals(FALSE, rightCategories)) {
leftOperator = SimplePredicate.Operator.EQUAL;
rightOperator = SimplePredicate.Operator.NOT_EQUAL;
} else if (Arrays.equals(FALSE, leftCategories) && Arrays.equals(TRUE, rightCategories)) {
leftOperator = SimplePredicate.Operator.NOT_EQUAL;
rightOperator = SimplePredicate.Operator.EQUAL;
} else {
throw new IllegalArgumentException();
}
String value = ValueUtil.formatValue(binaryFeature.getValue());
leftPredicate = predicateManager.createSimplePredicate(binaryFeature, leftOperator, value);
rightPredicate = predicateManager.createSimplePredicate(binaryFeature, rightOperator, value);
} else if (feature instanceof CategoricalFeature) {
CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
FieldName name = categoricalFeature.getName();
List<String> values = categoricalFeature.getValues();
if (values.size() != (leftCategories.length + rightCategories.length)) {
throw new IllegalArgumentException();
}
final Set<String> parentValues = parentFieldValues.get(name);
com.google.common.base.Predicate<String> valueFilter = new com.google.common.base.Predicate<String>() {
@Override
public boolean apply(String value) {
if (parentValues != null) {
return parentValues.contains(value);
}
return true;
}
};
List<String> leftValues = selectValues(values, leftCategories, valueFilter);
List<String> rightValues = selectValues(values, rightCategories, valueFilter);
leftFieldValues = new HashMap<>(parentFieldValues);
leftFieldValues.put(name, new HashSet<>(leftValues));
rightFieldValues = new HashMap<>(parentFieldValues);
rightFieldValues.put(name, new HashSet<>(rightValues));
leftPredicate = predicateManager.createSimpleSetPredicate(categoricalFeature, leftValues);
rightPredicate = predicateManager.createSimpleSetPredicate(categoricalFeature, rightValues);
} else {
throw new IllegalArgumentException();
}
} else {
throw new IllegalArgumentException();
}
Node result = new Node();
Node leftChild = encodeNode(internalNode.leftChild(), predicateManager, leftFieldValues, miningFunction, schema).setPredicate(leftPredicate);
Node rightChild = encodeNode(internalNode.rightChild(), predicateManager, rightFieldValues, miningFunction, schema).setPredicate(rightPredicate);
result.addNodes(leftChild, rightChild);
return result;
} else if (node instanceof LeafNode) {
LeafNode leafNode = (LeafNode) node;
Node result = new Node();
switch(miningFunction) {
case REGRESSION:
{
String score = ValueUtil.formatValue(node.prediction());
result.setScore(score);
}
break;
case CLASSIFICATION:
{
CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
int index = ValueUtil.asInt(node.prediction());
result.setScore(categoricalLabel.getValue(index));
ImpurityCalculator impurityCalculator = node.impurityStats();
result.setRecordCount((double) impurityCalculator.count());
double[] stats = impurityCalculator.stats();
for (int i = 0; i < stats.length; i++) {
ScoreDistribution scoreDistribution = new ScoreDistribution(categoricalLabel.getValue(i), stats[i]);
result.addScoreDistributions(scoreDistribution);
}
}
break;
default:
throw new UnsupportedOperationException();
}
return result;
} else {
throw new IllegalArgumentException();
}
}
use of org.dmg.pmml.Predicate in project jpmml-r by jpmml.
the class GBMConverter method encodeNode.
private void encodeNode(Node node, int i, RGenericVector tree, RGenericVector c_splits, Schema schema) {
RIntegerVector splitVar = (RIntegerVector) tree.getValue(0);
RDoubleVector splitCodePred = (RDoubleVector) tree.getValue(1);
RIntegerVector leftNode = (RIntegerVector) tree.getValue(2);
RIntegerVector rightNode = (RIntegerVector) tree.getValue(3);
RIntegerVector missingNode = (RIntegerVector) tree.getValue(4);
RDoubleVector prediction = (RDoubleVector) tree.getValue(7);
Predicate missingPredicate;
Predicate leftPredicate;
Predicate rightPredicate;
Integer var = splitVar.getValue(i);
if (var != -1) {
Feature feature = schema.getFeature(var);
missingPredicate = createSimplePredicate(feature, SimplePredicate.Operator.IS_MISSING, null);
Double split = splitCodePred.getValue(i);
if (feature instanceof CategoricalFeature) {
CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
List<String> values = categoricalFeature.getValues();
int index = ValueUtil.asInt(split);
RIntegerVector c_split = (RIntegerVector) c_splits.getValue(index);
List<Integer> splitValues = c_split.getValues();
leftPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, splitValues, true));
rightPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, splitValues, false));
} else {
ContinuousFeature continuousFeature = feature.toContinuousFeature();
String value = ValueUtil.formatValue(split);
leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_THAN, value);
rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_OR_EQUAL, value);
}
} else {
Double value = prediction.getValue(i);
node.setScore(ValueUtil.formatValue(value));
return;
}
Integer missing = missingNode.getValue(i);
if (missing != -1) {
Node missingChild = new Node().setId(String.valueOf(missing + 1)).setPredicate(missingPredicate);
encodeNode(missingChild, missing, tree, c_splits, schema);
node.addNodes(missingChild);
}
Integer left = leftNode.getValue(i);
if (left != -1) {
Node leftChild = new Node().setId(String.valueOf(left + 1)).setPredicate(leftPredicate);
encodeNode(leftChild, left, tree, c_splits, schema);
node.addNodes(leftChild);
}
Integer right = rightNode.getValue(i);
if (right != -1) {
Node rightChild = new Node().setId(String.valueOf(right + 1)).setPredicate(rightPredicate);
encodeNode(rightChild, right, tree, c_splits, schema);
node.addNodes(rightChild);
}
}
Aggregations