use of org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplitPMML in project knime-core by knime.
the class FromDecisionTreeNodeModel method addRules.
/**
* Adds the rules to {@code rs} (recursively on each leaf).
*
* @param rs The output {@link RuleSet}.
* @param parents The parent stack.
* @param node The actual node.
*/
private void addRules(final RuleSet rs, final List<DecisionTreeNode> parents, final DecisionTreeNode node) {
if (node.isLeaf()) {
SimpleRule rule = rs.addNewSimpleRule();
if (m_rulesToTable.getScorePmmlRecordCount().getBooleanValue()) {
// This increases the PMML quite significantly
BigDecimal sum = BigDecimal.ZERO;
final MathContext mc = new MathContext(7, RoundingMode.HALF_EVEN);
final boolean computeProbability = m_rulesToTable.getScorePmmlProbability().getBooleanValue();
if (computeProbability) {
sum = new BigDecimal(node.getClassCounts().entrySet().stream().mapToDouble(e -> e.getValue().doubleValue()).sum(), mc);
}
for (final Entry<DataCell, Double> entry : node.getClassCounts().entrySet()) {
final ScoreDistribution scoreDistrib = rule.addNewScoreDistribution();
scoreDistrib.setValue(entry.getKey().toString());
scoreDistrib.setRecordCount(entry.getValue());
if (computeProbability) {
if (Double.compare(entry.getValue().doubleValue(), 0.0) == 0) {
scoreDistrib.setProbability(new BigDecimal(0.0));
} else {
scoreDistrib.setProbability(new BigDecimal(entry.getValue().doubleValue(), mc).divide(sum, mc));
}
}
}
}
CompoundPredicate and = rule.addNewCompoundPredicate();
and.setBooleanOperator(BooleanOperator.AND);
DecisionTreeNode n = node;
do {
PMMLPredicate pmmlPredicate = ((DecisionTreeNodeSplitPMML) n.getParent()).getSplitPred()[n.getParent().getIndex(n)];
if (pmmlPredicate instanceof PMMLSimplePredicate) {
PMMLSimplePredicate simple = (PMMLSimplePredicate) pmmlPredicate;
SimplePredicate predicate = and.addNewSimplePredicate();
copy(predicate, simple);
} else if (pmmlPredicate instanceof PMMLCompoundPredicate) {
PMMLCompoundPredicate compound = (PMMLCompoundPredicate) pmmlPredicate;
CompoundPredicate predicate = and.addNewCompoundPredicate();
copy(predicate, compound);
} else if (pmmlPredicate instanceof PMMLSimpleSetPredicate) {
PMMLSimpleSetPredicate simpleSet = (PMMLSimpleSetPredicate) pmmlPredicate;
copy(and.addNewSimpleSetPredicate(), simpleSet);
} else if (pmmlPredicate instanceof PMMLTruePredicate) {
and.addNewTrue();
} else if (pmmlPredicate instanceof PMMLFalsePredicate) {
and.addNewFalse();
}
n = n.getParent();
} while (n.getParent() != null);
// Simple fix for the case when a single condition was used.
while (and.getFalseList().size() + and.getCompoundPredicateList().size() + and.getSimplePredicateList().size() + and.getSimpleSetPredicateList().size() + and.getTrueList().size() < 2) {
and.addNewTrue();
}
if (m_rulesToTable.getProvideStatistics().getBooleanValue()) {
rule.setNbCorrect(node.getOwnClassCount());
rule.setRecordCount(node.getEntireClassCount());
}
rule.setScore(node.getMajorityClass().toString());
} else {
parents.add(node);
for (int i = 0; i < node.getChildCount(); ++i) {
addRules(rs, parents, node.getChildAt(i));
}
parents.remove(node);
}
}
use of org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplitPMML in project knime-core by knime.
the class TreeNodeRegression method createDecisionTreeNode.
/**
* @param metaData
* @return
*/
public DecisionTreeNode createDecisionTreeNode(final MutableInteger idGenerator, final TreeMetaData metaData) {
DataCell majorityCell = new StringCell(DoubleFormat.formatDouble(m_mean));
final int nrChildren = getNrChildren();
LinkedHashMap<DataCell, Double> distributionMap = new LinkedHashMap<DataCell, Double>();
distributionMap.put(majorityCell, m_totalSum);
if (nrChildren == 0) {
return new DecisionTreeNodeLeaf(idGenerator.inc(), majorityCell, distributionMap);
} else {
int id = idGenerator.inc();
DecisionTreeNode[] childNodes = new DecisionTreeNode[nrChildren];
int splitAttributeIndex = getSplitAttributeIndex();
assert splitAttributeIndex >= 0 : "non-leaf node has no split";
String splitAttribute = metaData.getAttributeMetaData(splitAttributeIndex).getAttributeName();
PMMLPredicate[] childPredicates = new PMMLPredicate[nrChildren];
for (int i = 0; i < nrChildren; i++) {
final TreeNodeRegression treeNode = getChild(i);
TreeNodeCondition cond = treeNode.getCondition();
childPredicates[i] = cond.toPMMLPredicate();
childNodes[i] = treeNode.createDecisionTreeNode(idGenerator, metaData);
}
return new DecisionTreeNodeSplitPMML(id, majorityCell, distributionMap, splitAttribute, childPredicates, childNodes);
}
}
use of org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplitPMML in project knime-core by knime.
the class PMMLDecisionTreeTranslator method addTreeNode.
/**
* A recursive function which converts each KNIME Tree node to a
* corresponding PMML element.
*
* @param pmmlNode the desired PMML element
* @param node A KNIME DecisionTree node
*/
private static void addTreeNode(final NodeDocument.Node pmmlNode, final DecisionTreeNode node, final DerivedFieldMapper mapper) {
pmmlNode.setId(String.valueOf(node.getOwnIndex()));
pmmlNode.setScore(node.getMajorityClass().toString());
// read in and then exported again
if (node.getEntireClassCount() > 0) {
pmmlNode.setRecordCount(node.getEntireClassCount());
}
if (node instanceof DecisionTreeNodeSplitPMML) {
int defaultChild = ((DecisionTreeNodeSplitPMML) node).getDefaultChildIndex();
if (defaultChild > -1) {
pmmlNode.setDefaultChild(String.valueOf(defaultChild));
}
}
// adding score and stuff from parent
DecisionTreeNode parent = node.getParent();
if (parent == null) {
// When the parent is null, it is the root Node.
// For root node, the predicate is always True.
pmmlNode.addNewTrue();
} else if (parent instanceof DecisionTreeNodeSplitContinuous) {
// SimplePredicate case
DecisionTreeNodeSplitContinuous splitNode = (DecisionTreeNodeSplitContinuous) parent;
if (splitNode.getIndex(node) == 0) {
SimplePredicate pmmlSimplePredicate = pmmlNode.addNewSimplePredicate();
pmmlSimplePredicate.setField(mapper.getDerivedFieldName(splitNode.getSplitAttr()));
pmmlSimplePredicate.setOperator(Operator.LESS_OR_EQUAL);
pmmlSimplePredicate.setValue(String.valueOf(splitNode.getThreshold()));
} else if (splitNode.getIndex(node) == 1) {
pmmlNode.addNewTrue();
}
} else if (parent instanceof DecisionTreeNodeSplitNominalBinary) {
// SimpleSetPredicate case
DecisionTreeNodeSplitNominalBinary splitNode = (DecisionTreeNodeSplitNominalBinary) parent;
SimpleSetPredicate pmmlSimpleSetPredicate = pmmlNode.addNewSimpleSetPredicate();
pmmlSimpleSetPredicate.setField(mapper.getDerivedFieldName(splitNode.getSplitAttr()));
pmmlSimpleSetPredicate.setBooleanOperator(SimpleSetPredicate.BooleanOperator.IS_IN);
ArrayType pmmlArray = pmmlSimpleSetPredicate.addNewArray();
pmmlArray.setType(ArrayType.Type.STRING);
DataCell[] splitValues = splitNode.getSplitValues();
List<Integer> indices = null;
if (splitNode.getIndex(node) == SplitNominalBinary.LEFT_PARTITION) {
indices = splitNode.getLeftChildIndices();
} else if (splitNode.getIndex(node) == SplitNominalBinary.RIGHT_PARTITION) {
indices = splitNode.getRightChildIndices();
} else {
throw new IllegalArgumentException("Split node is neither " + "contained in the right nor in the left partition.");
}
StringBuilder classSet = new StringBuilder();
for (Integer i : indices) {
if (classSet.length() > 0) {
classSet.append(" ");
}
classSet.append(splitValues[i].toString());
}
pmmlArray.setN(BigInteger.valueOf(indices.size()));
XmlCursor xmlCursor = pmmlArray.newCursor();
xmlCursor.setTextValue(classSet.toString());
xmlCursor.dispose();
} else if (parent instanceof DecisionTreeNodeSplitNominal) {
DecisionTreeNodeSplitNominal splitNode = (DecisionTreeNodeSplitNominal) parent;
SimplePredicate pmmlSimplePredicate = pmmlNode.addNewSimplePredicate();
pmmlSimplePredicate.setField(mapper.getDerivedFieldName(splitNode.getSplitAttr()));
pmmlSimplePredicate.setOperator(Operator.EQUAL);
int nodeIndex = parent.getIndex(node);
pmmlSimplePredicate.setValue(String.valueOf(splitNode.getSplitValues()[nodeIndex].toString()));
} else if (parent instanceof DecisionTreeNodeSplitPMML) {
DecisionTreeNodeSplitPMML splitNode = (DecisionTreeNodeSplitPMML) parent;
int nodeIndex = parent.getIndex(node);
// get the PMML predicate of the current node from its parent
PMMLPredicate predicate = splitNode.getSplitPred()[nodeIndex];
if (predicate instanceof PMMLCompoundPredicate) {
// surrogates as used in GBT
exportCompoundPredicate(pmmlNode, (PMMLCompoundPredicate) predicate, mapper);
} else {
predicate.setSplitAttribute(mapper.getDerivedFieldName(predicate.getSplitAttribute()));
// delegate the writing to the predicate translator
PMMLPredicateTranslator.exportTo(predicate, pmmlNode);
}
} else {
throw new IllegalArgumentException("Node Type " + parent.getClass() + " is not supported!");
}
// adding score distribution (class counts)
Set<Entry<DataCell, Double>> classCounts = node.getClassCounts().entrySet();
Iterator<Entry<DataCell, Double>> iterator = classCounts.iterator();
while (iterator.hasNext()) {
Entry<DataCell, Double> entry = iterator.next();
DataCell cell = entry.getKey();
Double freq = entry.getValue();
ScoreDistribution pmmlScoreDist = pmmlNode.addNewScoreDistribution();
pmmlScoreDist.setValue(cell.toString());
pmmlScoreDist.setRecordCount(freq);
}
// adding children
if (!(node instanceof DecisionTreeNodeLeaf)) {
for (int i = 0; i < node.getChildCount(); i++) {
addTreeNode(pmmlNode.addNewNode(), node.getChildAt(i), mapper);
}
}
}
use of org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplitPMML in project knime-core by knime.
the class DecisionTreeLearnerNodeModel method buildTree.
/**
* Recursively induces the decision tree.
*
* @param table the {@link InMemoryTable} representing the data for this
* node to determine the split and after that perform
* partitioning
* @param exec the execution context for progress information
* @param depth the current recursion depth
*/
private DecisionTreeNode buildTree(final InMemoryTable table, final ExecutionContext exec, final int depth, final SplitQualityMeasure splitQualityMeasure, final ParallelProcessing parallelProcessing) throws CanceledExecutionException, IllegalAccessException {
exec.checkCanceled();
// derive this node's id from the counter
int nodeId = m_counter.getAndIncrement();
DataCell majorityClass = table.getMajorityClassAsCell();
LinkedHashMap<DataCell, Double> frequencies = table.getClassFrequencies();
// if the distribution allows for a leaf
if (table.isPureEnough()) {
// free memory
table.freeUnderlyingDataRows();
double value = m_finishedCounter.incrementAndGet(table.getSumOfWeights());
exec.setProgress(value / m_alloverRowCount, "Created node with id " + nodeId + " at level " + depth);
return new DecisionTreeNodeLeaf(nodeId, majorityClass, frequencies);
} else {
// find the best splits for all attributes
SplitFinder splittFinder = new SplitFinder(table, splitQualityMeasure, m_averageSplitpoint.getBooleanValue(), m_minNumberRecordsPerNode.getIntValue(), m_binaryNominalSplitMode.getBooleanValue(), m_maxNumNominalsForCompleteComputation.getIntValue());
// check for enough memory
checkMemory();
// get the best split among the best attribute splits
Split split = splittFinder.getSplit();
// if no best split could be evaluated, create a leaf node
if (split == null || !split.isValidSplit()) {
table.freeUnderlyingDataRows();
double value = m_finishedCounter.incrementAndGet(table.getSumOfWeights());
exec.setProgress(value / m_alloverRowCount, "Created node with id " + nodeId + " at level " + depth);
return new DecisionTreeNodeLeaf(nodeId, majorityClass, frequencies);
}
// partition the attribute lists according to this split
Partitioner partitioner = new Partitioner(table, split, m_minNumberRecordsPerNode.getIntValue());
if (!partitioner.couldBeUsefulPartitioned()) {
table.freeUnderlyingDataRows();
double value = m_finishedCounter.incrementAndGet(table.getSumOfWeights());
exec.setProgress(value / m_alloverRowCount, "Created node with id " + nodeId + " at level " + depth);
return new DecisionTreeNodeLeaf(nodeId, majorityClass, frequencies);
}
// get the just created partitions
InMemoryTable[] partitionTables = partitioner.getPartitionTables();
// recursively build the child nodes
DecisionTreeNode[] children = new DecisionTreeNode[partitionTables.length];
ArrayList<ParallelBuilding> threads = new ArrayList<ParallelBuilding>();
int i = 0;
for (InMemoryTable partitionTable : partitionTables) {
exec.checkCanceled();
if (partitionTable.getNumberDataRows() * m_numberAttributes < 10000 || !parallelProcessing.isThreadAvailable()) {
children[i] = buildTree(partitionTable, exec, depth + 1, splitQualityMeasure, parallelProcessing);
} else {
String threadName = "Build thread, node: " + nodeId + "." + i;
ParallelBuilding buildThread = new ParallelBuilding(threadName, partitionTable, exec, depth + 1, i, splitQualityMeasure, parallelProcessing);
LOGGER.debug("Start new parallel building thread: " + threadName);
threads.add(buildThread);
buildThread.start();
}
i++;
}
// already assigned to the child array
for (ParallelBuilding buildThread : threads) {
children[buildThread.getThreadIndex()] = buildThread.getResultNode();
exec.checkCanceled();
if (buildThread.getException() != null) {
for (ParallelBuilding buildThread2 : threads) {
buildThread2.stop();
}
throw new RuntimeException(buildThread.getException().getMessage());
}
}
threads.clear();
if (split instanceof SplitContinuous) {
double splitValue = ((SplitContinuous) split).getBestSplitValue();
// return new DecisionTreeNodeSplitContinuous(nodeId,
// majorityClass, frequencies, split
// .getSplitAttributeName(), children, splitValue);
String splitAttribute = split.getSplitAttributeName();
PMMLPredicate[] splitPredicates = new PMMLPredicate[] { new PMMLSimplePredicate(splitAttribute, PMMLOperator.LESS_OR_EQUAL, Double.toString(splitValue)), new PMMLSimplePredicate(splitAttribute, PMMLOperator.GREATER_THAN, Double.toString(splitValue)) };
return new DecisionTreeNodeSplitPMML(nodeId, majorityClass, frequencies, splitAttribute, splitPredicates, children);
} else if (split instanceof SplitNominalNormal) {
// else the attribute is nominal
DataCell[] splitValues = ((SplitNominalNormal) split).getSplitValues();
// return new DecisionTreeNodeSplitNominal(nodeId, majorityClass,
// frequencies, split.getSplitAttributeName(),
// splitValues, children);
int num = children.length;
PMMLPredicate[] splitPredicates = new PMMLPredicate[num];
String splitAttribute = split.getSplitAttributeName();
for (int j = 0; j < num; j++) {
splitPredicates[j] = new PMMLSimplePredicate(splitAttribute, PMMLOperator.EQUAL, splitValues[j].toString());
}
return new DecisionTreeNodeSplitPMML(nodeId, majorityClass, frequencies, splitAttribute, splitPredicates, children);
} else {
// binary nominal
SplitNominalBinary splitNominalBinary = (SplitNominalBinary) split;
DataCell[] splitValues = splitNominalBinary.getSplitValues();
// return new DecisionTreeNodeSplitNominalBinary(nodeId,
// majorityClass, frequencies, split
// .getSplitAttributeName(), splitValues,
// splitNominalBinary.getIntMappingsLeftPartition(),
// splitNominalBinary.getIntMappingsRightPartition(),
// children/* children[0]=left, ..[1] right */);
String splitAttribute = split.getSplitAttributeName();
int[][] indices = new int[][] { splitNominalBinary.getIntMappingsLeftPartition(), splitNominalBinary.getIntMappingsRightPartition() };
PMMLPredicate[] splitPredicates = new PMMLPredicate[2];
for (int j = 0; j < splitPredicates.length; j++) {
PMMLSimpleSetPredicate pred = null;
pred = new PMMLSimpleSetPredicate(splitAttribute, PMMLSetOperator.IS_IN);
pred.setArrayType(PMMLArrayType.STRING);
LinkedHashSet<String> values = new LinkedHashSet<String>();
for (int index : indices[j]) {
values.add(splitValues[index].toString());
}
pred.setValues(values);
splitPredicates[j] = pred;
}
return new DecisionTreeNodeSplitPMML(nodeId, majorityClass, frequencies, splitAttribute, splitPredicates, children);
}
}
}
use of org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplitPMML in project knime-core by knime.
the class TreeNodeClassification method createDecisionTreeNode.
/**
* Creates DecisionTreeNode model that is used in Decision Tree of KNIME
*
* @param idGenerator
* @param metaData
* @return a DecisionTreeNode
*/
public DecisionTreeNode createDecisionTreeNode(final MutableInteger idGenerator, final TreeMetaData metaData) {
DataCell majorityCell = new StringCell(getMajorityClassName());
double[] targetDistribution = getTargetDistribution();
int initSize = (int) (targetDistribution.length / 0.75 + 1.0);
LinkedHashMap<DataCell, Double> scoreDistributionMap = new LinkedHashMap<DataCell, Double>(initSize);
NominalValueRepresentation[] targets = getTargetMetaData().getValues();
for (int i = 0; i < targetDistribution.length; i++) {
String cl = targets[i].getNominalValue();
double d = targetDistribution[i];
scoreDistributionMap.put(new StringCell(cl), d);
}
final int nrChildren = getNrChildren();
if (nrChildren == 0) {
return new DecisionTreeNodeLeaf(idGenerator.inc(), majorityCell, scoreDistributionMap);
} else {
int id = idGenerator.inc();
DecisionTreeNode[] childNodes = new DecisionTreeNode[nrChildren];
int splitAttributeIndex = getSplitAttributeIndex();
assert splitAttributeIndex >= 0 : "non-leaf node has no split";
String splitAttribute = metaData.getAttributeMetaData(splitAttributeIndex).getAttributeName();
PMMLPredicate[] childPredicates = new PMMLPredicate[nrChildren];
for (int i = 0; i < nrChildren; i++) {
final TreeNodeClassification treeNode = getChild(i);
TreeNodeCondition cond = treeNode.getCondition();
childPredicates[i] = cond.toPMMLPredicate();
childNodes[i] = treeNode.createDecisionTreeNode(idGenerator, metaData);
}
return new DecisionTreeNodeSplitPMML(id, majorityCell, scoreDistributionMap, splitAttribute, childPredicates, childNodes);
}
}
Aggregations