Search in sources :

Example 6 with DecisionTree

use of org.knime.base.node.mine.decisiontree2.model.DecisionTree in project knime-core by knime.

the class PMMLDecisionTreeTranslator method addTreeNode.

/**
 * A recursive function which converts each KNIME Tree node to a
 * corresponding PMML element.
 *
 * @param pmmlNode the desired PMML element
 * @param node A KNIME DecisionTree node
 */
private static void addTreeNode(final NodeDocument.Node pmmlNode, final DecisionTreeNode node, final DerivedFieldMapper mapper) {
    pmmlNode.setId(String.valueOf(node.getOwnIndex()));
    pmmlNode.setScore(node.getMajorityClass().toString());
    // read in and then exported again
    if (node.getEntireClassCount() > 0) {
        pmmlNode.setRecordCount(node.getEntireClassCount());
    }
    if (node instanceof DecisionTreeNodeSplitPMML) {
        int defaultChild = ((DecisionTreeNodeSplitPMML) node).getDefaultChildIndex();
        if (defaultChild > -1) {
            pmmlNode.setDefaultChild(String.valueOf(defaultChild));
        }
    }
    // adding score and stuff from parent
    DecisionTreeNode parent = node.getParent();
    if (parent == null) {
        // When the parent is null, it is the root Node.
        // For root node, the predicate is always True.
        pmmlNode.addNewTrue();
    } else if (parent instanceof DecisionTreeNodeSplitContinuous) {
        // SimplePredicate case
        DecisionTreeNodeSplitContinuous splitNode = (DecisionTreeNodeSplitContinuous) parent;
        if (splitNode.getIndex(node) == 0) {
            SimplePredicate pmmlSimplePredicate = pmmlNode.addNewSimplePredicate();
            pmmlSimplePredicate.setField(mapper.getDerivedFieldName(splitNode.getSplitAttr()));
            pmmlSimplePredicate.setOperator(Operator.LESS_OR_EQUAL);
            pmmlSimplePredicate.setValue(String.valueOf(splitNode.getThreshold()));
        } else if (splitNode.getIndex(node) == 1) {
            pmmlNode.addNewTrue();
        }
    } else if (parent instanceof DecisionTreeNodeSplitNominalBinary) {
        // SimpleSetPredicate case
        DecisionTreeNodeSplitNominalBinary splitNode = (DecisionTreeNodeSplitNominalBinary) parent;
        SimpleSetPredicate pmmlSimpleSetPredicate = pmmlNode.addNewSimpleSetPredicate();
        pmmlSimpleSetPredicate.setField(mapper.getDerivedFieldName(splitNode.getSplitAttr()));
        pmmlSimpleSetPredicate.setBooleanOperator(SimpleSetPredicate.BooleanOperator.IS_IN);
        ArrayType pmmlArray = pmmlSimpleSetPredicate.addNewArray();
        pmmlArray.setType(ArrayType.Type.STRING);
        DataCell[] splitValues = splitNode.getSplitValues();
        List<Integer> indices = null;
        if (splitNode.getIndex(node) == SplitNominalBinary.LEFT_PARTITION) {
            indices = splitNode.getLeftChildIndices();
        } else if (splitNode.getIndex(node) == SplitNominalBinary.RIGHT_PARTITION) {
            indices = splitNode.getRightChildIndices();
        } else {
            throw new IllegalArgumentException("Split node is neither " + "contained in the right nor in the left partition.");
        }
        StringBuilder classSet = new StringBuilder();
        for (Integer i : indices) {
            if (classSet.length() > 0) {
                classSet.append(" ");
            }
            classSet.append(splitValues[i].toString());
        }
        pmmlArray.setN(BigInteger.valueOf(indices.size()));
        XmlCursor xmlCursor = pmmlArray.newCursor();
        xmlCursor.setTextValue(classSet.toString());
        xmlCursor.dispose();
    } else if (parent instanceof DecisionTreeNodeSplitNominal) {
        DecisionTreeNodeSplitNominal splitNode = (DecisionTreeNodeSplitNominal) parent;
        SimplePredicate pmmlSimplePredicate = pmmlNode.addNewSimplePredicate();
        pmmlSimplePredicate.setField(mapper.getDerivedFieldName(splitNode.getSplitAttr()));
        pmmlSimplePredicate.setOperator(Operator.EQUAL);
        int nodeIndex = parent.getIndex(node);
        pmmlSimplePredicate.setValue(String.valueOf(splitNode.getSplitValues()[nodeIndex].toString()));
    } else if (parent instanceof DecisionTreeNodeSplitPMML) {
        DecisionTreeNodeSplitPMML splitNode = (DecisionTreeNodeSplitPMML) parent;
        int nodeIndex = parent.getIndex(node);
        // get the PMML predicate of the current node from its parent
        PMMLPredicate predicate = splitNode.getSplitPred()[nodeIndex];
        if (predicate instanceof PMMLCompoundPredicate) {
            // surrogates as used in GBT
            exportCompoundPredicate(pmmlNode, (PMMLCompoundPredicate) predicate, mapper);
        } else {
            predicate.setSplitAttribute(mapper.getDerivedFieldName(predicate.getSplitAttribute()));
            // delegate the writing to the predicate translator
            PMMLPredicateTranslator.exportTo(predicate, pmmlNode);
        }
    } else {
        throw new IllegalArgumentException("Node Type " + parent.getClass() + " is not supported!");
    }
    // adding score distribution (class counts)
    Set<Entry<DataCell, Double>> classCounts = node.getClassCounts().entrySet();
    Iterator<Entry<DataCell, Double>> iterator = classCounts.iterator();
    while (iterator.hasNext()) {
        Entry<DataCell, Double> entry = iterator.next();
        DataCell cell = entry.getKey();
        Double freq = entry.getValue();
        ScoreDistribution pmmlScoreDist = pmmlNode.addNewScoreDistribution();
        pmmlScoreDist.setValue(cell.toString());
        pmmlScoreDist.setRecordCount(freq);
    }
    // adding children
    if (!(node instanceof DecisionTreeNodeLeaf)) {
        for (int i = 0; i < node.getChildCount(); i++) {
            addTreeNode(pmmlNode.addNewNode(), node.getChildAt(i), mapper);
        }
    }
}
Also used : DecisionTreeNodeSplitNominal(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplitNominal) ArrayType(org.dmg.pmml.ArrayType) Entry(java.util.Map.Entry) DecisionTreeNodeSplitPMML(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplitPMML) SimplePredicate(org.dmg.pmml.SimplePredicateDocument.SimplePredicate) SimpleSetPredicate(org.dmg.pmml.SimpleSetPredicateDocument.SimpleSetPredicate) XmlCursor(org.apache.xmlbeans.XmlCursor) BigInteger(java.math.BigInteger) ScoreDistribution(org.dmg.pmml.ScoreDistributionDocument.ScoreDistribution) DecisionTreeNodeLeaf(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeLeaf) DecisionTreeNodeSplitNominalBinary(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplitNominalBinary) DataCell(org.knime.core.data.DataCell) DecisionTreeNodeSplitContinuous(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplitContinuous) DecisionTreeNode(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode)

