use of org.dmg.pmml.Predicate in project jpmml-r by jpmml.
the class BinaryTreeConverter method encodeNode.
private void encodeNode(Node node, RGenericVector tree, Schema schema) {
RIntegerVector nodeId = (RIntegerVector) tree.getValue("nodeID");
RBooleanVector terminal = (RBooleanVector) tree.getValue("terminal");
RGenericVector psplit = (RGenericVector) tree.getValue("psplit");
RGenericVector ssplits = (RGenericVector) tree.getValue("ssplits");
RDoubleVector prediction = (RDoubleVector) tree.getValue("prediction");
RGenericVector left = (RGenericVector) tree.getValue("left");
RGenericVector right = (RGenericVector) tree.getValue("right");
node.setId(String.valueOf(nodeId.asScalar()));
if ((Boolean.TRUE).equals(terminal.asScalar())) {
node = encodeScore(node, prediction, schema);
return;
}
RNumberVector<?> splitpoint = (RNumberVector<?>) psplit.getValue("splitpoint");
RStringVector variableName = (RStringVector) psplit.getValue("variableName");
if (ssplits.size() > 0) {
throw new IllegalArgumentException();
}
Predicate leftPredicate;
Predicate rightPredicate;
FieldName name = FieldName.create(variableName.asScalar());
Integer index = this.featureIndexes.get(name);
if (index == null) {
throw new IllegalArgumentException();
}
Feature feature = schema.getFeature(index);
if (feature instanceof CategoricalFeature) {
CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
List<String> values = categoricalFeature.getValues();
List<Integer> splitValues = (List<Integer>) splitpoint.getValues();
leftPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, splitValues, true));
rightPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, splitValues, false));
} else {
ContinuousFeature continuousFeature = feature.toContinuousFeature();
String value = ValueUtil.formatValue((Double) splitpoint.asScalar());
leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
}
Node leftChild = new Node().setPredicate(leftPredicate);
encodeNode(leftChild, left, schema);
Node rightChild = new Node().setPredicate(rightPredicate);
encodeNode(rightChild, right, schema);
node.addNodes(leftChild, rightChild);
}
use of org.dmg.pmml.Predicate in project jpmml-sparkml by jpmml.
the class TreeModelCompactor method handleNodePop.
private void handleNodePop(Node node) {
String score = node.getScore();
Predicate predicate = node.getPredicate();
if (predicate instanceof True) {
Node parentNode = getParentNode();
if (parentNode == null) {
return;
}
String parentScore = parentNode.getScore();
if (parentScore != null) {
throw new IllegalArgumentException();
}
if ((MiningFunction.REGRESSION).equals(this.miningFunction)) {
parentNode.setScore(score);
List<Node> parentChildren = parentNode.getNodes();
boolean success = parentChildren.remove(node);
if (!success) {
throw new IllegalArgumentException();
}
if (node.hasNodes()) {
List<Node> children = node.getNodes();
parentChildren.addAll(children);
}
} else if ((MiningFunction.CLASSIFICATION).equals(this.miningFunction)) {
if (node.hasNodes()) {
List<Node> parentChildren = parentNode.getNodes();
boolean success = parentChildren.remove(node);
if (!success) {
throw new IllegalArgumentException();
}
List<Node> children = node.getNodes();
parentChildren.addAll(children);
}
} else {
throw new IllegalArgumentException();
}
}
}
use of org.dmg.pmml.Predicate in project jpmml-sparkml by jpmml.
the class TreeModelCompactor method handleNodePush.
private void handleNodePush(Node node) {
String id = node.getId();
String score = node.getScore();
if (id != null) {
throw new IllegalArgumentException();
}
if (node.hasNodes()) {
List<Node> children = node.getNodes();
if (children.size() != 2 || score != null) {
throw new IllegalArgumentException();
}
Node firstChild = children.get(0);
Node secondChild = children.get(1);
Predicate firstPredicate = firstChild.getPredicate();
Predicate secondPredicate = secondChild.getPredicate();
predicate: if (firstPredicate instanceof SimplePredicate && secondPredicate instanceof SimplePredicate) {
SimplePredicate firstSimplePredicate = (SimplePredicate) firstPredicate;
SimplePredicate secondSimplePredicate = (SimplePredicate) secondPredicate;
SimplePredicate.Operator firstOperator = firstSimplePredicate.getOperator();
SimplePredicate.Operator secondOperator = secondSimplePredicate.getOperator();
if (!(firstSimplePredicate.getField()).equals(secondSimplePredicate.getField())) {
throw new IllegalArgumentException();
}
if ((SimplePredicate.Operator.EQUAL).equals(firstOperator) && (SimplePredicate.Operator.EQUAL).equals(secondOperator)) {
if (!isCategoricalField(firstSimplePredicate.getField())) {
break predicate;
}
secondChild.setPredicate(new True());
} else {
if (!(firstSimplePredicate.getValue()).equals(secondSimplePredicate.getValue())) {
throw new IllegalArgumentException();
}
if ((SimplePredicate.Operator.NOT_EQUAL).equals(firstOperator) && (SimplePredicate.Operator.EQUAL).equals(secondOperator)) {
swapChildren(children);
firstChild = children.get(0);
secondChild = children.get(1);
} else if ((SimplePredicate.Operator.EQUAL).equals(firstOperator) && (SimplePredicate.Operator.NOT_EQUAL).equals(secondOperator)) {
// Ignored
} else if ((SimplePredicate.Operator.LESS_OR_EQUAL).equals(firstOperator) && (SimplePredicate.Operator.GREATER_THAN).equals(secondOperator)) {
// Ignored
} else {
throw new IllegalArgumentException();
}
secondChild.setPredicate(new True());
}
} else if (firstPredicate instanceof SimplePredicate && secondPredicate instanceof SimpleSetPredicate) {
SimplePredicate simplePredicate = (SimplePredicate) firstPredicate;
SimpleSetPredicate simpleSetPredicate = (SimpleSetPredicate) secondPredicate;
if (!(simplePredicate.getField()).equals(simpleSetPredicate.getField()) || !(SimplePredicate.Operator.EQUAL).equals(simplePredicate.getOperator()) || !(SimpleSetPredicate.BooleanOperator.IS_IN).equals(simpleSetPredicate.getBooleanOperator())) {
throw new IllegalArgumentException();
}
secondChild.setPredicate(addCategoricalField(simpleSetPredicate));
} else if (firstPredicate instanceof SimpleSetPredicate && secondPredicate instanceof SimplePredicate) {
SimpleSetPredicate simpleSetPredicate = (SimpleSetPredicate) firstPredicate;
SimplePredicate simplePredicate = (SimplePredicate) secondPredicate;
if (!(simpleSetPredicate.getField()).equals(simplePredicate.getField()) || !(SimpleSetPredicate.BooleanOperator.IS_IN).equals(simpleSetPredicate.getBooleanOperator()) || !(SimplePredicate.Operator.EQUAL).equals(simplePredicate.getOperator())) {
throw new IllegalArgumentException();
}
swapChildren(children);
firstChild = children.get(0);
secondChild = children.get(1);
secondChild.setPredicate(addCategoricalField(simpleSetPredicate));
} else if (firstPredicate instanceof SimpleSetPredicate && secondPredicate instanceof SimpleSetPredicate) {
SimpleSetPredicate firstSimpleSetPredicate = (SimpleSetPredicate) firstPredicate;
SimpleSetPredicate secondSimpleSetPredicate = (SimpleSetPredicate) secondPredicate;
if (!(firstSimpleSetPredicate.getField()).equals(secondSimpleSetPredicate.getField()) || !(SimpleSetPredicate.BooleanOperator.IS_IN).equals(firstSimpleSetPredicate.getBooleanOperator()) || !(SimpleSetPredicate.BooleanOperator.IS_IN).equals(secondSimpleSetPredicate.getBooleanOperator())) {
throw new IllegalArgumentException();
}
secondChild.setPredicate(addCategoricalField(secondSimpleSetPredicate));
} else {
throw new IllegalArgumentException();
}
} else {
if (score == null) {
throw new IllegalArgumentException();
}
}
}
use of org.dmg.pmml.Predicate in project jpmml-sparkml by jpmml.
the class TreeModelCompactor method isCategoricalField.
private boolean isCategoricalField(FieldName name) {
Deque<PMMLObject> parents = getParents();
for (PMMLObject parent : parents) {
if (parent instanceof Node) {
Node node = (Node) parent;
Predicate predicate = node.getPredicate();
if (predicate instanceof SimpleSetPredicate) {
SimpleSetPredicate simpleSetPredicate = (SimpleSetPredicate) predicate;
FieldName categoricalField = simpleSetPredicate.getField();
if ((name).equals(categoricalField)) {
return true;
}
} else if (predicate instanceof True) {
True truePredicate = (True) predicate;
FieldName categoricalField = this.categoricalFields.get(truePredicate);
if ((name).equals(categoricalField)) {
return true;
}
}
} else {
return false;
}
}
return false;
}
Aggregations