Search in sources :

Example 56 with DecisionTreeNode

use of org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode in project knime-core by knime.

the class Pruner method trainingErrorPruning.

// private static double estimatedError(final double all, final double error,
// final double zValue) {
// double f = error / all;
// double z = zValue;
// double N = all;
// 
// double estimatedError =
// (f + z * z / (2 * N) + z
// * Math.sqrt(f / N - f * f / N + z * z / (4 * N * N)))
// / (1 + z * z / N);
// 
// // return the weighted value
// return estimatedError * all;
// }
// 
/**
 * Prunes a {@link DecisionTree} according to the training error. I.e.
 * if the error in the subtree according to the training data is the same
 * as in the current node, the subtree is pruned, as nothing is gained.
 *
 * @param decTree the decision tree to prune
 */
public static void trainingErrorPruning(final DecisionTree decTree) {
    // traverse the tree depth first (in-fix)
    DecisionTreeNode root = decTree.getRootNode();
    trainingErrorPruningRecurse(root);
}
Also used : DecisionTreeNode(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode)

Example 57 with DecisionTreeNode

use of org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode in project knime-core by knime.

the class Pruner method mdlPruningRecurse.

// /**
// * The general idea is to recursively prune the children and then compare
// * the potential leaf estimated erro with the actual estimated error
// * including the length of the children.
// *
// * @param node the node to prune
// * @param zValue the z value according to which the error is estimated
// *            calculated from the confidence value
// *
// * @return the resulting description length after pruning; this value is
// *         used in higher levels of the recursion, i.e. for the parent node
// */
// private static PruningResult estimatedErrorPruningRecurse(
// final DecisionTreeNode node, final double zValue) {
// 
// // if this is a child, just return the estimated error
// if (node.isLeaf()) {
// double error = node.getEntireClassCount() - node.getOwnClassCount();
// double estimatedError =
// estimatedError(node.getEntireClassCount(), error, zValue);
// 
// return new PruningResult(estimatedError, node);
// }
// 
// // holds the estimated errors of the children
// double[] childDescriptionLength = new double[node.getChildCount()];
// DecisionTreeNodeSplit splitNode = (DecisionTreeNodeSplit)node;
// // prune all children
// DecisionTreeNode[] children = splitNode.getChildren();
// int count = 0;
// for (DecisionTreeNode childNode : children) {
// 
// PruningResult result =
// estimatedErrorPruningRecurse(childNode, zValue);
// childDescriptionLength[count] = result.getQualityValue();
// 
// // replace the child with the one from the result (could of course
// // be the same)
// splitNode.replaceChild(childNode, result.getNode());
// 
// count++;
// }
// 
// // calculate the estimated error if this would be a leaf
// double error = node.getEntireClassCount() - node.getOwnClassCount();
// double leafEstimatedError =
// estimatedError(node.getEntireClassCount(), error, zValue);
// 
// // calculate the current estimated error (sum of estimated errors of the
// // children)
// double currentEstimatedError = 0;
// for (double childDescLength : childDescriptionLength) {
// currentEstimatedError += childDescLength;
// }
// 
// // define the return node
// DecisionTreeNode returnNode = node;
// double returnEstimatedError = currentEstimatedError;
// 
// // if the possible leaf costs are smaller, replace this node
// // with a leaf (tollerance is 0.1)
// if (leafEstimatedError <= currentEstimatedError + 0.1) {
// DecisionTreeNodeLeaf newLeaf =
// new DecisionTreeNodeLeaf(node.getOwnIndex(), node
// .getMajorityClass(), node.getClassCounts());
// newLeaf.setParent((DecisionTreeNode)node.getParent());
// newLeaf.setPrefix(node.getPrefix());
// returnNode = newLeaf;
// returnEstimatedError = leafEstimatedError;
// }
// 
// return new PruningResult(returnEstimatedError, returnNode);
// }
// 
// /**
// * Prunes a {@link DecisionTree} according to the estimated error pruning
// * (Quinlan 87).
// *
// * @param decTree the decision tree to prune
// * @param confidence the confidence value according to which the error is
// *            estimated
// */
// public static void estimatedErrorPruning(final DecisionTree decTree,
// final double confidence) {
// 
// // traverse the tree depth first (in-fix)
// DecisionTreeNode root = decTree.getRootNode();
// // double zValue = xnormi(1 - confidence);
// estimatedErrorPruningRecurse(root, zValue);
// }
/**
 * The general idea is to recursively prune the children and then compare
 * the potential leaf description length with the actual length including
 * the length of the children.
 *
 * @param node the node to prune
 *
 * @return the resulting description length after pruning; this value is
 *         used in higher levels of the recursion, i.e. for the parent node
 */