Example 7 with DecisionTree

use of org.knime.base.node.mine.decisiontree2.model.DecisionTree in project knime-core by knime.

the class DecTreeToImageView method modelChanged.

/**
 * {@inheritDoc}
 */
@Override
protected void modelChanged() {
    DecTreeToImageNodeModel model = this.getNodeModel();
    if (model != null) {
        DecisionTree dt = model.getDecisionTree();
        if (dt != null) {
            m_graph.setColorColumn(model.getDecisionTree().getColorColumn());
            m_graph.setRootNode(dt.getRootNode());
        } else {
            m_graph.setColorColumn(null);
            m_graph.setRootNode(null);
        }
    }
}
Also used : DecisionTree(org.knime.base.node.mine.decisiontree2.model.DecisionTree)

Example 8 with DecisionTree

use of org.knime.base.node.mine.decisiontree2.model.DecisionTree in project knime-core by knime.

the class DecisionTreeLearnerNodeModel2 method execute.

/**
 * Start of decision tree induction.
 *
 * @param exec the execution context for this run
 * @param data the input data to build the decision tree from
 * @return an empty data table array, as just a model is provided
 * @throws Exception any type of exception, e.g. for cancellation,
 *         invalid input,...
 * @see NodeModel#execute(BufferedDataTable[],ExecutionContext)
 */
