use of org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode in project knime-core by knime.
the class TreeEnsembleLearnerNodeView method getSubtree.
private List<DecisionTreeNode> getSubtree(final DecisionTreeNode node) {
List<DecisionTreeNode> subTree = new ArrayList<DecisionTreeNode>();
List<DecisionTreeNode> toProcess = new LinkedList<DecisionTreeNode>();
toProcess.add(0, node);
// Traverse the tree breadth first
while (!toProcess.isEmpty()) {
DecisionTreeNode curr = toProcess.remove(0);
subTree.add(curr);
for (int i = 0; i < curr.getChildCount(); i++) {
toProcess.add(0, curr.getChildAt(i));
}
}
return subTree;
}
use of org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode in project knime-core by knime.
the class DecTreeNodeView2 method changeSelectedHiLite.
// /////////////////////////////
// routines for HiLite Support
// /////////////////////////////
/*
* hilite or unhilite all items that are covered by currently selected
* branches in the tree
*
* @param state if true hilite, otherwise unhilite selection
*/
private void changeSelectedHiLite(final boolean state) {
TreePath[] selectedPaths = m_jTree.getSelectionPaths();
if (selectedPaths == null) {
// nothing selected
return;
}
for (int i = 0; i < selectedPaths.length; i++) {
assert (selectedPaths[i] != null);
if (selectedPaths[i] == null) {
return;
}
TreePath path = selectedPaths[i];
Object lastNode = path.getLastPathComponent();
assert (lastNode != null);
assert (lastNode instanceof DecisionTreeNode);
Set<RowKey> covPat = ((DecisionTreeNode) lastNode).coveredPattern();
if (state) {
m_hiLiteHdl.fireHiLiteEvent(covPat);
} else {
m_hiLiteHdl.fireUnHiLiteEvent(covPat);
}
}
}
use of org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode in project knime-core by knime.
the class PMMLDecisionTreeTranslator method addTreeNode.
/**
* A recursive function which converts each KNIME Tree node to a
* corresponding PMML element.
*
* @param pmmlNode the desired PMML element
* @param node A KNIME DecisionTree node
*/
private static void addTreeNode(final NodeDocument.Node pmmlNode, final DecisionTreeNode node, final DerivedFieldMapper mapper) {
pmmlNode.setId(String.valueOf(node.getOwnIndex()));
pmmlNode.setScore(node.getMajorityClass().toString());
// read in and then exported again
if (node.getEntireClassCount() > 0) {
pmmlNode.setRecordCount(node.getEntireClassCount());
}
if (node instanceof DecisionTreeNodeSplitPMML) {
int defaultChild = ((DecisionTreeNodeSplitPMML) node).getDefaultChildIndex();
if (defaultChild > -1) {
pmmlNode.setDefaultChild(String.valueOf(defaultChild));
}
}
// adding score and stuff from parent
DecisionTreeNode parent = node.getParent();
if (parent == null) {
// When the parent is null, it is the root Node.
// For root node, the predicate is always True.
pmmlNode.addNewTrue();
} else if (parent instanceof DecisionTreeNodeSplitContinuous) {
// SimplePredicate case
DecisionTreeNodeSplitContinuous splitNode = (DecisionTreeNodeSplitContinuous) parent;
if (splitNode.getIndex(node) == 0) {
SimplePredicate pmmlSimplePredicate = pmmlNode.addNewSimplePredicate();
pmmlSimplePredicate.setField(mapper.getDerivedFieldName(splitNode.getSplitAttr()));
pmmlSimplePredicate.setOperator(Operator.LESS_OR_EQUAL);
pmmlSimplePredicate.setValue(String.valueOf(splitNode.getThreshold()));
} else if (splitNode.getIndex(node) == 1) {
pmmlNode.addNewTrue();
}
} else if (parent instanceof DecisionTreeNodeSplitNominalBinary) {
// SimpleSetPredicate case
DecisionTreeNodeSplitNominalBinary splitNode = (DecisionTreeNodeSplitNominalBinary) parent;
SimpleSetPredicate pmmlSimpleSetPredicate = pmmlNode.addNewSimpleSetPredicate();
pmmlSimpleSetPredicate.setField(mapper.getDerivedFieldName(splitNode.getSplitAttr()));
pmmlSimpleSetPredicate.setBooleanOperator(SimpleSetPredicate.BooleanOperator.IS_IN);
ArrayType pmmlArray = pmmlSimpleSetPredicate.addNewArray();
pmmlArray.setType(ArrayType.Type.STRING);
DataCell[] splitValues = splitNode.getSplitValues();
List<Integer> indices = null;
if (splitNode.getIndex(node) == SplitNominalBinary.LEFT_PARTITION) {
indices = splitNode.getLeftChildIndices();
} else if (splitNode.getIndex(node) == SplitNominalBinary.RIGHT_PARTITION) {
indices = splitNode.getRightChildIndices();
} else {
throw new IllegalArgumentException("Split node is neither " + "contained in the right nor in the left partition.");
}
StringBuilder classSet = new StringBuilder();
for (Integer i : indices) {
if (classSet.length() > 0) {
classSet.append(" ");
}
classSet.append(splitValues[i].toString());
}
pmmlArray.setN(BigInteger.valueOf(indices.size()));
XmlCursor xmlCursor = pmmlArray.newCursor();
xmlCursor.setTextValue(classSet.toString());
xmlCursor.dispose();
} else if (parent instanceof DecisionTreeNodeSplitNominal) {
DecisionTreeNodeSplitNominal splitNode = (DecisionTreeNodeSplitNominal) parent;
SimplePredicate pmmlSimplePredicate = pmmlNode.addNewSimplePredicate();
pmmlSimplePredicate.setField(mapper.getDerivedFieldName(splitNode.getSplitAttr()));
pmmlSimplePredicate.setOperator(Operator.EQUAL);
int nodeIndex = parent.getIndex(node);
pmmlSimplePredicate.setValue(String.valueOf(splitNode.getSplitValues()[nodeIndex].toString()));
} else if (parent instanceof DecisionTreeNodeSplitPMML) {
DecisionTreeNodeSplitPMML splitNode = (DecisionTreeNodeSplitPMML) parent;
int nodeIndex = parent.getIndex(node);
// get the PMML predicate of the current node from its parent
PMMLPredicate predicate = splitNode.getSplitPred()[nodeIndex];
if (predicate instanceof PMMLCompoundPredicate) {
// surrogates as used in GBT
exportCompoundPredicate(pmmlNode, (PMMLCompoundPredicate) predicate, mapper);
} else {
predicate.setSplitAttribute(mapper.getDerivedFieldName(predicate.getSplitAttribute()));
// delegate the writing to the predicate translator
PMMLPredicateTranslator.exportTo(predicate, pmmlNode);
}
} else {
throw new IllegalArgumentException("Node Type " + parent.getClass() + " is not supported!");
}
// adding score distribution (class counts)
Set<Entry<DataCell, Double>> classCounts = node.getClassCounts().entrySet();
Iterator<Entry<DataCell, Double>> iterator = classCounts.iterator();
while (iterator.hasNext()) {
Entry<DataCell, Double> entry = iterator.next();
DataCell cell = entry.getKey();
Double freq = entry.getValue();
ScoreDistribution pmmlScoreDist = pmmlNode.addNewScoreDistribution();
pmmlScoreDist.setValue(cell.toString());
pmmlScoreDist.setRecordCount(freq);
}
// adding children
if (!(node instanceof DecisionTreeNodeLeaf)) {
for (int i = 0; i < node.getChildCount(); i++) {
addTreeNode(pmmlNode.addNewNode(), node.getChildAt(i), mapper);
}
}
}
use of org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode in project knime-core by knime.
the class DecisionTreeLearnerNodeModel2 method execute.
/**
* Start of decision tree induction.
*
* @param exec the execution context for this run
* @param data the input data to build the decision tree from
* @return an empty data table array, as just a model is provided
* @throws Exception any type of exception, e.g. for cancellation,
* invalid input,...
* @see NodeModel#execute(BufferedDataTable[],ExecutionContext)
*/
@Override
protected PortObject[] execute(final PortObject[] data, final ExecutionContext exec) throws Exception {
// holds the warning message displayed after execution
m_warningMessageSb = new StringBuilder();
ParallelProcessing parallelProcessing = new ParallelProcessing(m_parallelProcessing.getIntValue());
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Number available threads: " + parallelProcessing.getMaxNumberThreads() + " used threads: " + parallelProcessing.getCurrentThreadsInUse());
}
exec.setProgress("Preparing...");
// check input data
assert (data != null && data[DATA_INPORT] != null);
BufferedDataTable inData = (BufferedDataTable) data[DATA_INPORT];
// get column with color information
String colorColumn = null;
for (DataColumnSpec s : inData.getDataTableSpec()) {
if (s.getColorHandler() != null) {
colorColumn = s.getName();
break;
}
}
// the data table must have more than 2 records
if (inData.size() <= 1) {
throw new IllegalArgumentException("Input data table must have at least 2 records!");
}
// get class column index
int classColumnIndex = inData.getDataTableSpec().findColumnIndex(m_classifyColumn.getStringValue());
assert classColumnIndex > -1;
// create initial In-Memory table
exec.setProgress("Create initial In-Memory table...");
InMemoryTableCreator tableCreator = new InMemoryTableCreator(inData, classColumnIndex, m_minNumberRecordsPerNode.getIntValue(), m_skipColumns.getBooleanValue());
InMemoryTable initialTable = tableCreator.createInMemoryTable(exec.createSubExecutionContext(0.05));
int removedRows = tableCreator.getRemovedRowsDueToMissingClassValue();
if (removedRows == inData.size()) {
throw new IllegalArgumentException("Class column contains only " + "missing values");
}
if (removedRows > 0) {
m_warningMessageSb.append(removedRows);
m_warningMessageSb.append(" rows removed due to missing class value;");
}
// the all over row count is used to report progress
m_alloverRowCount = initialTable.getSumOfWeights();
// set the finishing counter
// this counter will always be incremented when a leaf node is
// created, as this determines the recursion end and can thus
// be used for progress indication
m_finishedCounter = new AtomicDouble(0);
// get the number of attributes
m_numberAttributes = initialTable.getNumAttributes();
// create the quality measure
final SplitQualityMeasure splitQualityMeasure;
if (m_splitQualityMeasureType.getStringValue().equals(SPLIT_QUALITY_GINI)) {
splitQualityMeasure = new SplitQualityGini();
} else {
splitQualityMeasure = new SplitQualityGainRatio();
}
// build the tree
// before this set the node counter to 0
m_counter.set(0);
exec.setMessage("Building tree...");
final int firstSplitColIdx = initialTable.getAttributeIndex(m_firstSplitCol.getStringValue());
DecisionTreeNode root = null;
root = buildTree(initialTable, exec, 0, splitQualityMeasure, parallelProcessing, firstSplitColIdx);
boolean isBinaryNominal = m_binaryNominalSplitMode.getBooleanValue();
boolean isFilterInvalidAttributeValues = m_filterNominalValuesFromParent.getBooleanValue();
if (isBinaryNominal && isFilterInvalidAttributeValues) {
// traverse tree nodes and remove from the children the attribute
// values that were filtered out further up in the tree. "Bug" 3124
root.filterIllegalAttributes(Collections.<String, Set<String>>emptyMap());
}
// the decision tree model saved as PMML at the second out-port
DecisionTree decisionTree = new DecisionTree(root, m_classifyColumn.getStringValue(), /* strategy has to be set explicitly as the default in PMML is
none, which means rows with missing values are not
classified. */
PMMLMissingValueStrategy.get(m_missingValues.getStringValue()), PMMLNoTrueChildStrategy.get(m_noTrueChild.getStringValue()));
decisionTree.setColorColumn(colorColumn);
// prune the tree
exec.setMessage("Prune tree with " + m_pruningMethod.getStringValue() + "...");
pruneTree(decisionTree);
// add highlight patterns and color information
exec.setMessage("Adding hilite and color info to tree...");
addHiliteAndColorInfo(inData, decisionTree);
LOGGER.info("Decision tree consisting of " + decisionTree.getNumberNodes() + " nodes created with pruning method " + m_pruningMethod.getStringValue());
// set the warning message if available
if (m_warningMessageSb.length() > 0) {
setWarningMessage(m_warningMessageSb.toString());
}
// reset the number available threads
parallelProcessing.reset();
parallelProcessing = null;
// no data out table is created -> return an empty table array
exec.setMessage("Creating PMML decision tree model...");
// handle the optional PMML input
PMMLPortObject inPMMLPort = m_pmmlInEnabled ? (PMMLPortObject) data[1] : null;
DataTableSpec inSpec = inData.getSpec();
PMMLPortObjectSpec outPortSpec = createPMMLPortObjectSpec(inPMMLPort == null ? null : inPMMLPort.getSpec(), inSpec);
PMMLPortObject outPMMLPort = new PMMLPortObject(outPortSpec, inPMMLPort, inData.getSpec());
outPMMLPort.addModelTranslater(new PMMLDecisionTreeTranslator(decisionTree));
m_decisionTree = decisionTree;
return new PortObject[] { outPMMLPort };
}
use of org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode in project knime-core by knime.
the class Pruner method trainingErrorPruning.
// private static double estimatedError(final double all, final double error,
// final double zValue) {
// double f = error / all;
// double z = zValue;
// double N = all;
//
// double estimatedError =
// (f + z * z / (2 * N) + z
// * Math.sqrt(f / N - f * f / N + z * z / (4 * N * N)))
// / (1 + z * z / N);
//
// // return the weighted value
// return estimatedError * all;
// }
//
/**
* Prunes a {@link DecisionTree} according to the training error. I.e.
* if the error in the subtree according to the training data is the same
* as in the current node, the subtree is pruned, as nothing is gained.
*
* @param decTree the decision tree to prune
*/
public static void trainingErrorPruning(final DecisionTree decTree) {
// traverse the tree depth first (in-fix)
DecisionTreeNode root = decTree.getRootNode();
trainingErrorPruningRecurse(root);
}
Aggregations