use of org.dmg.pmml.tree.Node in project jpmml-r by jpmml.
the class GBMConverter method encodeNode.
private void encodeNode(Node node, int i, RGenericVector tree, RGenericVector c_splits, Schema schema) {
RIntegerVector splitVar = (RIntegerVector) tree.getValue(0);
RDoubleVector splitCodePred = (RDoubleVector) tree.getValue(1);
RIntegerVector leftNode = (RIntegerVector) tree.getValue(2);
RIntegerVector rightNode = (RIntegerVector) tree.getValue(3);
RIntegerVector missingNode = (RIntegerVector) tree.getValue(4);
RDoubleVector prediction = (RDoubleVector) tree.getValue(7);
Predicate missingPredicate;
Predicate leftPredicate;
Predicate rightPredicate;
Integer var = splitVar.getValue(i);
if (var != -1) {
Feature feature = schema.getFeature(var);
missingPredicate = createSimplePredicate(feature, SimplePredicate.Operator.IS_MISSING, null);
Double split = splitCodePred.getValue(i);
if (feature instanceof CategoricalFeature) {
CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
List<String> values = categoricalFeature.getValues();
int index = ValueUtil.asInt(split);
RIntegerVector c_split = (RIntegerVector) c_splits.getValue(index);
List<Integer> splitValues = c_split.getValues();
leftPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, splitValues, true));
rightPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, splitValues, false));
} else {
ContinuousFeature continuousFeature = feature.toContinuousFeature();
String value = ValueUtil.formatValue(split);
leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_THAN, value);
rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_OR_EQUAL, value);
}
} else {
Double value = prediction.getValue(i);
node.setScore(ValueUtil.formatValue(value));
return;
}
Integer missing = missingNode.getValue(i);
if (missing != -1) {
Node missingChild = new Node().setId(String.valueOf(missing + 1)).setPredicate(missingPredicate);
encodeNode(missingChild, missing, tree, c_splits, schema);
node.addNodes(missingChild);
}
Integer left = leftNode.getValue(i);
if (left != -1) {
Node leftChild = new Node().setId(String.valueOf(left + 1)).setPredicate(leftPredicate);
encodeNode(leftChild, left, tree, c_splits, schema);
node.addNodes(leftChild);
}
Integer right = rightNode.getValue(i);
if (right != -1) {
Node rightChild = new Node().setId(String.valueOf(right + 1)).setPredicate(rightPredicate);
encodeNode(rightChild, right, tree, c_splits, schema);
node.addNodes(rightChild);
}
}
use of org.dmg.pmml.tree.Node in project jpmml-r by jpmml.
the class BinaryTreeConverter method encodeNode.
private void encodeNode(Node node, RGenericVector tree, Schema schema) {
RIntegerVector nodeId = (RIntegerVector) tree.getValue("nodeID");
RBooleanVector terminal = (RBooleanVector) tree.getValue("terminal");
RGenericVector psplit = (RGenericVector) tree.getValue("psplit");
RGenericVector ssplits = (RGenericVector) tree.getValue("ssplits");
RDoubleVector prediction = (RDoubleVector) tree.getValue("prediction");
RGenericVector left = (RGenericVector) tree.getValue("left");
RGenericVector right = (RGenericVector) tree.getValue("right");
node.setId(String.valueOf(nodeId.asScalar()));
if ((Boolean.TRUE).equals(terminal.asScalar())) {
node = encodeScore(node, prediction, schema);
return;
}
RNumberVector<?> splitpoint = (RNumberVector<?>) psplit.getValue("splitpoint");
RStringVector variableName = (RStringVector) psplit.getValue("variableName");
if (ssplits.size() > 0) {
throw new IllegalArgumentException();
}
Predicate leftPredicate;
Predicate rightPredicate;
FieldName name = FieldName.create(variableName.asScalar());
Integer index = this.featureIndexes.get(name);
if (index == null) {
throw new IllegalArgumentException();
}
Feature feature = schema.getFeature(index);
if (feature instanceof CategoricalFeature) {
CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
List<String> values = categoricalFeature.getValues();
List<Integer> splitValues = (List<Integer>) splitpoint.getValues();
leftPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, splitValues, true));
rightPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, splitValues, false));
} else {
ContinuousFeature continuousFeature = feature.toContinuousFeature();
String value = ValueUtil.formatValue((Double) splitpoint.asScalar());
leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
}
Node leftChild = new Node().setPredicate(leftPredicate);
encodeNode(leftChild, left, schema);
Node rightChild = new Node().setPredicate(rightPredicate);
encodeNode(rightChild, right, schema);
node.addNodes(leftChild, rightChild);
}
use of org.dmg.pmml.tree.Node in project jpmml-sparkml by jpmml.
the class TreeModelUtil method encodeTreeModel.
public static TreeModel encodeTreeModel(org.apache.spark.ml.tree.Node node, PredicateManager predicateManager, MiningFunction miningFunction, Schema schema) {
Node root = encodeNode(node, predicateManager, Collections.<FieldName, Set<String>>emptyMap(), miningFunction, schema).setPredicate(new True());
TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
String compact = TreeModelOptions.COMPACT;
if (compact != null && Boolean.valueOf(compact)) {
Visitor visitor = new TreeModelCompactor();
visitor.applyTo(treeModel);
}
return treeModel;
}
use of org.dmg.pmml.tree.Node in project jpmml-sparkml by jpmml.
the class TreeModelCompactor method handleNodePop.
private void handleNodePop(Node node) {
String score = node.getScore();
Predicate predicate = node.getPredicate();
if (predicate instanceof True) {
Node parentNode = getParentNode();
if (parentNode == null) {
return;
}
String parentScore = parentNode.getScore();
if (parentScore != null) {
throw new IllegalArgumentException();
}
if ((MiningFunction.REGRESSION).equals(this.miningFunction)) {
parentNode.setScore(score);
List<Node> parentChildren = parentNode.getNodes();
boolean success = parentChildren.remove(node);
if (!success) {
throw new IllegalArgumentException();
}
if (node.hasNodes()) {
List<Node> children = node.getNodes();
parentChildren.addAll(children);
}
} else if ((MiningFunction.CLASSIFICATION).equals(this.miningFunction)) {
if (node.hasNodes()) {
List<Node> parentChildren = parentNode.getNodes();
boolean success = parentChildren.remove(node);
if (!success) {
throw new IllegalArgumentException();
}
List<Node> children = node.getNodes();
parentChildren.addAll(children);
}
} else {
throw new IllegalArgumentException();
}
}
}
use of org.dmg.pmml.tree.Node in project jpmml-sparkml by jpmml.
the class TreeModelCompactor method handleNodePush.
private void handleNodePush(Node node) {
String id = node.getId();
String score = node.getScore();
if (id != null) {
throw new IllegalArgumentException();
}
if (node.hasNodes()) {
List<Node> children = node.getNodes();
if (children.size() != 2 || score != null) {
throw new IllegalArgumentException();
}
Node firstChild = children.get(0);
Node secondChild = children.get(1);
Predicate firstPredicate = firstChild.getPredicate();
Predicate secondPredicate = secondChild.getPredicate();
predicate: if (firstPredicate instanceof SimplePredicate && secondPredicate instanceof SimplePredicate) {
SimplePredicate firstSimplePredicate = (SimplePredicate) firstPredicate;
SimplePredicate secondSimplePredicate = (SimplePredicate) secondPredicate;
SimplePredicate.Operator firstOperator = firstSimplePredicate.getOperator();
SimplePredicate.Operator secondOperator = secondSimplePredicate.getOperator();
if (!(firstSimplePredicate.getField()).equals(secondSimplePredicate.getField())) {
throw new IllegalArgumentException();
}
if ((SimplePredicate.Operator.EQUAL).equals(firstOperator) && (SimplePredicate.Operator.EQUAL).equals(secondOperator)) {
if (!isCategoricalField(firstSimplePredicate.getField())) {
break predicate;
}
secondChild.setPredicate(new True());
} else {
if (!(firstSimplePredicate.getValue()).equals(secondSimplePredicate.getValue())) {
throw new IllegalArgumentException();
}
if ((SimplePredicate.Operator.NOT_EQUAL).equals(firstOperator) && (SimplePredicate.Operator.EQUAL).equals(secondOperator)) {
swapChildren(children);
firstChild = children.get(0);
secondChild = children.get(1);
} else if ((SimplePredicate.Operator.EQUAL).equals(firstOperator) && (SimplePredicate.Operator.NOT_EQUAL).equals(secondOperator)) {
// Ignored
} else if ((SimplePredicate.Operator.LESS_OR_EQUAL).equals(firstOperator) && (SimplePredicate.Operator.GREATER_THAN).equals(secondOperator)) {
// Ignored
} else {
throw new IllegalArgumentException();
}
secondChild.setPredicate(new True());
}
} else if (firstPredicate instanceof SimplePredicate && secondPredicate instanceof SimpleSetPredicate) {
SimplePredicate simplePredicate = (SimplePredicate) firstPredicate;
SimpleSetPredicate simpleSetPredicate = (SimpleSetPredicate) secondPredicate;
if (!(simplePredicate.getField()).equals(simpleSetPredicate.getField()) || !(SimplePredicate.Operator.EQUAL).equals(simplePredicate.getOperator()) || !(SimpleSetPredicate.BooleanOperator.IS_IN).equals(simpleSetPredicate.getBooleanOperator())) {
throw new IllegalArgumentException();
}
secondChild.setPredicate(addCategoricalField(simpleSetPredicate));
} else if (firstPredicate instanceof SimpleSetPredicate && secondPredicate instanceof SimplePredicate) {
SimpleSetPredicate simpleSetPredicate = (SimpleSetPredicate) firstPredicate;
SimplePredicate simplePredicate = (SimplePredicate) secondPredicate;
if (!(simpleSetPredicate.getField()).equals(simplePredicate.getField()) || !(SimpleSetPredicate.BooleanOperator.IS_IN).equals(simpleSetPredicate.getBooleanOperator()) || !(SimplePredicate.Operator.EQUAL).equals(simplePredicate.getOperator())) {
throw new IllegalArgumentException();
}
swapChildren(children);
firstChild = children.get(0);
secondChild = children.get(1);
secondChild.setPredicate(addCategoricalField(simpleSetPredicate));
} else if (firstPredicate instanceof SimpleSetPredicate && secondPredicate instanceof SimpleSetPredicate) {
SimpleSetPredicate firstSimpleSetPredicate = (SimpleSetPredicate) firstPredicate;
SimpleSetPredicate secondSimpleSetPredicate = (SimpleSetPredicate) secondPredicate;
if (!(firstSimpleSetPredicate.getField()).equals(secondSimpleSetPredicate.getField()) || !(SimpleSetPredicate.BooleanOperator.IS_IN).equals(firstSimpleSetPredicate.getBooleanOperator()) || !(SimpleSetPredicate.BooleanOperator.IS_IN).equals(secondSimpleSetPredicate.getBooleanOperator())) {
throw new IllegalArgumentException();
}
secondChild.setPredicate(addCategoricalField(secondSimpleSetPredicate));
} else {
throw new IllegalArgumentException();
}
} else {
if (score == null) {
throw new IllegalArgumentException();
}
}
}
Aggregations