use of org.dmg.pmml.tree.LeafNode in project drools by kiegroup.
the class KiePMMLTreeModelNodeASTFactoryTest method isFinalLeaf.
@Test
public void isFinalLeaf() {
Node node = new LeafNode();
DATA_TYPE targetType = DATA_TYPE.STRING;
KiePMMLTreeModelNodeASTFactory.factory(new HashMap<>(), Collections.emptyList(), TreeModel.NoTrueChildStrategy.RETURN_NULL_PREDICTION, targetType).isFinalLeaf(node);
assertTrue(KiePMMLTreeModelNodeASTFactory.factory(new HashMap<>(), Collections.emptyList(), TreeModel.NoTrueChildStrategy.RETURN_NULL_PREDICTION, targetType).isFinalLeaf(node));
node = new ClassifierNode();
assertTrue(KiePMMLTreeModelNodeASTFactory.factory(new HashMap<>(), Collections.emptyList(), TreeModel.NoTrueChildStrategy.RETURN_NULL_PREDICTION, targetType).isFinalLeaf(node));
node.addNodes(new LeafNode());
assertFalse(KiePMMLTreeModelNodeASTFactory.factory(new HashMap<>(), Collections.emptyList(), TreeModel.NoTrueChildStrategy.RETURN_NULL_PREDICTION, targetType).isFinalLeaf(node));
}
use of org.dmg.pmml.tree.LeafNode in project jpmml-r by jpmml.
the class RandomForestConverter method encodeNode.
private <P extends Number> Node encodeNode(Predicate predicate, int i, ScoreEncoder<P> scoreEncoder, List<? extends Number> leftDaughter, List<? extends Number> rightDaughter, List<? extends Number> bestvar, List<Double> xbestsplit, List<P> nodepred, CategoryManager categoryManager, Schema schema) {
Integer id = Integer.valueOf(i + 1);
int var = ValueUtil.asInt(bestvar.get(i));
if (var == 0) {
P prediction = nodepred.get(i);
Node result = new LeafNode(scoreEncoder.encode(prediction), predicate).setId(id);
return result;
}
CategoryManager leftCategoryManager = categoryManager;
CategoryManager rightCategoryManager = categoryManager;
Predicate leftPredicate;
Predicate rightPredicate;
Feature feature = schema.getFeature(var - 1);
Double split = xbestsplit.get(i);
if (feature instanceof BooleanFeature) {
BooleanFeature booleanFeature = (BooleanFeature) feature;
if (split != 0.5d) {
throw new IllegalArgumentException();
}
leftPredicate = createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(0));
rightPredicate = createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(1));
} else if (feature instanceof CategoricalFeature) {
CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
String name = categoricalFeature.getName();
List<?> values = categoricalFeature.getValues();
java.util.function.Predicate<Object> valueFilter = categoryManager.getValueFilter(name);
List<Object> leftValues = selectValues(values, valueFilter, split, true);
List<Object> rightValues = selectValues(values, valueFilter, split, false);
leftCategoryManager = categoryManager.fork(name, leftValues);
rightCategoryManager = categoryManager.fork(name, rightValues);
leftPredicate = createPredicate(categoricalFeature, leftValues);
rightPredicate = createPredicate(categoricalFeature, rightValues);
} else {
ContinuousFeature continuousFeature = feature.toContinuousFeature();
leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, split);
rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, split);
}
Node result = new BranchNode(null, predicate).setId(id);
List<Node> nodes = result.getNodes();
int left = ValueUtil.asInt(leftDaughter.get(i));
if (left != 0) {
Node leftChild = encodeNode(leftPredicate, left - 1, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit, nodepred, leftCategoryManager, schema);
nodes.add(leftChild);
}
int right = ValueUtil.asInt(rightDaughter.get(i));
if (right != 0) {
Node rightChild = encodeNode(rightPredicate, right - 1, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit, nodepred, rightCategoryManager, schema);
nodes.add(rightChild);
}
return result;
}
use of org.dmg.pmml.tree.LeafNode in project jpmml-r by jpmml.
the class BinaryTreeConverter method encodeNode.
private Node encodeNode(RGenericVector tree, Predicate predicate, Schema schema) {
RIntegerVector nodeId = tree.getIntegerElement("nodeID");
RBooleanVector terminal = tree.getBooleanElement("terminal");
RGenericVector psplit = tree.getGenericElement("psplit");
RGenericVector ssplits = tree.getGenericElement("ssplits");
RDoubleVector prediction = tree.getDoubleElement("prediction");
RGenericVector left = tree.getGenericElement("left");
RGenericVector right = tree.getGenericElement("right");
Integer id = nodeId.asScalar();
if ((Boolean.TRUE).equals(terminal.asScalar())) {
Node result = new LeafNode(null, predicate).setId(id);
return encodeScore(result, prediction, schema);
}
RNumberVector<?> splitpoint = psplit.getNumericElement("splitpoint");
RStringVector variableName = psplit.getStringElement("variableName");
if (ssplits.size() > 0) {
throw new IllegalArgumentException();
}
Predicate leftPredicate;
Predicate rightPredicate;
String name = variableName.asScalar();
Integer index = this.featureIndexes.get(name);
if (index == null) {
throw new IllegalArgumentException();
}
Feature feature = schema.getFeature(index);
if (feature instanceof CategoricalFeature) {
CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
List<?> values = categoricalFeature.getValues();
List<Integer> splitValues = (List<Integer>) splitpoint.getValues();
leftPredicate = createPredicate(categoricalFeature, selectValues(values, splitValues, true));
rightPredicate = createPredicate(categoricalFeature, selectValues(values, splitValues, false));
} else {
ContinuousFeature continuousFeature = feature.toContinuousFeature();
Number value = splitpoint.asScalar();
leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
}
Node leftChild = encodeNode(left, leftPredicate, schema);
Node rightChild = encodeNode(right, rightPredicate, schema);
Node result = new BranchNode(null, predicate).setId(id).addNodes(leftChild, rightChild);
return result;
}
use of org.dmg.pmml.tree.LeafNode in project jpmml-sparkml by jpmml.
the class TreeModelUtil method encodeDecisionTree.
private static <M extends Model<M> & DecisionTreeModel> TreeModel encodeDecisionTree(ModelConverter<?> converter, M model, PredicateManager predicateManager, ScoreDistributionManager scoreDistributionManager, Schema schema) {
TreeModel treeModel;
if (model instanceof DecisionTreeRegressionModel) {
ScoreEncoder scoreEncoder = new ScoreEncoder() {
@Override
public Node encode(Node node, org.apache.spark.ml.tree.LeafNode leafNode) {
node.setScore(leafNode.prediction());
return node;
}
};
treeModel = encodeTreeModel(MiningFunction.REGRESSION, scoreEncoder, model, predicateManager, schema);
} else if (model instanceof DecisionTreeClassificationModel) {
ScoreEncoder scoreEncoder = new ScoreEncoder() {
private CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
@Override
public Node encode(Node node, org.apache.spark.ml.tree.LeafNode leafNode) {
node = new ClassifierNode(null, node.getPredicate());
int index = ValueUtil.asInt(leafNode.prediction());
node.setScore(this.categoricalLabel.getValue(index));
ImpurityCalculator impurityCalculator = leafNode.impurityStats();
node.setRecordCount(ValueUtil.narrow(impurityCalculator.count()));
scoreDistributionManager.addScoreDistributions(node, this.categoricalLabel.getValues(), impurityCalculator.stats());
return node;
}
};
treeModel = encodeTreeModel(MiningFunction.CLASSIFICATION, scoreEncoder, model, predicateManager, schema);
} else {
throw new IllegalArgumentException();
}
Boolean compact = (Boolean) converter.getOption(HasTreeOptions.OPTION_COMPACT, Boolean.TRUE);
if (compact != null && compact) {
Visitor visitor = new TreeModelCompactor();
visitor.applyTo(treeModel);
}
return treeModel;
}
Aggregations