@Override
protected PortObject[] execute(final PortObject[] data, final ExecutionContext exec) throws Exception {
    // holds the warning message displayed after execution
    m_warningMessageSb = new StringBuilder();
    ParallelProcessing parallelProcessing = new ParallelProcessing(m_parallelProcessing.getIntValue());
    if (LOGGER.isDebugEnabled()) {
        LOGGER.debug("Number available threads: " + parallelProcessing.getMaxNumberThreads() + " used threads: " + parallelProcessing.getCurrentThreadsInUse());
    }
    exec.setProgress("Preparing...");
    // check input data
    assert (data != null && data[DATA_INPORT] != null);
    BufferedDataTable inData = (BufferedDataTable) data[DATA_INPORT];
    // get column with color information
    String colorColumn = null;
    for (DataColumnSpec s : inData.getDataTableSpec()) {
        if (s.getColorHandler() != null) {
            colorColumn = s.getName();
            break;
        }
    }
    // the data table must have more than 2 records
    if (inData.size() <= 1) {
        throw new IllegalArgumentException("Input data table must have at least 2 records!");
    }
    // get class column index
    int classColumnIndex = inData.getDataTableSpec().findColumnIndex(m_classifyColumn.getStringValue());
    assert classColumnIndex > -1;
    // create initial In-Memory table
    exec.setProgress("Create initial In-Memory table...");
    InMemoryTableCreator tableCreator = new InMemoryTableCreator(inData, classColumnIndex, m_minNumberRecordsPerNode.getIntValue(), m_skipColumns.getBooleanValue());
    InMemoryTable initialTable = tableCreator.createInMemoryTable(exec.createSubExecutionContext(0.05));
    int removedRows = tableCreator.getRemovedRowsDueToMissingClassValue();
    if (removedRows == inData.size()) {
        throw new IllegalArgumentException("Class column contains only " + "missing values");
    }
    if (removedRows > 0) {
        m_warningMessageSb.append(removedRows);
        m_warningMessageSb.append(" rows removed due to missing class value;");
    }
    // the all over row count is used to report progress
    m_alloverRowCount = initialTable.getSumOfWeights();
    // set the finishing counter
    // this counter will always be incremented when a leaf node is
    // created, as this determines the recursion end and can thus
    // be used for progress indication
    m_finishedCounter = new AtomicDouble(0);
    // get the number of attributes
    m_numberAttributes = initialTable.getNumAttributes();
    // create the quality measure
    final SplitQualityMeasure splitQualityMeasure;
    if (m_splitQualityMeasureType.getStringValue().equals(SPLIT_QUALITY_GINI)) {
        splitQualityMeasure = new SplitQualityGini();
    } else {
        splitQualityMeasure = new SplitQualityGainRatio();
    }
    // build the tree
    // before this set the node counter to 0
    m_counter.set(0);
    exec.setMessage("Building tree...");
    final int firstSplitColIdx = initialTable.getAttributeIndex(m_firstSplitCol.getStringValue());
    DecisionTreeNode root = null;
    root = buildTree(initialTable, exec, 0, splitQualityMeasure, parallelProcessing, firstSplitColIdx);
    boolean isBinaryNominal = m_binaryNominalSplitMode.getBooleanValue();
    boolean isFilterInvalidAttributeValues = m_filterNominalValuesFromParent.getBooleanValue();
    if (isBinaryNominal && isFilterInvalidAttributeValues) {
        // traverse tree nodes and remove from the children the attribute
        // values that were filtered out further up in the tree. "Bug" 3124
        root.filterIllegalAttributes(Collections.<String, Set<String>>emptyMap());
    }
    // the decision tree model saved as PMML at the second out-port
    DecisionTree decisionTree = new DecisionTree(root, m_classifyColumn.getStringValue(), /* strategy has to be set explicitly as the default in PMML is
                    none, which means rows with missing values are not
                    classified. */
    PMMLMissingValueStrategy.get(m_missingValues.getStringValue()), PMMLNoTrueChildStrategy.get(m_noTrueChild.getStringValue()));
    decisionTree.setColorColumn(colorColumn);
    // prune the tree
    exec.setMessage("Prune tree with " + m_pruningMethod.getStringValue() + "...");
    pruneTree(decisionTree);
    // add highlight patterns and color information
    exec.setMessage("Adding hilite and color info to tree...");
    addHiliteAndColorInfo(inData, decisionTree);
    LOGGER.info("Decision tree consisting of " + decisionTree.getNumberNodes() + " nodes created with pruning method " + m_pruningMethod.getStringValue());
    // set the warning message if available
    if (m_warningMessageSb.length() > 0) {
        setWarningMessage(m_warningMessageSb.toString());
    }
    // reset the number available threads
    parallelProcessing.reset();
    parallelProcessing = null;
    // no data out table is created -> return an empty table array
    exec.setMessage("Creating PMML decision tree model...");
    // handle the optional PMML input
    PMMLPortObject inPMMLPort = m_pmmlInEnabled ? (PMMLPortObject) data[1] : null;
    DataTableSpec inSpec = inData.getSpec();
    PMMLPortObjectSpec outPortSpec = createPMMLPortObjectSpec(inPMMLPort == null ? null : inPMMLPort.getSpec(), inSpec);
    PMMLPortObject outPMMLPort = new PMMLPortObject(outPortSpec, inPMMLPort, inData.getSpec());
    outPMMLPort.addModelTranslater(new PMMLDecisionTreeTranslator(decisionTree));
    m_decisionTree = decisionTree;
    return new PortObject[] { outPMMLPort };
}
Also used : DecisionTree(org.knime.base.node.mine.decisiontree2.model.DecisionTree) DataTableSpec(org.knime.core.data.DataTableSpec) PMMLPortObjectSpec(org.knime.core.node.port.pmml.PMMLPortObjectSpec) PMMLDecisionTreeTranslator(org.knime.base.node.mine.decisiontree2.PMMLDecisionTreeTranslator) SettingsModelString(org.knime.core.node.defaultnodesettings.SettingsModelString) DataColumnSpec(org.knime.core.data.DataColumnSpec) PMMLPortObject(org.knime.core.node.port.pmml.PMMLPortObject) BufferedDataTable(org.knime.core.node.BufferedDataTable) PortObject(org.knime.core.node.port.PortObject) PMMLPortObject(org.knime.core.node.port.pmml.PMMLPortObject) DecisionTreeNode(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode)

