use of org.dmg.pmml.SimplePredicateDocument.SimplePredicate in project knime-core by knime.
the class FromDecisionTreeNodeModel method addRules.
/**
* Adds the rules to {@code rs} (recursively on each leaf).
*
* @param rs The output {@link RuleSet}.
* @param parents The parent stack.
* @param node The actual node.
*/
private void addRules(final RuleSet rs, final List<DecisionTreeNode> parents, final DecisionTreeNode node) {
if (node.isLeaf()) {
SimpleRule rule = rs.addNewSimpleRule();
if (m_rulesToTable.getScorePmmlRecordCount().getBooleanValue()) {
// This increases the PMML quite significantly
BigDecimal sum = BigDecimal.ZERO;
final MathContext mc = new MathContext(7, RoundingMode.HALF_EVEN);
final boolean computeProbability = m_rulesToTable.getScorePmmlProbability().getBooleanValue();
if (computeProbability) {
sum = new BigDecimal(node.getClassCounts().entrySet().stream().mapToDouble(e -> e.getValue().doubleValue()).sum(), mc);
}
for (final Entry<DataCell, Double> entry : node.getClassCounts().entrySet()) {
final ScoreDistribution scoreDistrib = rule.addNewScoreDistribution();
scoreDistrib.setValue(entry.getKey().toString());
scoreDistrib.setRecordCount(entry.getValue());
if (computeProbability) {
if (Double.compare(entry.getValue().doubleValue(), 0.0) == 0) {
scoreDistrib.setProbability(new BigDecimal(0.0));
} else {
scoreDistrib.setProbability(new BigDecimal(entry.getValue().doubleValue(), mc).divide(sum, mc));
}
}
}
}
CompoundPredicate and = rule.addNewCompoundPredicate();
and.setBooleanOperator(BooleanOperator.AND);
DecisionTreeNode n = node;
do {
PMMLPredicate pmmlPredicate = ((DecisionTreeNodeSplitPMML) n.getParent()).getSplitPred()[n.getParent().getIndex(n)];
if (pmmlPredicate instanceof PMMLSimplePredicate) {
PMMLSimplePredicate simple = (PMMLSimplePredicate) pmmlPredicate;
SimplePredicate predicate = and.addNewSimplePredicate();
copy(predicate, simple);
} else if (pmmlPredicate instanceof PMMLCompoundPredicate) {
PMMLCompoundPredicate compound = (PMMLCompoundPredicate) pmmlPredicate;
CompoundPredicate predicate = and.addNewCompoundPredicate();
copy(predicate, compound);
} else if (pmmlPredicate instanceof PMMLSimpleSetPredicate) {
PMMLSimpleSetPredicate simpleSet = (PMMLSimpleSetPredicate) pmmlPredicate;
copy(and.addNewSimpleSetPredicate(), simpleSet);
} else if (pmmlPredicate instanceof PMMLTruePredicate) {
and.addNewTrue();
} else if (pmmlPredicate instanceof PMMLFalsePredicate) {
and.addNewFalse();
}
n = n.getParent();
} while (n.getParent() != null);
// Simple fix for the case when a single condition was used.
while (and.getFalseList().size() + and.getCompoundPredicateList().size() + and.getSimplePredicateList().size() + and.getSimpleSetPredicateList().size() + and.getTrueList().size() < 2) {
and.addNewTrue();
}
if (m_rulesToTable.getProvideStatistics().getBooleanValue()) {
rule.setNbCorrect(node.getOwnClassCount());
rule.setRecordCount(node.getEntireClassCount());
}
rule.setScore(node.getMajorityClass().toString());
} else {
parents.add(node);
for (int i = 0; i < node.getChildCount(); ++i) {
addRules(rs, parents, node.getChildAt(i));
}
parents.remove(node);
}
}
use of org.dmg.pmml.SimplePredicateDocument.SimplePredicate in project knime-core by knime.
the class LiteralConditionParser method parseCondition.
/**
* {@inheritDoc}
*/
@Override
public TreeNodeCondition parseCondition(final Node node) {
CompoundPredicate compound = node.getCompoundPredicate();
if (compound != null) {
return handleCompoundPredicate(compound);
}
SimplePredicate simplePred = node.getSimplePredicate();
if (simplePred != null) {
return handleSimplePredicate(simplePred, false);
}
SimpleSetPredicate simpleSetPred = node.getSimpleSetPredicate();
if (simpleSetPred != null) {
return handleSimpleSetPredicate(simpleSetPred, false);
}
True truePred = node.getTrue();
if (truePred != null) {
return TreeNodeTrueCondition.INSTANCE;
}
False falsePred = node.getFalse();
if (falsePred != null) {
throw new IllegalArgumentException("There is no False condition in KNIME.");
}
throw new IllegalStateException("The pmmlNode contains no valid Predicate.");
}
use of org.dmg.pmml.SimplePredicateDocument.SimplePredicate in project knime-core by knime.
the class PMMLConditionTranslator method parseCompoundPredicate.
/**
* Create a KNIME compound predicate from a PMML compound predicate. Note that the "order" of the sub-predicates is
* important (because of surrogate predicate). Therefore, we need to use xmlCursor to retrieve the order of the
* predicates
*
* @param xmlCompoundPredicate the PMML Compound Predicate element
* @return the KNIME Compound Predicate
*/
protected PMMLCompoundPredicate parseCompoundPredicate(final CompoundPredicate xmlCompoundPredicate) {
List<PMMLPredicate> tempPredicateList = new ArrayList<PMMLPredicate>();
if (xmlCompoundPredicate.sizeOfSimplePredicateArray() != 0) {
for (SimplePredicate xmlSubSimplePredicate : xmlCompoundPredicate.getSimplePredicateList()) {
tempPredicateList.add(parseSimplePredicate(xmlSubSimplePredicate));
}
}
if (xmlCompoundPredicate.sizeOfCompoundPredicateArray() != 0) {
for (CompoundPredicate xmlSubCompoundPredicate : xmlCompoundPredicate.getCompoundPredicateList()) {
tempPredicateList.add(parseCompoundPredicate(xmlSubCompoundPredicate));
}
}
if (xmlCompoundPredicate.sizeOfSimpleSetPredicateArray() != 0) {
for (SimpleSetPredicate xmlSubSimpleSetPredicate : xmlCompoundPredicate.getSimpleSetPredicateList()) {
tempPredicateList.add(parseSimpleSetPredicate(xmlSubSimpleSetPredicate));
}
}
if (xmlCompoundPredicate.sizeOfTrueArray() != 0) {
for (int i = 0; i < xmlCompoundPredicate.sizeOfTrueArray(); i++) {
tempPredicateList.add(new PMMLTruePredicate());
}
}
if (xmlCompoundPredicate.sizeOfFalseArray() != 0) {
for (int i = 0; i < xmlCompoundPredicate.sizeOfFalseArray(); i++) {
tempPredicateList.add(new PMMLFalsePredicate());
}
}
List<String> predicateNames = new ArrayList<String>();
XmlCursor xmlCursor = xmlCompoundPredicate.newCursor();
if (xmlCursor.toFirstChild()) {
do {
XmlObject xmlElement = xmlCursor.getObject();
XmlCursor elementCursor = xmlElement.newCursor();
if (xmlElement instanceof CompoundPredicateDocument.CompoundPredicate) {
predicateNames.add(COMPOUND);
} else if (xmlElement instanceof TrueDocument.True) {
predicateNames.add(TRUE);
} else if (xmlElement instanceof FalseDocument.False) {
predicateNames.add(FALSE);
} else {
elementCursor.toFirstAttribute();
do {
if ("field".equals(elementCursor.getName().getLocalPart())) {
predicateNames.add(m_nameMapper.getColumnName(elementCursor.getTextValue()));
break;
}
} while (elementCursor.toNextAttribute());
}
} while (xmlCursor.toNextSibling());
}
// ------------------------------------------------------
// sort the predicate list
List<PMMLPredicate> predicateList = new ArrayList<PMMLPredicate>();
List<PMMLPredicate> compoundList = new ArrayList<PMMLPredicate>();
for (PMMLPredicate tempPredicate : tempPredicateList) {
if (tempPredicate instanceof PMMLCompoundPredicate) {
compoundList.add(tempPredicate);
}
}
for (String name : predicateNames) {
if (name.equals(COMPOUND)) {
predicateList.add(compoundList.get(0));
compoundList.remove(0);
} else if (name.equals(TRUE)) {
predicateList.add(new PMMLTruePredicate());
} else if (name.equals(FALSE)) {
predicateList.add(new PMMLFalsePredicate());
} else {
int foundIndex = -1, i = 0;
for (PMMLPredicate tempPredicate : tempPredicateList) {
if (tempPredicate instanceof PMMLSimplePredicate) {
if (name.equals(((PMMLSimplePredicate) tempPredicate).getSplitAttribute())) {
predicateList.add(tempPredicate);
foundIndex = i;
break;
}
} else if (tempPredicate instanceof PMMLSimpleSetPredicate) {
if (name.equals(((PMMLSimpleSetPredicate) tempPredicate).getSplitAttribute())) {
predicateList.add(tempPredicate);
foundIndex = i;
break;
}
}
++i;
}
assert foundIndex >= 0 : tempPredicateList + "\n" + name;
tempPredicateList.remove(foundIndex);
}
}
LinkedList<PMMLPredicate> subPredicates = new LinkedList<PMMLPredicate>(predicateList);
String operator = xmlCompoundPredicate.getBooleanOperator().toString();
PMMLCompoundPredicate compoundPredicate = newCompoundPredicate(operator);
compoundPredicate.setPredicates(subPredicates);
return compoundPredicate;
}
use of org.dmg.pmml.SimplePredicateDocument.SimplePredicate in project knime-core by knime.
the class PMMLDecisionTreeTranslator method addTreeNode.
/**
* A recursive function which converts each KNIME Tree node to a
* corresponding PMML element.
*
* @param pmmlNode the desired PMML element
* @param node A KNIME DecisionTree node
*/
private static void addTreeNode(final NodeDocument.Node pmmlNode, final DecisionTreeNode node, final DerivedFieldMapper mapper) {
pmmlNode.setId(String.valueOf(node.getOwnIndex()));
pmmlNode.setScore(node.getMajorityClass().toString());
// read in and then exported again
if (node.getEntireClassCount() > 0) {
pmmlNode.setRecordCount(node.getEntireClassCount());
}
if (node instanceof DecisionTreeNodeSplitPMML) {
int defaultChild = ((DecisionTreeNodeSplitPMML) node).getDefaultChildIndex();
if (defaultChild > -1) {
pmmlNode.setDefaultChild(String.valueOf(defaultChild));
}
}
// adding score and stuff from parent
DecisionTreeNode parent = node.getParent();
if (parent == null) {
// When the parent is null, it is the root Node.
// For root node, the predicate is always True.
pmmlNode.addNewTrue();
} else if (parent instanceof DecisionTreeNodeSplitContinuous) {
// SimplePredicate case
DecisionTreeNodeSplitContinuous splitNode = (DecisionTreeNodeSplitContinuous) parent;
if (splitNode.getIndex(node) == 0) {
SimplePredicate pmmlSimplePredicate = pmmlNode.addNewSimplePredicate();
pmmlSimplePredicate.setField(mapper.getDerivedFieldName(splitNode.getSplitAttr()));
pmmlSimplePredicate.setOperator(Operator.LESS_OR_EQUAL);
pmmlSimplePredicate.setValue(String.valueOf(splitNode.getThreshold()));
} else if (splitNode.getIndex(node) == 1) {
pmmlNode.addNewTrue();
}
} else if (parent instanceof DecisionTreeNodeSplitNominalBinary) {
// SimpleSetPredicate case
DecisionTreeNodeSplitNominalBinary splitNode = (DecisionTreeNodeSplitNominalBinary) parent;
SimpleSetPredicate pmmlSimpleSetPredicate = pmmlNode.addNewSimpleSetPredicate();
pmmlSimpleSetPredicate.setField(mapper.getDerivedFieldName(splitNode.getSplitAttr()));
pmmlSimpleSetPredicate.setBooleanOperator(SimpleSetPredicate.BooleanOperator.IS_IN);
ArrayType pmmlArray = pmmlSimpleSetPredicate.addNewArray();
pmmlArray.setType(ArrayType.Type.STRING);
DataCell[] splitValues = splitNode.getSplitValues();
List<Integer> indices = null;
if (splitNode.getIndex(node) == SplitNominalBinary.LEFT_PARTITION) {
indices = splitNode.getLeftChildIndices();
} else if (splitNode.getIndex(node) == SplitNominalBinary.RIGHT_PARTITION) {
indices = splitNode.getRightChildIndices();
} else {
throw new IllegalArgumentException("Split node is neither " + "contained in the right nor in the left partition.");
}
StringBuilder classSet = new StringBuilder();
for (Integer i : indices) {
if (classSet.length() > 0) {
classSet.append(" ");
}
classSet.append(splitValues[i].toString());
}
pmmlArray.setN(BigInteger.valueOf(indices.size()));
XmlCursor xmlCursor = pmmlArray.newCursor();
xmlCursor.setTextValue(classSet.toString());
xmlCursor.dispose();
} else if (parent instanceof DecisionTreeNodeSplitNominal) {
DecisionTreeNodeSplitNominal splitNode = (DecisionTreeNodeSplitNominal) parent;
SimplePredicate pmmlSimplePredicate = pmmlNode.addNewSimplePredicate();
pmmlSimplePredicate.setField(mapper.getDerivedFieldName(splitNode.getSplitAttr()));
pmmlSimplePredicate.setOperator(Operator.EQUAL);
int nodeIndex = parent.getIndex(node);
pmmlSimplePredicate.setValue(String.valueOf(splitNode.getSplitValues()[nodeIndex].toString()));
} else if (parent instanceof DecisionTreeNodeSplitPMML) {
DecisionTreeNodeSplitPMML splitNode = (DecisionTreeNodeSplitPMML) parent;
int nodeIndex = parent.getIndex(node);
// get the PMML predicate of the current node from its parent
PMMLPredicate predicate = splitNode.getSplitPred()[nodeIndex];
if (predicate instanceof PMMLCompoundPredicate) {
// surrogates as used in GBT
exportCompoundPredicate(pmmlNode, (PMMLCompoundPredicate) predicate, mapper);
} else {
predicate.setSplitAttribute(mapper.getDerivedFieldName(predicate.getSplitAttribute()));
// delegate the writing to the predicate translator
PMMLPredicateTranslator.exportTo(predicate, pmmlNode);
}
} else {
throw new IllegalArgumentException("Node Type " + parent.getClass() + " is not supported!");
}
// adding score distribution (class counts)
Set<Entry<DataCell, Double>> classCounts = node.getClassCounts().entrySet();
Iterator<Entry<DataCell, Double>> iterator = classCounts.iterator();
while (iterator.hasNext()) {
Entry<DataCell, Double> entry = iterator.next();
DataCell cell = entry.getKey();
Double freq = entry.getValue();
ScoreDistribution pmmlScoreDist = pmmlNode.addNewScoreDistribution();
pmmlScoreDist.setValue(cell.toString());
pmmlScoreDist.setRecordCount(freq);
}
// adding children
if (!(node instanceof DecisionTreeNodeLeaf)) {
for (int i = 0; i < node.getChildCount(); i++) {
addTreeNode(pmmlNode.addNewNode(), node.getChildAt(i), mapper);
}
}
}
use of org.dmg.pmml.SimplePredicateDocument.SimplePredicate in project knime-core by knime.
the class PMMLPredicateTranslator method exportTo.
/**
* @param predicate the predicate to export
* @param compound the CompundPredicate element to add the predicate to
*/
public static void exportTo(final PMMLPredicate predicate, final CompoundPredicate compound) {
/**
* Is basically a duplicate of the other export methods but there is no common parent class and therefore the
* code is not really reusable.
*/
if (predicate instanceof PMMLFalsePredicate) {
compound.addNewFalse();
} else if (predicate instanceof PMMLTruePredicate) {
compound.addNewTrue();
} else if (predicate instanceof PMMLFalsePredicate) {
compound.addNewFalse();
} else if (predicate instanceof PMMLSimplePredicate) {
PMMLSimplePredicate sp = (PMMLSimplePredicate) predicate;
SimplePredicate simplePred = compound.addNewSimplePredicate();
initSimplePredicate(sp, simplePred);
} else if (predicate instanceof PMMLSimpleSetPredicate) {
PMMLSimpleSetPredicate sp = (PMMLSimpleSetPredicate) predicate;
SimpleSetPredicate setPred = compound.addNewSimpleSetPredicate();
initSimpleSetPred(sp, setPred);
} else if (predicate instanceof PMMLCompoundPredicate) {
PMMLCompoundPredicate compPred = (PMMLCompoundPredicate) predicate;
CompoundPredicate cp = CompoundPredicate.Factory.newInstance();
cp.setBooleanOperator(getOperator(compPred.getBooleanOperator()));
for (PMMLPredicate pred : compPred.getPredicates()) {
exportTo(pred, cp);
}
}
}
Aggregations