use of org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeLeaf in project knime-core by knime.
the class TreeNodeRegression method createDecisionTreeNode.
/**
* @param metaData
* @return
*/
public DecisionTreeNode createDecisionTreeNode(final MutableInteger idGenerator, final TreeMetaData metaData) {
DataCell majorityCell = new StringCell(DoubleFormat.formatDouble(m_mean));
final int nrChildren = getNrChildren();
LinkedHashMap<DataCell, Double> distributionMap = new LinkedHashMap<DataCell, Double>();
distributionMap.put(majorityCell, m_totalSum);
if (nrChildren == 0) {
return new DecisionTreeNodeLeaf(idGenerator.inc(), majorityCell, distributionMap);
} else {
int id = idGenerator.inc();
DecisionTreeNode[] childNodes = new DecisionTreeNode[nrChildren];
int splitAttributeIndex = getSplitAttributeIndex();
assert splitAttributeIndex >= 0 : "non-leaf node has no split";
String splitAttribute = metaData.getAttributeMetaData(splitAttributeIndex).getAttributeName();
PMMLPredicate[] childPredicates = new PMMLPredicate[nrChildren];
for (int i = 0; i < nrChildren; i++) {
final TreeNodeRegression treeNode = getChild(i);
TreeNodeCondition cond = treeNode.getCondition();
childPredicates[i] = cond.toPMMLPredicate();
childNodes[i] = treeNode.createDecisionTreeNode(idGenerator, metaData);
}
return new DecisionTreeNodeSplitPMML(id, majorityCell, distributionMap, splitAttribute, childPredicates, childNodes);
}
}
use of org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeLeaf in project knime-core by knime.
the class PMMLDecisionTreeTranslator method addTreeNode.
/**
* A recursive function which converts each KNIME Tree node to a
* corresponding PMML element.
*
* @param pmmlNode the desired PMML element
* @param node A KNIME DecisionTree node
*/
private static void addTreeNode(final NodeDocument.Node pmmlNode, final DecisionTreeNode node, final DerivedFieldMapper mapper) {
pmmlNode.setId(String.valueOf(node.getOwnIndex()));
pmmlNode.setScore(node.getMajorityClass().toString());
// read in and then exported again
if (node.getEntireClassCount() > 0) {
pmmlNode.setRecordCount(node.getEntireClassCount());
}
if (node instanceof DecisionTreeNodeSplitPMML) {
int defaultChild = ((DecisionTreeNodeSplitPMML) node).getDefaultChildIndex();
if (defaultChild > -1) {
pmmlNode.setDefaultChild(String.valueOf(defaultChild));
}
}
// adding score and stuff from parent
DecisionTreeNode parent = node.getParent();
if (parent == null) {
// When the parent is null, it is the root Node.
// For root node, the predicate is always True.
pmmlNode.addNewTrue();
} else if (parent instanceof DecisionTreeNodeSplitContinuous) {
// SimplePredicate case
DecisionTreeNodeSplitContinuous splitNode = (DecisionTreeNodeSplitContinuous) parent;
if (splitNode.getIndex(node) == 0) {
SimplePredicate pmmlSimplePredicate = pmmlNode.addNewSimplePredicate();
pmmlSimplePredicate.setField(mapper.getDerivedFieldName(splitNode.getSplitAttr()));
pmmlSimplePredicate.setOperator(Operator.LESS_OR_EQUAL);
pmmlSimplePredicate.setValue(String.valueOf(splitNode.getThreshold()));
} else if (splitNode.getIndex(node) == 1) {
pmmlNode.addNewTrue();
}
} else if (parent instanceof DecisionTreeNodeSplitNominalBinary) {
// SimpleSetPredicate case
DecisionTreeNodeSplitNominalBinary splitNode = (DecisionTreeNodeSplitNominalBinary) parent;
SimpleSetPredicate pmmlSimpleSetPredicate = pmmlNode.addNewSimpleSetPredicate();
pmmlSimpleSetPredicate.setField(mapper.getDerivedFieldName(splitNode.getSplitAttr()));
pmmlSimpleSetPredicate.setBooleanOperator(SimpleSetPredicate.BooleanOperator.IS_IN);
ArrayType pmmlArray = pmmlSimpleSetPredicate.addNewArray();
pmmlArray.setType(ArrayType.Type.STRING);
DataCell[] splitValues = splitNode.getSplitValues();
List<Integer> indices = null;
if (splitNode.getIndex(node) == SplitNominalBinary.LEFT_PARTITION) {
indices = splitNode.getLeftChildIndices();
} else if (splitNode.getIndex(node) == SplitNominalBinary.RIGHT_PARTITION) {
indices = splitNode.getRightChildIndices();
} else {
throw new IllegalArgumentException("Split node is neither " + "contained in the right nor in the left partition.");
}
StringBuilder classSet = new StringBuilder();
for (Integer i : indices) {
if (classSet.length() > 0) {
classSet.append(" ");
}
classSet.append(splitValues[i].toString());
}
pmmlArray.setN(BigInteger.valueOf(indices.size()));
XmlCursor xmlCursor = pmmlArray.newCursor();
xmlCursor.setTextValue(classSet.toString());
xmlCursor.dispose();
} else if (parent instanceof DecisionTreeNodeSplitNominal) {
DecisionTreeNodeSplitNominal splitNode = (DecisionTreeNodeSplitNominal) parent;
SimplePredicate pmmlSimplePredicate = pmmlNode.addNewSimplePredicate();
pmmlSimplePredicate.setField(mapper.getDerivedFieldName(splitNode.getSplitAttr()));
pmmlSimplePredicate.setOperator(Operator.EQUAL);
int nodeIndex = parent.getIndex(node);
pmmlSimplePredicate.setValue(String.valueOf(splitNode.getSplitValues()[nodeIndex].toString()));
} else if (parent instanceof DecisionTreeNodeSplitPMML) {
DecisionTreeNodeSplitPMML splitNode = (DecisionTreeNodeSplitPMML) parent;
int nodeIndex = parent.getIndex(node);
// get the PMML predicate of the current node from its parent
PMMLPredicate predicate = splitNode.getSplitPred()[nodeIndex];
if (predicate instanceof PMMLCompoundPredicate) {
// surrogates as used in GBT
exportCompoundPredicate(pmmlNode, (PMMLCompoundPredicate) predicate, mapper);
} else {
predicate.setSplitAttribute(mapper.getDerivedFieldName(predicate.getSplitAttribute()));
// delegate the writing to the predicate translator
PMMLPredicateTranslator.exportTo(predicate, pmmlNode);
}
} else {
throw new IllegalArgumentException("Node Type " + parent.getClass() + " is not supported!");
}
// adding score distribution (class counts)
Set<Entry<DataCell, Double>> classCounts = node.getClassCounts().entrySet();
Iterator<Entry<DataCell, Double>> iterator = classCounts.iterator();
while (iterator.hasNext()) {
Entry<DataCell, Double> entry = iterator.next();
DataCell cell = entry.getKey();
Double freq = entry.getValue();
ScoreDistribution pmmlScoreDist = pmmlNode.addNewScoreDistribution();
pmmlScoreDist.setValue(cell.toString());
pmmlScoreDist.setRecordCount(freq);
}
// adding children
if (!(node instanceof DecisionTreeNodeLeaf)) {
for (int i = 0; i < node.getChildCount(); i++) {
addTreeNode(pmmlNode.addNewNode(), node.getChildAt(i), mapper);
}
}
}
use of org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeLeaf in project knime-core by knime.
the class Pruner method mdlPruningRecurse.
// /**
// * The general idea is to recursively prune the children and then compare
// * the potential leaf estimated erro with the actual estimated error
// * including the length of the children.
// *
// * @param node the node to prune
// * @param zValue the z value according to which the error is estimated
// * calculated from the confidence value
// *
// * @return the resulting description length after pruning; this value is
// * used in higher levels of the recursion, i.e. for the parent node
// */
// private static PruningResult estimatedErrorPruningRecurse(
// final DecisionTreeNode node, final double zValue) {
//
// // if this is a child, just return the estimated error
// if (node.isLeaf()) {
// double error = node.getEntireClassCount() - node.getOwnClassCount();
// double estimatedError =
// estimatedError(node.getEntireClassCount(), error, zValue);
//
// return new PruningResult(estimatedError, node);
// }
//
// // holds the estimated errors of the children
// double[] childDescriptionLength = new double[node.getChildCount()];
// DecisionTreeNodeSplit splitNode = (DecisionTreeNodeSplit)node;
// // prune all children
// DecisionTreeNode[] children = splitNode.getChildren();
// int count = 0;
// for (DecisionTreeNode childNode : children) {
//
// PruningResult result =
// estimatedErrorPruningRecurse(childNode, zValue);
// childDescriptionLength[count] = result.getQualityValue();
//
// // replace the child with the one from the result (could of course
// // be the same)
// splitNode.replaceChild(childNode, result.getNode());
//
// count++;
// }
//
// // calculate the estimated error if this would be a leaf
// double error = node.getEntireClassCount() - node.getOwnClassCount();
// double leafEstimatedError =
// estimatedError(node.getEntireClassCount(), error, zValue);
//
// // calculate the current estimated error (sum of estimated errors of the
// // children)
// double currentEstimatedError = 0;
// for (double childDescLength : childDescriptionLength) {
// currentEstimatedError += childDescLength;
// }
//
// // define the return node
// DecisionTreeNode returnNode = node;
// double returnEstimatedError = currentEstimatedError;
//
// // if the possible leaf costs are smaller, replace this node
// // with a leaf (tollerance is 0.1)
// if (leafEstimatedError <= currentEstimatedError + 0.1) {
// DecisionTreeNodeLeaf newLeaf =
// new DecisionTreeNodeLeaf(node.getOwnIndex(), node
// .getMajorityClass(), node.getClassCounts());
// newLeaf.setParent((DecisionTreeNode)node.getParent());
// newLeaf.setPrefix(node.getPrefix());
// returnNode = newLeaf;
// returnEstimatedError = leafEstimatedError;
// }
//
// return new PruningResult(returnEstimatedError, returnNode);
// }
//
// /**
// * Prunes a {@link DecisionTree} according to the estimated error pruning
// * (Quinlan 87).
// *
// * @param decTree the decision tree to prune
// * @param confidence the confidence value according to which the error is
// * estimated
// */
// public static void estimatedErrorPruning(final DecisionTree decTree,
// final double confidence) {
//
// // traverse the tree depth first (in-fix)
// DecisionTreeNode root = decTree.getRootNode();
// // double zValue = xnormi(1 - confidence);
// estimatedErrorPruningRecurse(root, zValue);
// }
/**
* The general idea is to recursively prune the children and then compare
* the potential leaf description length with the actual length including
* the length of the children.
*
* @param node the node to prune
*
* @return the resulting description length after pruning; this value is
* used in higher levels of the recursion, i.e. for the parent node
*/
private static PruningResult mdlPruningRecurse(final DecisionTreeNode node) {
// leaf
if (node.isLeaf()) {
double error = node.getEntireClassCount() - node.getOwnClassCount();
// node => 1Bit)
return new PruningResult(error + 1.0, node);
}
// holds the description length of the children
double[] childDescriptionLength = new double[node.getChildCount()];
DecisionTreeNodeSplit splitNode = (DecisionTreeNodeSplit) node;
// prune all children
DecisionTreeNode[] children = splitNode.getChildren();
int count = 0;
for (DecisionTreeNode childNode : children) {
PruningResult result = mdlPruningRecurse(childNode);
childDescriptionLength[count] = result.getQualityValue();
// replace the child with the one from the result (could of course
// be the same)
splitNode.replaceChild(childNode, result.getNode());
count++;
}
// calculate the cost if this would be a leaf
double leafCost = node.getEntireClassCount() - node.getOwnClassCount() + 1.0;
// calculate the current cost including the children
double currentCost = 1.0 + Math.log(node.getChildCount()) / Math.log(2);
for (double childDescLength : childDescriptionLength) {
currentCost += childDescLength;
}
// define the return node
DecisionTreeNode returnNode = node;
double returnCost = currentCost;
// with a leaf
if (leafCost <= currentCost) {
DecisionTreeNodeLeaf newLeaf = new DecisionTreeNodeLeaf(node.getOwnIndex(), node.getMajorityClass(), node.getClassCounts());
newLeaf.setParent(node.getParent());
newLeaf.setPrefix(node.getPrefix());
returnNode = newLeaf;
returnCost = leafCost;
}
return new PruningResult(returnCost, returnNode);
}
use of org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeLeaf in project knime-core by knime.
the class DecisionTreeLearnerNodeModel method buildTree.
/**
* Recursively induces the decision tree.
*
* @param table the {@link InMemoryTable} representing the data for this
* node to determine the split and after that perform
* partitioning
* @param exec the execution context for progress information
* @param depth the current recursion depth
*/
private DecisionTreeNode buildTree(final InMemoryTable table, final ExecutionContext exec, final int depth, final SplitQualityMeasure splitQualityMeasure, final ParallelProcessing parallelProcessing) throws CanceledExecutionException, IllegalAccessException {
exec.checkCanceled();
// derive this node's id from the counter
int nodeId = m_counter.getAndIncrement();
DataCell majorityClass = table.getMajorityClassAsCell();
LinkedHashMap<DataCell, Double> frequencies = table.getClassFrequencies();
// if the distribution allows for a leaf
if (table.isPureEnough()) {
// free memory
table.freeUnderlyingDataRows();
double value = m_finishedCounter.incrementAndGet(table.getSumOfWeights());
exec.setProgress(value / m_alloverRowCount, "Created node with id " + nodeId + " at level " + depth);
return new DecisionTreeNodeLeaf(nodeId, majorityClass, frequencies);
} else {
// find the best splits for all attributes
SplitFinder splittFinder = new SplitFinder(table, splitQualityMeasure, m_averageSplitpoint.getBooleanValue(), m_minNumberRecordsPerNode.getIntValue(), m_binaryNominalSplitMode.getBooleanValue(), m_maxNumNominalsForCompleteComputation.getIntValue());
// check for enough memory
checkMemory();
// get the best split among the best attribute splits
Split split = splittFinder.getSplit();
// if no best split could be evaluated, create a leaf node
if (split == null || !split.isValidSplit()) {
table.freeUnderlyingDataRows();
double value = m_finishedCounter.incrementAndGet(table.getSumOfWeights());
exec.setProgress(value / m_alloverRowCount, "Created node with id " + nodeId + " at level " + depth);
return new DecisionTreeNodeLeaf(nodeId, majorityClass, frequencies);
}
// partition the attribute lists according to this split
Partitioner partitioner = new Partitioner(table, split, m_minNumberRecordsPerNode.getIntValue());
if (!partitioner.couldBeUsefulPartitioned()) {
table.freeUnderlyingDataRows();
double value = m_finishedCounter.incrementAndGet(table.getSumOfWeights());
exec.setProgress(value / m_alloverRowCount, "Created node with id " + nodeId + " at level " + depth);
return new DecisionTreeNodeLeaf(nodeId, majorityClass, frequencies);
}
// get the just created partitions
InMemoryTable[] partitionTables = partitioner.getPartitionTables();
// recursively build the child nodes
DecisionTreeNode[] children = new DecisionTreeNode[partitionTables.length];
ArrayList<ParallelBuilding> threads = new ArrayList<ParallelBuilding>();
int i = 0;
for (InMemoryTable partitionTable : partitionTables) {
exec.checkCanceled();
if (partitionTable.getNumberDataRows() * m_numberAttributes < 10000 || !parallelProcessing.isThreadAvailable()) {
children[i] = buildTree(partitionTable, exec, depth + 1, splitQualityMeasure, parallelProcessing);
} else {
String threadName = "Build thread, node: " + nodeId + "." + i;
ParallelBuilding buildThread = new ParallelBuilding(threadName, partitionTable, exec, depth + 1, i, splitQualityMeasure, parallelProcessing);
LOGGER.debug("Start new parallel building thread: " + threadName);
threads.add(buildThread);
buildThread.start();
}
i++;
}
// already assigned to the child array
for (ParallelBuilding buildThread : threads) {
children[buildThread.getThreadIndex()] = buildThread.getResultNode();
exec.checkCanceled();
if (buildThread.getException() != null) {
for (ParallelBuilding buildThread2 : threads) {
buildThread2.stop();
}
throw new RuntimeException(buildThread.getException().getMessage());
}
}
threads.clear();
if (split instanceof SplitContinuous) {
double splitValue = ((SplitContinuous) split).getBestSplitValue();
// return new DecisionTreeNodeSplitContinuous(nodeId,
// majorityClass, frequencies, split
// .getSplitAttributeName(), children, splitValue);
String splitAttribute = split.getSplitAttributeName();
PMMLPredicate[] splitPredicates = new PMMLPredicate[] { new PMMLSimplePredicate(splitAttribute, PMMLOperator.LESS_OR_EQUAL, Double.toString(splitValue)), new PMMLSimplePredicate(splitAttribute, PMMLOperator.GREATER_THAN, Double.toString(splitValue)) };
return new DecisionTreeNodeSplitPMML(nodeId, majorityClass, frequencies, splitAttribute, splitPredicates, children);
} else if (split instanceof SplitNominalNormal) {
// else the attribute is nominal
DataCell[] splitValues = ((SplitNominalNormal) split).getSplitValues();
// return new DecisionTreeNodeSplitNominal(nodeId, majorityClass,
// frequencies, split.getSplitAttributeName(),
// splitValues, children);
int num = children.length;
PMMLPredicate[] splitPredicates = new PMMLPredicate[num];
String splitAttribute = split.getSplitAttributeName();
for (int j = 0; j < num; j++) {
splitPredicates[j] = new PMMLSimplePredicate(splitAttribute, PMMLOperator.EQUAL, splitValues[j].toString());
}
return new DecisionTreeNodeSplitPMML(nodeId, majorityClass, frequencies, splitAttribute, splitPredicates, children);
} else {
// binary nominal
SplitNominalBinary splitNominalBinary = (SplitNominalBinary) split;
DataCell[] splitValues = splitNominalBinary.getSplitValues();
// return new DecisionTreeNodeSplitNominalBinary(nodeId,
// majorityClass, frequencies, split
// .getSplitAttributeName(), splitValues,
// splitNominalBinary.getIntMappingsLeftPartition(),
// splitNominalBinary.getIntMappingsRightPartition(),
// children/* children[0]=left, ..[1] right */);
String splitAttribute = split.getSplitAttributeName();
int[][] indices = new int[][] { splitNominalBinary.getIntMappingsLeftPartition(), splitNominalBinary.getIntMappingsRightPartition() };
PMMLPredicate[] splitPredicates = new PMMLPredicate[2];
for (int j = 0; j < splitPredicates.length; j++) {
PMMLSimpleSetPredicate pred = null;
pred = new PMMLSimpleSetPredicate(splitAttribute, PMMLSetOperator.IS_IN);
pred.setArrayType(PMMLArrayType.STRING);
LinkedHashSet<String> values = new LinkedHashSet<String>();
for (int index : indices[j]) {
values.add(splitValues[index].toString());
}
pred.setValues(values);
splitPredicates[j] = pred;
}
return new DecisionTreeNodeSplitPMML(nodeId, majorityClass, frequencies, splitAttribute, splitPredicates, children);
}
}
}
use of org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeLeaf in project knime-core by knime.
the class TreeNodeRegression method createDecisionTreeNode.
/**
* @param metaData
* @return
*/
public DecisionTreeNode createDecisionTreeNode(final MutableInteger idGenerator, final TreeMetaData metaData) {
DataCell majorityCell = new StringCell(DoubleFormat.formatDouble(m_mean));
final int nrChildren = getNrChildren();
LinkedHashMap<DataCell, Double> distributionMap = new LinkedHashMap<DataCell, Double>();
distributionMap.put(majorityCell, m_totalSum);
if (nrChildren == 0) {
return new DecisionTreeNodeLeaf(idGenerator.inc(), majorityCell, distributionMap);
} else {
int id = idGenerator.inc();
DecisionTreeNode[] childNodes = new DecisionTreeNode[nrChildren];
int splitAttributeIndex = getSplitAttributeIndex();
assert splitAttributeIndex >= 0 : "non-leaf node has no split";
String splitAttribute = metaData.getAttributeMetaData(splitAttributeIndex).getAttributeName();
PMMLPredicate[] childPredicates = new PMMLPredicate[nrChildren];
for (int i = 0; i < nrChildren; i++) {
final TreeNodeRegression treeNode = getChild(i);
TreeNodeCondition cond = treeNode.getCondition();
childPredicates[i] = cond.toPMMLPredicate();
childNodes[i] = treeNode.createDecisionTreeNode(idGenerator, metaData);
}
return new DecisionTreeNodeSplitPMML(id, majorityCell, distributionMap, splitAttribute, childPredicates, childNodes);
}
}
Aggregations