Example 9 with DecisionTree

use of org.knime.base.node.mine.decisiontree2.model.DecisionTree in project knime-core by knime.

the class Pruner method trainingErrorPruning.

// private static double estimatedError(final double all, final double error,
// final double zValue) {
// double f = error / all;
// double z = zValue;
// double N = all;
// 
// double estimatedError =
// (f + z * z / (2 * N) + z
// * Math.sqrt(f / N - f * f / N + z * z / (4 * N * N)))
// / (1 + z * z / N);
// 
// // return the weighted value
// return estimatedError * all;
// }
// 
/**
 * Prunes a {@link DecisionTree} according to the training error. I.e.
 * if the error in the subtree according to the training data is the same
 * as in the current node, the subtree is pruned, as nothing is gained.
 *
 * @param decTree the decision tree to prune
 */
public static void trainingErrorPruning(final DecisionTree decTree) {
    // traverse the tree depth first (in-fix)
    DecisionTreeNode root = decTree.getRootNode();
    trainingErrorPruningRecurse(root);
}
Also used : DecisionTreeNode(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode)

Example 10 with DecisionTree

use of org.knime.base.node.mine.decisiontree2.model.DecisionTree in project knime-core by knime.

the class Pruner method mdlPruningRecurse.

// /**
// * The general idea is to recursively prune the children and then compare
// * the potential leaf estimated erro with the actual estimated error
// * including the length of the children.
// *
// * @param node the node to prune
// * @param zValue the z value according to which the error is estimated
// *            calculated from the confidence value
// *
// * @return the resulting description length after pruning; this value is
// *         used in higher levels of the recursion, i.e. for the parent node
// */
// private static PruningResult estimatedErrorPruningRecurse(
// final DecisionTreeNode node, final double zValue) {
// 
// // if this is a child, just return the estimated error
// if (node.isLeaf()) {
// double error = node.getEntireClassCount() - node.getOwnClassCount();
// double estimatedError =
// estimatedError(node.getEntireClassCount(), error, zValue);
// 
// return new PruningResult(estimatedError, node);
// }
// 
// // holds the estimated errors of the children
// double[] childDescriptionLength = new double[node.getChildCount()];
// DecisionTreeNodeSplit splitNode = (DecisionTreeNodeSplit)node;
// // prune all children
// DecisionTreeNode[] children = splitNode.getChildren();
// int count = 0;
// for (DecisionTreeNode childNode : children) {
// 
// PruningResult result =
// estimatedErrorPruningRecurse(childNode, zValue);
// childDescriptionLength[count] = result.getQualityValue();
// 
// // replace the child with the one from the result (could of course
// // be the same)
// splitNode.replaceChild(childNode, result.getNode());
// 
// count++;
// }
// 
// // calculate the estimated error if this would be a leaf
// double error = node.getEntireClassCount() - node.getOwnClassCount();
// double leafEstimatedError =
// estimatedError(node.getEntireClassCount(), error, zValue);
// 
// // calculate the current estimated error (sum of estimated errors of the
// // children)
// double currentEstimatedError = 0;
// for (double childDescLength : childDescriptionLength) {
// currentEstimatedError += childDescLength;
// }
// 
// // define the return node
// DecisionTreeNode returnNode = node;
// double returnEstimatedError = currentEstimatedError;
// 
// // if the possible leaf costs are smaller, replace this node
// // with a leaf (tollerance is 0.1)
// if (leafEstimatedError <= currentEstimatedError + 0.1) {
// DecisionTreeNodeLeaf newLeaf =
// new DecisionTreeNodeLeaf(node.getOwnIndex(), node
// .getMajorityClass(), node.getClassCounts());
// newLeaf.setParent((DecisionTreeNode)node.getParent());
// newLeaf.setPrefix(node.getPrefix());
// returnNode = newLeaf;
// returnEstimatedError = leafEstimatedError;
// }
// 
// return new PruningResult(returnEstimatedError, returnNode);
// }
// 
// /**
// * Prunes a {@link DecisionTree} according to the estimated error pruning
// * (Quinlan 87).
// *
// * @param decTree the decision tree to prune
// * @param confidence the confidence value according to which the error is
// *            estimated
// */
// public static void estimatedErrorPruning(final DecisionTree decTree,
// final double confidence) {
// 
// // traverse the tree depth first (in-fix)
// DecisionTreeNode root = decTree.getRootNode();
// // double zValue = xnormi(1 - confidence);
// estimatedErrorPruningRecurse(root, zValue);
// }
/**
 * The general idea is to recursively prune the children and then compare
 * the potential leaf description length with the actual length including
 * the length of the children.
 *
 * @param node the node to prune
 *
 * @return the resulting description length after pruning; this value is
 *         used in higher levels of the recursion, i.e. for the parent node
 */
