use of org.knime.base.node.mine.decisiontree2.PMMLPredicate in project knime-core by knime.
the class TreeModelPMMLTranslator method setValuesFromPMMLCompoundPredicate.
private static void setValuesFromPMMLCompoundPredicate(final CompoundPredicate to, final PMMLCompoundPredicate from) {
final PMMLBooleanOperator boolOp = from.getBooleanOperator();
switch(boolOp) {
case AND:
to.setBooleanOperator(CompoundPredicate.BooleanOperator.AND);
break;
case OR:
to.setBooleanOperator(CompoundPredicate.BooleanOperator.OR);
break;
case SURROGATE:
to.setBooleanOperator(CompoundPredicate.BooleanOperator.SURROGATE);
break;
case XOR:
to.setBooleanOperator(CompoundPredicate.BooleanOperator.XOR);
break;
default:
throw new IllegalStateException("Unknown boolean predicate \"" + boolOp + "\".");
}
final List<PMMLPredicate> predicates = from.getPredicates();
for (final PMMLPredicate predicate : predicates) {
if (predicate instanceof PMMLSimplePredicate) {
setValuesFromPMMLSimplePredicate(to.addNewSimplePredicate(), (PMMLSimplePredicate) predicate);
} else if (predicate instanceof PMMLSimpleSetPredicate) {
setValuesFromPMMLSimpleSetPredicate(to.addNewSimpleSetPredicate(), (PMMLSimpleSetPredicate) predicate);
} else if (predicate instanceof PMMLTruePredicate) {
to.addNewTrue();
} else if (predicate instanceof PMMLFalsePredicate) {
to.addNewFalse();
} else if (predicate instanceof PMMLCompoundPredicate) {
final CompoundPredicate compound = to.addNewCompoundPredicate();
final PMMLCompoundPredicate knimeCompound = (PMMLCompoundPredicate) predicate;
setValuesFromPMMLCompoundPredicate(compound, knimeCompound);
} else {
throw new IllegalStateException("Unknown predicate type \"" + predicate + "\".");
}
}
}
use of org.knime.base.node.mine.decisiontree2.PMMLPredicate in project knime-core by knime.
the class TreeNodeNominalBinaryConditionTest method testToPMMLPredicate.
/**
* This method tests the
* {@link TreeNodeNominalBinaryCondition#toPMMLPredicate()} method.
*
* @throws Exception
*/
@Test
public void testToPMMLPredicate() throws Exception {
final TreeEnsembleLearnerConfiguration config = new TreeEnsembleLearnerConfiguration(false);
final TestDataGenerator dataGen = new TestDataGenerator(config);
final TreeNominalColumnData col = dataGen.createNominalAttributeColumn("A,A,B,C,C,D", "testcol", 0);
TreeNodeNominalBinaryCondition cond = new TreeNodeNominalBinaryCondition(col.getMetaData(), BigInteger.valueOf(1), true, false);
PMMLPredicate predicate = cond.toPMMLPredicate();
assertThat(predicate, instanceOf(PMMLSimpleSetPredicate.class));
PMMLSimpleSetPredicate setPredicate = (PMMLSimpleSetPredicate) predicate;
assertEquals("Wrong attribute", col.getMetaData().getAttributeName(), setPredicate.getSplitAttribute());
assertEquals("Wrong set predicate", PMMLSetOperator.IS_IN, setPredicate.getSetOperator());
assertArrayEquals("Wrong values", new String[] { "A" }, setPredicate.getValues().toArray(new String[1]));
cond = new TreeNodeNominalBinaryCondition(col.getMetaData(), BigInteger.valueOf(2), false, true);
predicate = cond.toPMMLPredicate();
assertEquals("Wrong attribute", col.getMetaData().getAttributeName(), predicate.getSplitAttribute());
assertThat(predicate, instanceOf(PMMLCompoundPredicate.class));
PMMLCompoundPredicate compoundPredicate = (PMMLCompoundPredicate) predicate;
assertEquals("Wrong boolean operator", PMMLBooleanOperator.OR, compoundPredicate.getBooleanOperator());
LinkedList<PMMLPredicate> preds = compoundPredicate.getPredicates();
assertEquals("Number of predicates did not match.", 2, preds.size());
assertThat(preds.get(0), instanceOf(PMMLSimpleSetPredicate.class));
setPredicate = (PMMLSimpleSetPredicate) preds.get(0);
assertEquals("Wrong attribute", col.getMetaData().getAttributeName(), setPredicate.getSplitAttribute());
assertEquals("Wrong set predicate", PMMLSetOperator.IS_NOT_IN, setPredicate.getSetOperator());
assertArrayEquals("Wrong values", new String[] { "B" }, setPredicate.getValues().toArray(new String[1]));
assertThat(preds.get(1), instanceOf(PMMLSimplePredicate.class));
PMMLSimplePredicate simplePredicate = (PMMLSimplePredicate) preds.get(1);
assertEquals("Should be isMissing", PMMLOperator.IS_MISSING, simplePredicate.getOperator());
}
use of org.knime.base.node.mine.decisiontree2.PMMLPredicate in project knime-core by knime.
the class PMMLDecisionTreeTranslator method addTreeNode.
/**
* A recursive function which converts each KNIME Tree node to a
* corresponding PMML element.
*
* @param pmmlNode the desired PMML element
* @param node A KNIME DecisionTree node
*/
private static void addTreeNode(final NodeDocument.Node pmmlNode, final DecisionTreeNode node, final DerivedFieldMapper mapper) {
pmmlNode.setId(String.valueOf(node.getOwnIndex()));
pmmlNode.setScore(node.getMajorityClass().toString());
// read in and then exported again
if (node.getEntireClassCount() > 0) {
pmmlNode.setRecordCount(node.getEntireClassCount());
}
if (node instanceof DecisionTreeNodeSplitPMML) {
int defaultChild = ((DecisionTreeNodeSplitPMML) node).getDefaultChildIndex();
if (defaultChild > -1) {
pmmlNode.setDefaultChild(String.valueOf(defaultChild));
}
}
// adding score and stuff from parent
DecisionTreeNode parent = node.getParent();
if (parent == null) {
// When the parent is null, it is the root Node.
// For root node, the predicate is always True.
pmmlNode.addNewTrue();
} else if (parent instanceof DecisionTreeNodeSplitContinuous) {
// SimplePredicate case
DecisionTreeNodeSplitContinuous splitNode = (DecisionTreeNodeSplitContinuous) parent;
if (splitNode.getIndex(node) == 0) {
SimplePredicate pmmlSimplePredicate = pmmlNode.addNewSimplePredicate();
pmmlSimplePredicate.setField(mapper.getDerivedFieldName(splitNode.getSplitAttr()));
pmmlSimplePredicate.setOperator(Operator.LESS_OR_EQUAL);
pmmlSimplePredicate.setValue(String.valueOf(splitNode.getThreshold()));
} else if (splitNode.getIndex(node) == 1) {
pmmlNode.addNewTrue();
}
} else if (parent instanceof DecisionTreeNodeSplitNominalBinary) {
// SimpleSetPredicate case
DecisionTreeNodeSplitNominalBinary splitNode = (DecisionTreeNodeSplitNominalBinary) parent;
SimpleSetPredicate pmmlSimpleSetPredicate = pmmlNode.addNewSimpleSetPredicate();
pmmlSimpleSetPredicate.setField(mapper.getDerivedFieldName(splitNode.getSplitAttr()));
pmmlSimpleSetPredicate.setBooleanOperator(SimpleSetPredicate.BooleanOperator.IS_IN);
ArrayType pmmlArray = pmmlSimpleSetPredicate.addNewArray();
pmmlArray.setType(ArrayType.Type.STRING);
DataCell[] splitValues = splitNode.getSplitValues();
List<Integer> indices = null;
if (splitNode.getIndex(node) == SplitNominalBinary.LEFT_PARTITION) {
indices = splitNode.getLeftChildIndices();
} else if (splitNode.getIndex(node) == SplitNominalBinary.RIGHT_PARTITION) {
indices = splitNode.getRightChildIndices();
} else {
throw new IllegalArgumentException("Split node is neither " + "contained in the right nor in the left partition.");
}
StringBuilder classSet = new StringBuilder();
for (Integer i : indices) {
if (classSet.length() > 0) {
classSet.append(" ");
}
classSet.append(splitValues[i].toString());
}
pmmlArray.setN(BigInteger.valueOf(indices.size()));
XmlCursor xmlCursor = pmmlArray.newCursor();
xmlCursor.setTextValue(classSet.toString());
xmlCursor.dispose();
} else if (parent instanceof DecisionTreeNodeSplitNominal) {
DecisionTreeNodeSplitNominal splitNode = (DecisionTreeNodeSplitNominal) parent;
SimplePredicate pmmlSimplePredicate = pmmlNode.addNewSimplePredicate();
pmmlSimplePredicate.setField(mapper.getDerivedFieldName(splitNode.getSplitAttr()));
pmmlSimplePredicate.setOperator(Operator.EQUAL);
int nodeIndex = parent.getIndex(node);
pmmlSimplePredicate.setValue(String.valueOf(splitNode.getSplitValues()[nodeIndex].toString()));
} else if (parent instanceof DecisionTreeNodeSplitPMML) {
DecisionTreeNodeSplitPMML splitNode = (DecisionTreeNodeSplitPMML) parent;
int nodeIndex = parent.getIndex(node);
// get the PMML predicate of the current node from its parent
PMMLPredicate predicate = splitNode.getSplitPred()[nodeIndex];
if (predicate instanceof PMMLCompoundPredicate) {
// surrogates as used in GBT
exportCompoundPredicate(pmmlNode, (PMMLCompoundPredicate) predicate, mapper);
} else {
predicate.setSplitAttribute(mapper.getDerivedFieldName(predicate.getSplitAttribute()));
// delegate the writing to the predicate translator
PMMLPredicateTranslator.exportTo(predicate, pmmlNode);
}
} else {
throw new IllegalArgumentException("Node Type " + parent.getClass() + " is not supported!");
}
// adding score distribution (class counts)
Set<Entry<DataCell, Double>> classCounts = node.getClassCounts().entrySet();
Iterator<Entry<DataCell, Double>> iterator = classCounts.iterator();
while (iterator.hasNext()) {
Entry<DataCell, Double> entry = iterator.next();
DataCell cell = entry.getKey();
Double freq = entry.getValue();
ScoreDistribution pmmlScoreDist = pmmlNode.addNewScoreDistribution();
pmmlScoreDist.setValue(cell.toString());
pmmlScoreDist.setRecordCount(freq);
}
// adding children
if (!(node instanceof DecisionTreeNodeLeaf)) {
for (int i = 0; i < node.getChildCount(); i++) {
addTreeNode(pmmlNode.addNewNode(), node.getChildAt(i), mapper);
}
}
}
use of org.knime.base.node.mine.decisiontree2.PMMLPredicate in project knime-core by knime.
the class DecisionTreeNodeSplitPMML method getChildUsedNominalSplitAttributeValues.
/**
* {@inheritDoc}
*/
@Override
protected Pair<String, Set<String>> getChildUsedNominalSplitAttributeValues(final int childIndex) {
PMMLPredicate pmmlPredicate = m_splitPred[childIndex];
Set<String> values = pmmlPredicate.getUsedNominalSplitAttributeValues();
if (values != null) {
String splitAttribute = pmmlPredicate.getSplitAttribute();
return new Pair<String, Set<String>>(splitAttribute, values);
}
return null;
}
use of org.knime.base.node.mine.decisiontree2.PMMLPredicate in project knime-core by knime.
the class DecisionTreeNodeSplitPMML method retainOnlyAttributeValuesInChild.
/**
* {@inheritDoc}
*/
@Override
protected void retainOnlyAttributeValuesInChild(final int childIndex, final Set<String> toBeRetained) {
PMMLPredicate pmmlPredicate = m_splitPred[childIndex];
pmmlPredicate.retainOnlyAttributeValues(toBeRetained);
// the "prefix" is used in the decision tree view as
// label for the condition
getChildAt(childIndex).setPrefix(pmmlPredicate.toString());
}
Aggregations