use of org.dmg.pmml.tree.Node in project jpmml-r by jpmml.
the class RandomForestConverter method encodeTreeModel.
private <P extends Number> TreeModel encodeTreeModel(MiningFunction miningFunction, ScoreEncoder<P> scoreEncoder, List<? extends Number> leftDaughter, List<? extends Number> rightDaughter, List<P> nodepred, List<? extends Number> bestvar, List<Double> xbestsplit, Schema schema) {
Node root = new Node().setId("1").setPredicate(new True());
encodeNode(root, 0, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit, nodepred, schema);
TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
return treeModel;
}
use of org.dmg.pmml.tree.Node in project jpmml-sparkml by jpmml.
the class TreeModelUtil method encodeTreeModel.
public static TreeModel encodeTreeModel(org.apache.spark.ml.tree.Node node, PredicateManager predicateManager, MiningFunction miningFunction, Schema schema) {
Node root = encodeNode(node, predicateManager, Collections.<FieldName, Set<String>>emptyMap(), miningFunction, schema).setPredicate(new True());
TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
String compact = TreeModelOptions.COMPACT;
if (compact != null && Boolean.valueOf(compact)) {
Visitor visitor = new TreeModelCompactor();
visitor.applyTo(treeModel);
}
return treeModel;
}
use of org.dmg.pmml.tree.Node in project jpmml-sparkml by jpmml.
the class TreeModelCompactor method handleNodePop.
private void handleNodePop(Node node) {
String score = node.getScore();
Predicate predicate = node.getPredicate();
if (predicate instanceof True) {
Node parentNode = getParentNode();
if (parentNode == null) {
return;
}
String parentScore = parentNode.getScore();
if (parentScore != null) {
throw new IllegalArgumentException();
}
if ((MiningFunction.REGRESSION).equals(this.miningFunction)) {
parentNode.setScore(score);
List<Node> parentChildren = parentNode.getNodes();
boolean success = parentChildren.remove(node);
if (!success) {
throw new IllegalArgumentException();
}
if (node.hasNodes()) {
List<Node> children = node.getNodes();
parentChildren.addAll(children);
}
} else if ((MiningFunction.CLASSIFICATION).equals(this.miningFunction)) {
if (node.hasNodes()) {
List<Node> parentChildren = parentNode.getNodes();
boolean success = parentChildren.remove(node);
if (!success) {
throw new IllegalArgumentException();
}
List<Node> children = node.getNodes();
parentChildren.addAll(children);
}
} else {
throw new IllegalArgumentException();
}
}
}
use of org.dmg.pmml.tree.Node in project jpmml-sparkml by jpmml.
the class TreeModelCompactor method handleNodePush.
private void handleNodePush(Node node) {
String id = node.getId();
String score = node.getScore();
if (id != null) {
throw new IllegalArgumentException();
}
if (node.hasNodes()) {
List<Node> children = node.getNodes();
if (children.size() != 2 || score != null) {
throw new IllegalArgumentException();
}
Node firstChild = children.get(0);
Node secondChild = children.get(1);
Predicate firstPredicate = firstChild.getPredicate();
Predicate secondPredicate = secondChild.getPredicate();
predicate: if (firstPredicate instanceof SimplePredicate && secondPredicate instanceof SimplePredicate) {
SimplePredicate firstSimplePredicate = (SimplePredicate) firstPredicate;
SimplePredicate secondSimplePredicate = (SimplePredicate) secondPredicate;
SimplePredicate.Operator firstOperator = firstSimplePredicate.getOperator();
SimplePredicate.Operator secondOperator = secondSimplePredicate.getOperator();
if (!(firstSimplePredicate.getField()).equals(secondSimplePredicate.getField())) {
throw new IllegalArgumentException();
}
if ((SimplePredicate.Operator.EQUAL).equals(firstOperator) && (SimplePredicate.Operator.EQUAL).equals(secondOperator)) {
if (!isCategoricalField(firstSimplePredicate.getField())) {
break predicate;
}
secondChild.setPredicate(new True());
} else {
if (!(firstSimplePredicate.getValue()).equals(secondSimplePredicate.getValue())) {
throw new IllegalArgumentException();
}
if ((SimplePredicate.Operator.NOT_EQUAL).equals(firstOperator) && (SimplePredicate.Operator.EQUAL).equals(secondOperator)) {
swapChildren(children);
firstChild = children.get(0);
secondChild = children.get(1);
} else if ((SimplePredicate.Operator.EQUAL).equals(firstOperator) && (SimplePredicate.Operator.NOT_EQUAL).equals(secondOperator)) {
// Ignored
} else if ((SimplePredicate.Operator.LESS_OR_EQUAL).equals(firstOperator) && (SimplePredicate.Operator.GREATER_THAN).equals(secondOperator)) {
// Ignored
} else {
throw new IllegalArgumentException();
}
secondChild.setPredicate(new True());
}
} else if (firstPredicate instanceof SimplePredicate && secondPredicate instanceof SimpleSetPredicate) {
SimplePredicate simplePredicate = (SimplePredicate) firstPredicate;
SimpleSetPredicate simpleSetPredicate = (SimpleSetPredicate) secondPredicate;
if (!(simplePredicate.getField()).equals(simpleSetPredicate.getField()) || !(SimplePredicate.Operator.EQUAL).equals(simplePredicate.getOperator()) || !(SimpleSetPredicate.BooleanOperator.IS_IN).equals(simpleSetPredicate.getBooleanOperator())) {
throw new IllegalArgumentException();
}
secondChild.setPredicate(addCategoricalField(simpleSetPredicate));
} else if (firstPredicate instanceof SimpleSetPredicate && secondPredicate instanceof SimplePredicate) {
SimpleSetPredicate simpleSetPredicate = (SimpleSetPredicate) firstPredicate;
SimplePredicate simplePredicate = (SimplePredicate) secondPredicate;
if (!(simpleSetPredicate.getField()).equals(simplePredicate.getField()) || !(SimpleSetPredicate.BooleanOperator.IS_IN).equals(simpleSetPredicate.getBooleanOperator()) || !(SimplePredicate.Operator.EQUAL).equals(simplePredicate.getOperator())) {
throw new IllegalArgumentException();
}
swapChildren(children);
firstChild = children.get(0);
secondChild = children.get(1);
secondChild.setPredicate(addCategoricalField(simpleSetPredicate));
} else if (firstPredicate instanceof SimpleSetPredicate && secondPredicate instanceof SimpleSetPredicate) {
SimpleSetPredicate firstSimpleSetPredicate = (SimpleSetPredicate) firstPredicate;
SimpleSetPredicate secondSimpleSetPredicate = (SimpleSetPredicate) secondPredicate;
if (!(firstSimpleSetPredicate.getField()).equals(secondSimpleSetPredicate.getField()) || !(SimpleSetPredicate.BooleanOperator.IS_IN).equals(firstSimpleSetPredicate.getBooleanOperator()) || !(SimpleSetPredicate.BooleanOperator.IS_IN).equals(secondSimpleSetPredicate.getBooleanOperator())) {
throw new IllegalArgumentException();
}
secondChild.setPredicate(addCategoricalField(secondSimpleSetPredicate));
} else {
throw new IllegalArgumentException();
}
} else {
if (score == null) {
throw new IllegalArgumentException();
}
}
}
use of org.dmg.pmml.tree.Node in project jpmml-sparkml by jpmml.
the class TreeModelCompactor method isCategoricalField.
private boolean isCategoricalField(FieldName name) {
Deque<PMMLObject> parents = getParents();
for (PMMLObject parent : parents) {
if (parent instanceof Node) {
Node node = (Node) parent;
Predicate predicate = node.getPredicate();
if (predicate instanceof SimpleSetPredicate) {
SimpleSetPredicate simpleSetPredicate = (SimpleSetPredicate) predicate;
FieldName categoricalField = simpleSetPredicate.getField();
if ((name).equals(categoricalField)) {
return true;
}
} else if (predicate instanceof True) {
True truePredicate = (True) predicate;
FieldName categoricalField = this.categoricalFields.get(truePredicate);
if ((name).equals(categoricalField)) {
return true;
}
}
} else {
return false;
}
}
return false;
}
Aggregations