private static PruningResult mdlPruningRecurse(final DecisionTreeNode node) {
    // leaf
    if (node.isLeaf()) {
        double error = node.getEntireClassCount() - node.getOwnClassCount();
        // node => 1Bit)
        return new PruningResult(error + 1.0, node);
    }
    // holds the description length of the children
    double[] childDescriptionLength = new double[node.getChildCount()];
    DecisionTreeNodeSplit splitNode = (DecisionTreeNodeSplit) node;
    // prune all children
    DecisionTreeNode[] children = splitNode.getChildren();
    int count = 0;
    for (DecisionTreeNode childNode : children) {
        PruningResult result = mdlPruningRecurse(childNode);
        childDescriptionLength[count] = result.getQualityValue();
        // replace the child with the one from the result (could of course
        // be the same)
        splitNode.replaceChild(childNode, result.getNode());
        count++;
    }
    // calculate the cost if this would be a leaf
    double leafCost = node.getEntireClassCount() - node.getOwnClassCount() + 1.0;
    // calculate the current cost including the children
    double currentCost = 1.0 + Math.log(node.getChildCount()) / Math.log(2);
    for (double childDescLength : childDescriptionLength) {
        currentCost += childDescLength;
    }
    // define the return node
    DecisionTreeNode returnNode = node;
    double returnCost = currentCost;
    // with a leaf
    if (leafCost <= currentCost) {
        DecisionTreeNodeLeaf newLeaf = new DecisionTreeNodeLeaf(node.getOwnIndex(), node.getMajorityClass(), node.getClassCounts());
        newLeaf.setParent(node.getParent());
        newLeaf.setPrefix(node.getPrefix());
        returnNode = newLeaf;
        returnCost = leafCost;
    }
    return new PruningResult(returnCost, returnNode);
}
Also used : DecisionTreeNodeSplit(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplit) DecisionTreeNodeLeaf(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeLeaf) DecisionTreeNode(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode)

Aggregations

DecisionTree (org.knime.base.node.mine.decisiontree2.model.DecisionTree)24 DecisionTreeNode (org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode)11 IOException (java.io.IOException)9 CanceledExecutionException (org.knime.core.node.CanceledExecutionException)9 DataTableSpec (org.knime.core.data.DataTableSpec)8 InvalidSettingsException (org.knime.core.node.InvalidSettingsException)7 DataRow (org.knime.core.data.DataRow)6 BufferedInputStream (java.io.BufferedInputStream)5 File (java.io.File)5 FileInputStream (java.io.FileInputStream)5 GZIPInputStream (java.util.zip.GZIPInputStream)5 PMMLDecisionTreeTranslator (org.knime.base.node.mine.decisiontree2.PMMLDecisionTreeTranslator)5 PMMLPortObject (org.knime.core.node.port.pmml.PMMLPortObject)5 DefaultTreeModel (javax.swing.tree.DefaultTreeModel)4 DecisionTreeNodeRenderer (org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeRenderer)4 DataColumnSpec (org.knime.core.data.DataColumnSpec)4 BufferedDataTable (org.knime.core.node.BufferedDataTable)4 ModelContentRO (org.knime.core.node.ModelContentRO)4 PMMLPortObjectSpec (org.knime.core.node.port.pmml.PMMLPortObjectSpec)4 DecisionTreeNodeLeaf (org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeLeaf)3