use of org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode in project knime-core by knime.
the class DecisionTreeLearnerNodeModel method execute.
/**
* Start of decision tree induction.
*
* @param exec the execution context for this run
* @param data the input data to build the decision tree from
* @return an empty data table array, as just a model is provided
* @throws Exception any type of exception, e.g. for cancellation,
* invalid input,...
* @see NodeModel#execute(BufferedDataTable[],ExecutionContext)
*/
@Override
protected PortObject[] execute(final PortObject[] data, final ExecutionContext exec) throws Exception {
// holds the warning message displayed after execution
StringBuilder warningMessageSb = new StringBuilder();
ParallelProcessing parallelProcessing = new ParallelProcessing(m_parallelProcessing.getIntValue());
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Number available threads: " + parallelProcessing.getMaxNumberThreads() + " used threads: " + parallelProcessing.getCurrentThreadsInUse());
}
exec.setProgress("Preparing...");
// check input data
assert (data != null && data[DATA_INPORT] != null);
BufferedDataTable inData = (BufferedDataTable) data[DATA_INPORT];
// get column with color information
String colorColumn = null;
for (DataColumnSpec s : inData.getDataTableSpec()) {
if (s.getColorHandler() != null) {
colorColumn = s.getName();
break;
}
}
// the data table must have more than 2 records
if (inData.getRowCount() <= 1) {
throw new IllegalArgumentException("Input data table must have at least 2 records!");
}
// get class column index
int classColumnIndex = inData.getDataTableSpec().findColumnIndex(m_classifyColumn.getStringValue());
assert classColumnIndex > -1;
// create initial In-Memory table
exec.setProgress("Create initial In-Memory table...");
InMemoryTableCreator tableCreator = new InMemoryTableCreator(inData, classColumnIndex, m_minNumberRecordsPerNode.getIntValue(), m_skipColumns.getBooleanValue());
InMemoryTable initialTable = tableCreator.createInMemoryTable(exec.createSubExecutionContext(0.05));
int removedRows = tableCreator.getRemovedRowsDueToMissingClassValue();
if (removedRows == inData.getRowCount()) {
throw new IllegalArgumentException("Class column contains only " + "missing values");
}
if (removedRows > 0) {
warningMessageSb.append(removedRows);
warningMessageSb.append(" rows removed due to missing class value;");
}
// the all over row count is used to report progress
m_alloverRowCount = initialTable.getSumOfWeights();
// set the finishing counter
// this counter will always be incremented when a leaf node is
// created, as this determines the recursion end and can thus
// be used for progress indication
m_finishedCounter = new AtomicDouble(0);
// get the number of attributes
m_numberAttributes = initialTable.getNumAttributes();
// create the quality measure
final SplitQualityMeasure splitQualityMeasure;
if (m_splitQualityMeasureType.getStringValue().equals(SPLIT_QUALITY_GINI)) {
splitQualityMeasure = new SplitQualityGini();
} else {
splitQualityMeasure = new SplitQualityGainRatio();
}
// build the tree
// before this set the node counter to 0
m_counter.set(0);
exec.setMessage("Building tree...");
DecisionTreeNode root = null;
root = buildTree(initialTable, exec, 0, splitQualityMeasure, parallelProcessing);
boolean isBinaryNominal = m_binaryNominalSplitMode.getBooleanValue();
boolean isFilterInvalidAttributeValues = m_filterNominalValuesFromParent.getBooleanValue();
if (isBinaryNominal && isFilterInvalidAttributeValues) {
// traverse tree nodes and remove from the children the attribute
// values that were filtered out further up in the tree. "Bug" 3124
root.filterIllegalAttributes(Collections.EMPTY_MAP);
}
// the decision tree model saved as PMML at the second out-port
DecisionTree decisionTree = new DecisionTree(root, m_classifyColumn.getStringValue(), /* strategy has to be set explicitly as the default in PMML is
none, which means rows with missing values are not
classified. */
PMMLMissingValueStrategy.LAST_PREDICTION);
decisionTree.setColorColumn(colorColumn);
// prune the tree
exec.setMessage("Prune tree with " + m_pruningMethod.getStringValue() + "...");
pruneTree(decisionTree);
// add highlight patterns and color information
exec.setMessage("Adding hilite and color info to tree...");
addHiliteAndColorInfo(inData, decisionTree);
LOGGER.info("Decision tree consisting of " + decisionTree.getNumberNodes() + " nodes created with pruning method " + m_pruningMethod.getStringValue());
// set the warning message if available
if (warningMessageSb.length() > 0) {
setWarningMessage(warningMessageSb.toString());
}
// reset the number available threads
parallelProcessing.reset();
parallelProcessing = null;
// no data out table is created -> return an empty table array
exec.setMessage("Creating PMML decision tree model...");
// handle the optional PMML input
PMMLPortObject inPMMLPort = (PMMLPortObject) data[1];
DataTableSpec inSpec = inData.getSpec();
PMMLPortObjectSpec outPortSpec = createPMMLPortObjectSpec(inPMMLPort == null ? null : inPMMLPort.getSpec(), inSpec);
PMMLPortObject outPMMLPort = new PMMLPortObject(outPortSpec, inPMMLPort, inData.getSpec());
outPMMLPort.addModelTranslater(new PMMLDecisionTreeTranslator(decisionTree));
m_decisionTree = decisionTree;
return new PortObject[] { outPMMLPort };
}
use of org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode in project knime-core by knime.
the class DecisionTreeLearnerNodeModel method buildTree.
/**
* Recursively induces the decision tree.
*
* @param table the {@link InMemoryTable} representing the data for this
* node to determine the split and after that perform
* partitioning
* @param exec the execution context for progress information
* @param depth the current recursion depth
*/
private DecisionTreeNode buildTree(final InMemoryTable table, final ExecutionContext exec, final int depth, final SplitQualityMeasure splitQualityMeasure, final ParallelProcessing parallelProcessing) throws CanceledExecutionException, IllegalAccessException {
exec.checkCanceled();
// derive this node's id from the counter
int nodeId = m_counter.getAndIncrement();
DataCell majorityClass = table.getMajorityClassAsCell();
LinkedHashMap<DataCell, Double> frequencies = table.getClassFrequencies();
// if the distribution allows for a leaf
if (table.isPureEnough()) {
// free memory
table.freeUnderlyingDataRows();
double value = m_finishedCounter.incrementAndGet(table.getSumOfWeights());
exec.setProgress(value / m_alloverRowCount, "Created node with id " + nodeId + " at level " + depth);
return new DecisionTreeNodeLeaf(nodeId, majorityClass, frequencies);
} else {
// find the best splits for all attributes
SplitFinder splittFinder = new SplitFinder(table, splitQualityMeasure, m_averageSplitpoint.getBooleanValue(), m_minNumberRecordsPerNode.getIntValue(), m_binaryNominalSplitMode.getBooleanValue(), m_maxNumNominalsForCompleteComputation.getIntValue());
// check for enough memory
checkMemory();
// get the best split among the best attribute splits
Split split = splittFinder.getSplit();
// if no best split could be evaluated, create a leaf node
if (split == null || !split.isValidSplit()) {
table.freeUnderlyingDataRows();
double value = m_finishedCounter.incrementAndGet(table.getSumOfWeights());
exec.setProgress(value / m_alloverRowCount, "Created node with id " + nodeId + " at level " + depth);
return new DecisionTreeNodeLeaf(nodeId, majorityClass, frequencies);
}
// partition the attribute lists according to this split
Partitioner partitioner = new Partitioner(table, split, m_minNumberRecordsPerNode.getIntValue());
if (!partitioner.couldBeUsefulPartitioned()) {
table.freeUnderlyingDataRows();
double value = m_finishedCounter.incrementAndGet(table.getSumOfWeights());
exec.setProgress(value / m_alloverRowCount, "Created node with id " + nodeId + " at level " + depth);
return new DecisionTreeNodeLeaf(nodeId, majorityClass, frequencies);
}
// get the just created partitions
InMemoryTable[] partitionTables = partitioner.getPartitionTables();
// recursively build the child nodes
DecisionTreeNode[] children = new DecisionTreeNode[partitionTables.length];
ArrayList<ParallelBuilding> threads = new ArrayList<ParallelBuilding>();
int i = 0;
for (InMemoryTable partitionTable : partitionTables) {
exec.checkCanceled();
if (partitionTable.getNumberDataRows() * m_numberAttributes < 10000 || !parallelProcessing.isThreadAvailable()) {
children[i] = buildTree(partitionTable, exec, depth + 1, splitQualityMeasure, parallelProcessing);
} else {
String threadName = "Build thread, node: " + nodeId + "." + i;
ParallelBuilding buildThread = new ParallelBuilding(threadName, partitionTable, exec, depth + 1, i, splitQualityMeasure, parallelProcessing);
LOGGER.debug("Start new parallel building thread: " + threadName);
threads.add(buildThread);
buildThread.start();
}
i++;
}
// already assigned to the child array
for (ParallelBuilding buildThread : threads) {
children[buildThread.getThreadIndex()] = buildThread.getResultNode();
exec.checkCanceled();
if (buildThread.getException() != null) {
for (ParallelBuilding buildThread2 : threads) {
buildThread2.stop();
}
throw new RuntimeException(buildThread.getException().getMessage());
}
}
threads.clear();
if (split instanceof SplitContinuous) {
double splitValue = ((SplitContinuous) split).getBestSplitValue();
// return new DecisionTreeNodeSplitContinuous(nodeId,
// majorityClass, frequencies, split
// .getSplitAttributeName(), children, splitValue);
String splitAttribute = split.getSplitAttributeName();
PMMLPredicate[] splitPredicates = new PMMLPredicate[] { new PMMLSimplePredicate(splitAttribute, PMMLOperator.LESS_OR_EQUAL, Double.toString(splitValue)), new PMMLSimplePredicate(splitAttribute, PMMLOperator.GREATER_THAN, Double.toString(splitValue)) };
return new DecisionTreeNodeSplitPMML(nodeId, majorityClass, frequencies, splitAttribute, splitPredicates, children);
} else if (split instanceof SplitNominalNormal) {
// else the attribute is nominal
DataCell[] splitValues = ((SplitNominalNormal) split).getSplitValues();
// return new DecisionTreeNodeSplitNominal(nodeId, majorityClass,
// frequencies, split.getSplitAttributeName(),
// splitValues, children);
int num = children.length;
PMMLPredicate[] splitPredicates = new PMMLPredicate[num];
String splitAttribute = split.getSplitAttributeName();
for (int j = 0; j < num; j++) {
splitPredicates[j] = new PMMLSimplePredicate(splitAttribute, PMMLOperator.EQUAL, splitValues[j].toString());
}
return new DecisionTreeNodeSplitPMML(nodeId, majorityClass, frequencies, splitAttribute, splitPredicates, children);
} else {
// binary nominal
SplitNominalBinary splitNominalBinary = (SplitNominalBinary) split;
DataCell[] splitValues = splitNominalBinary.getSplitValues();
// return new DecisionTreeNodeSplitNominalBinary(nodeId,
// majorityClass, frequencies, split
// .getSplitAttributeName(), splitValues,
// splitNominalBinary.getIntMappingsLeftPartition(),
// splitNominalBinary.getIntMappingsRightPartition(),
// children/* children[0]=left, ..[1] right */);
String splitAttribute = split.getSplitAttributeName();
int[][] indices = new int[][] { splitNominalBinary.getIntMappingsLeftPartition(), splitNominalBinary.getIntMappingsRightPartition() };
PMMLPredicate[] splitPredicates = new PMMLPredicate[2];
for (int j = 0; j < splitPredicates.length; j++) {
PMMLSimpleSetPredicate pred = null;
pred = new PMMLSimpleSetPredicate(splitAttribute, PMMLSetOperator.IS_IN);
pred.setArrayType(PMMLArrayType.STRING);
LinkedHashSet<String> values = new LinkedHashSet<String>();
for (int index : indices[j]) {
values.add(splitValues[index].toString());
}
pred.setValues(values);
splitPredicates[j] = pred;
}
return new DecisionTreeNodeSplitPMML(nodeId, majorityClass, frequencies, splitAttribute, splitPredicates, children);
}
}
}
use of org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode in project knime-core by knime.
the class Pruner method mdlPruning.
/**
* Prunes a {@link DecisionTree} according to the minimum description lenght
* (MDL) principle.
*
* @param decTree the decision tree to prune
*/
public static void mdlPruning(final DecisionTree decTree) {
// traverse the tree depth first (in-fix)
DecisionTreeNode root = decTree.getRootNode();
mdlPruningRecurse(root);
}
use of org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode in project knime-core by knime.
the class TreeNodeClassification method createDecisionTreeNode.
/**
* Creates DecisionTreeNode model that is used in Decision Tree of KNIME
*
* @param idGenerator
* @param metaData
* @return a DecisionTreeNode
*/
public DecisionTreeNode createDecisionTreeNode(final MutableInteger idGenerator, final TreeMetaData metaData) {
DataCell majorityCell = new StringCell(getMajorityClassName());
double[] targetDistribution = getTargetDistribution();
int initSize = (int) (targetDistribution.length / 0.75 + 1.0);
LinkedHashMap<DataCell, Double> scoreDistributionMap = new LinkedHashMap<DataCell, Double>(initSize);
NominalValueRepresentation[] targets = getTargetMetaData().getValues();
for (int i = 0; i < targetDistribution.length; i++) {
String cl = targets[i].getNominalValue();
double d = targetDistribution[i];
scoreDistributionMap.put(new StringCell(cl), d);
}
final int nrChildren = getNrChildren();
if (nrChildren == 0) {
return new DecisionTreeNodeLeaf(idGenerator.inc(), majorityCell, scoreDistributionMap);
} else {
int id = idGenerator.inc();
DecisionTreeNode[] childNodes = new DecisionTreeNode[nrChildren];
int splitAttributeIndex = getSplitAttributeIndex();
assert splitAttributeIndex >= 0 : "non-leaf node has no split";
String splitAttribute = metaData.getAttributeMetaData(splitAttributeIndex).getAttributeName();
PMMLPredicate[] childPredicates = new PMMLPredicate[nrChildren];
for (int i = 0; i < nrChildren; i++) {
final TreeNodeClassification treeNode = getChild(i);
TreeNodeCondition cond = treeNode.getCondition();
childPredicates[i] = cond.toPMMLPredicate();
childNodes[i] = treeNode.createDecisionTreeNode(idGenerator, metaData);
}
return new DecisionTreeNodeSplitPMML(id, majorityCell, scoreDistributionMap, splitAttribute, childPredicates, childNodes);
}
}
use of org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode in project knime-core by knime.
the class TreeEnsembleLearnerNodeView method updateHiLite.
// ///////////////////////////////////////////////////
private void updateHiLite(final boolean state) {
DecisionTreeNode selected = m_graph.getSelected();
if (selected == null) {
return;
}
Set<RowKey> covPat = new HashSet<RowKey>();
covPat.addAll(selected.coveredPattern());
if (state) {
m_hiLiteHdl.fireHiLiteEvent(covPat);
} else {
m_hiLiteHdl.fireUnHiLiteEvent(covPat);
}
}
Aggregations