private static PruningResult mdlPruningRecurse(final DecisionTreeNode node) {
    // leaf
    if (node.isLeaf()) {
        double error = node.getEntireClassCount() - node.getOwnClassCount();
        // node => 1Bit)
        return new PruningResult(error + 1.0, node);
    }
    // holds the description length of the children
    double[] childDescriptionLength = new double[node.getChildCount()];
    DecisionTreeNodeSplit splitNode = (DecisionTreeNodeSplit) node;
    // prune all children
    DecisionTreeNode[] children = splitNode.getChildren();
    int count = 0;
    for (DecisionTreeNode childNode : children) {
        PruningResult result = mdlPruningRecurse(childNode);
        childDescriptionLength[count] = result.getQualityValue();
        // replace the child with the one from the result (could of course
        // be the same)
        splitNode.replaceChild(childNode, result.getNode());
        count++;
    }
    // calculate the cost if this would be a leaf
    double leafCost = node.getEntireClassCount() - node.getOwnClassCount() + 1.0;
    // calculate the current cost including the children
    double currentCost = 1.0 + Math.log(node.getChildCount()) / Math.log(2);
    for (double childDescLength : childDescriptionLength) {
        currentCost += childDescLength;
    }
    // define the return node
    DecisionTreeNode returnNode = node;
    double returnCost = currentCost;
    // with a leaf
    if (leafCost <= currentCost) {
        DecisionTreeNodeLeaf newLeaf = new DecisionTreeNodeLeaf(node.getOwnIndex(), node.getMajorityClass(), node.getClassCounts());
        newLeaf.setParent((DecisionTreeNode) node.getParent());
        newLeaf.setPrefix(node.getPrefix());
        returnNode = newLeaf;
        returnCost = leafCost;
    }
    return new PruningResult(returnCost, returnNode);
}
Also used : DecisionTreeNodeSplit(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplit) DecisionTreeNodeLeaf(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeLeaf) DecisionTreeNode(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode)

Example 58 with DecisionTreeNode

use of org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode in project knime-core by knime.

the class Pruner method trainingErrorPruningRecurse.

/**
 * The recursion for the training error based pruning.
 *
 * @param node the node to prune
 *
 * @return the resulting error; this value is
 *         used in higher levels of the recursion, i.e. for the parent node
 */
private static PruningResult trainingErrorPruningRecurse(final DecisionTreeNode node) {
    // if this is a child, just return the error rate
    if (node.isLeaf()) {
        double error = node.getEntireClassCount() - node.getOwnClassCount();
        return new PruningResult(error, node);
    }
    // holds the error rates of the children
    double[] childErrorRates = new double[node.getChildCount()];
    // this node must be a split node
    DecisionTreeNodeSplit splitNode = (DecisionTreeNodeSplit) node;
    // prune all children
    DecisionTreeNode[] children = splitNode.getChildren();
    int count = 0;
    for (DecisionTreeNode childNode : children) {
        PruningResult result = trainingErrorPruningRecurse(childNode);
        childErrorRates[count] = result.getQualityValue();
        // replace the child with the one from the result (could of course
        // be the same)
        splitNode.replaceChild(childNode, result.getNode());
        count++;
    }
    // calculate the error if this would be a leaf
    double leafError = node.getEntireClassCount() - node.getOwnClassCount();
    // calculate the current error including the children
    double currentError = 0.0;
    for (double childError : childErrorRates) {
        currentError += childError;
    }
    // define the return node
    DecisionTreeNode returnNode = node;
    double returnError = currentError;
    // with a leaf
    if (leafError - 0.001 <= currentError) {
        DecisionTreeNodeLeaf newLeaf = new DecisionTreeNodeLeaf(node.getOwnIndex(), node.getMajorityClass(), node.getClassCounts());
        newLeaf.setParent((DecisionTreeNode) node.getParent());
        newLeaf.setPrefix(node.getPrefix());
        returnNode = newLeaf;
        returnError = leafError;
    }
    return new PruningResult(returnError, returnNode);
}
Also used : DecisionTreeNodeSplit(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplit) DecisionTreeNodeLeaf(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeLeaf) DecisionTreeNode(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode)

Aggregations

DecisionTreeNode (org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode)50 RowKey (org.knime.core.data.RowKey)18 HashSet (java.util.HashSet)14 LinkedList (java.util.LinkedList)14 DecisionTreeNodeLeaf (org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeLeaf)12 ArrayList (java.util.ArrayList)10 DecisionTreeNodeSplitPMML (org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplitPMML)10 DataCell (org.knime.core.data.DataCell)9 Action (javax.swing.Action)8 JMenu (javax.swing.JMenu)8 CollapseBranchAction (org.knime.base.node.mine.decisiontree2.view.graph.CollapseBranchAction)8 ExpandBranchAction (org.knime.base.node.mine.decisiontree2.view.graph.ExpandBranchAction)8 PMMLPredicate (org.knime.base.node.mine.decisiontree2.PMMLPredicate)7 PMMLDecisionTreeTranslator (org.knime.base.node.mine.decisiontree2.PMMLDecisionTreeTranslator)5 DecisionTree (org.knime.base.node.mine.decisiontree2.model.DecisionTree)5 TreePath (javax.swing.tree.TreePath)4 PortObject (org.knime.core.node.port.PortObject)4 PMMLPortObject (org.knime.core.node.port.pmml.PMMLPortObject)4 LinkedHashMap (java.util.LinkedHashMap)3 PMMLSimplePredicate (org.knime.base.node.mine.decisiontree2.PMMLSimplePredicate)3