Search in sources :

Example 11 with DecisionTreeNode

use of org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode in project knime-core by knime.

the class TreeEnsembleLearnerNodeView method getSubtree.

private List<DecisionTreeNode> getSubtree(final DecisionTreeNode node) {
    List<DecisionTreeNode> subTree = new ArrayList<DecisionTreeNode>();
    List<DecisionTreeNode> toProcess = new LinkedList<DecisionTreeNode>();
    toProcess.add(0, node);
    // Traverse the tree breadth first
    while (!toProcess.isEmpty()) {
        DecisionTreeNode curr = toProcess.remove(0);
        subTree.add(curr);
        for (int i = 0; i < curr.getChildCount(); i++) {
            toProcess.add(0, curr.getChildAt(i));
        }
    }
    return subTree;
}
Also used : ArrayList(java.util.ArrayList) LinkedList(java.util.LinkedList) DecisionTreeNode(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode)

Example 12 with DecisionTreeNode

use of org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode in project knime-core by knime.

the class DecTreeNodeView2 method changeSelectedHiLite.

// /////////////////////////////
// routines for HiLite Support
// /////////////////////////////
/*
     * hilite or unhilite all items that are covered by currently selected
     * branches in the tree
     *
     * @param state if true hilite, otherwise unhilite selection
     */
private void changeSelectedHiLite(final boolean state) {
    TreePath[] selectedPaths = m_jTree.getSelectionPaths();
    if (selectedPaths == null) {
        // nothing selected
        return;
    }
    for (int i = 0; i < selectedPaths.length; i++) {
        assert (selectedPaths[i] != null);
        if (selectedPaths[i] == null) {
            return;
        }
        TreePath path = selectedPaths[i];
        Object lastNode = path.getLastPathComponent();
        assert (lastNode != null);
        assert (lastNode instanceof DecisionTreeNode);
        Set<RowKey> covPat = ((DecisionTreeNode) lastNode).coveredPattern();
        if (state) {
            m_hiLiteHdl.fireHiLiteEvent(covPat);
        } else {
            m_hiLiteHdl.fireUnHiLiteEvent(covPat);
        }
    }
}
Also used : TreePath(javax.swing.tree.TreePath) RowKey(org.knime.core.data.RowKey) DecisionTreeNode(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode)

Example 13 with DecisionTreeNode

use of org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode in project knime-core by knime.

the class PMMLDecisionTreeTranslator method addTreeNode.

/**
 * A recursive function which converts each KNIME Tree node to a
 * corresponding PMML element.
 *
 * @param pmmlNode the desired PMML element
 * @param node A KNIME DecisionTree node
 */
private static void addTreeNode(final NodeDocument.Node pmmlNode, final DecisionTreeNode node, final DerivedFieldMapper mapper) {
    pmmlNode.setId(String.valueOf(node.getOwnIndex()));
    pmmlNode.setScore(node.getMajorityClass().toString());
    // read in and then exported again
    if (node.getEntireClassCount() > 0) {
        pmmlNode.setRecordCount(node.getEntireClassCount());
    }
    if (node instanceof DecisionTreeNodeSplitPMML) {
        int defaultChild = ((DecisionTreeNodeSplitPMML) node).getDefaultChildIndex();
        if (defaultChild > -1) {
            pmmlNode.setDefaultChild(String.valueOf(defaultChild));
        }
    }
    // adding score and stuff from parent
    DecisionTreeNode parent = node.getParent();
    if (parent == null) {
        // When the parent is null, it is the root Node.
        // For root node, the predicate is always True.
        pmmlNode.addNewTrue();
    } else if (parent instanceof DecisionTreeNodeSplitContinuous) {
        // SimplePredicate case
        DecisionTreeNodeSplitContinuous splitNode = (DecisionTreeNodeSplitContinuous) parent;
        if (splitNode.getIndex(node) == 0) {
            SimplePredicate pmmlSimplePredicate = pmmlNode.addNewSimplePredicate();
            pmmlSimplePredicate.setField(mapper.getDerivedFieldName(splitNode.getSplitAttr()));
            pmmlSimplePredicate.setOperator(Operator.LESS_OR_EQUAL);
            pmmlSimplePredicate.setValue(String.valueOf(splitNode.getThreshold()));
        } else if (splitNode.getIndex(node) == 1) {
            pmmlNode.addNewTrue();
        }
    } else if (parent instanceof DecisionTreeNodeSplitNominalBinary) {
        // SimpleSetPredicate case
        DecisionTreeNodeSplitNominalBinary splitNode = (DecisionTreeNodeSplitNominalBinary) parent;
        SimpleSetPredicate pmmlSimpleSetPredicate = pmmlNode.addNewSimpleSetPredicate();
        pmmlSimpleSetPredicate.setField(mapper.getDerivedFieldName(splitNode.getSplitAttr()));
        pmmlSimpleSetPredicate.setBooleanOperator(SimpleSetPredicate.BooleanOperator.IS_IN);
        ArrayType pmmlArray = pmmlSimpleSetPredicate.addNewArray();
        pmmlArray.setType(ArrayType.Type.STRING);
        DataCell[] splitValues = splitNode.getSplitValues();
        List<Integer> indices = null;
        if (splitNode.getIndex(node) == SplitNominalBinary.LEFT_PARTITION) {
            indices = splitNode.getLeftChildIndices();
        } else if (splitNode.getIndex(node) == SplitNominalBinary.RIGHT_PARTITION) {
            indices = splitNode.getRightChildIndices();
        } else {
            throw new IllegalArgumentException("Split node is neither " + "contained in the right nor in the left partition.");
        }
        StringBuilder classSet = new StringBuilder();
        for (Integer i : indices) {
            if (classSet.length() > 0) {
                classSet.append(" ");
            }
            classSet.append(splitValues[i].toString());
        }
        pmmlArray.setN(BigInteger.valueOf(indices.size()));
        XmlCursor xmlCursor = pmmlArray.newCursor();
        xmlCursor.setTextValue(classSet.toString());
        xmlCursor.dispose();
    } else if (parent instanceof DecisionTreeNodeSplitNominal) {
        DecisionTreeNodeSplitNominal splitNode = (DecisionTreeNodeSplitNominal) parent;
        SimplePredicate pmmlSimplePredicate = pmmlNode.addNewSimplePredicate();
        pmmlSimplePredicate.setField(mapper.getDerivedFieldName(splitNode.getSplitAttr()));
        pmmlSimplePredicate.setOperator(Operator.EQUAL);
        int nodeIndex = parent.getIndex(node);
        pmmlSimplePredicate.setValue(String.valueOf(splitNode.getSplitValues()[nodeIndex].toString()));
    } else if (parent instanceof DecisionTreeNodeSplitPMML) {
        DecisionTreeNodeSplitPMML splitNode = (DecisionTreeNodeSplitPMML) parent;
        int nodeIndex = parent.getIndex(node);
        // get the PMML predicate of the current node from its parent
        PMMLPredicate predicate = splitNode.getSplitPred()[nodeIndex];
        if (predicate instanceof PMMLCompoundPredicate) {
            // surrogates as used in GBT
            exportCompoundPredicate(pmmlNode, (PMMLCompoundPredicate) predicate, mapper);
        } else {
            predicate.setSplitAttribute(mapper.getDerivedFieldName(predicate.getSplitAttribute()));
            // delegate the writing to the predicate translator
            PMMLPredicateTranslator.exportTo(predicate, pmmlNode);
        }
    } else {
        throw new IllegalArgumentException("Node Type " + parent.getClass() + " is not supported!");
    }
    // adding score distribution (class counts)
    Set<Entry<DataCell, Double>> classCounts = node.getClassCounts().entrySet();
    Iterator<Entry<DataCell, Double>> iterator = classCounts.iterator();
    while (iterator.hasNext()) {
        Entry<DataCell, Double> entry = iterator.next();
        DataCell cell = entry.getKey();
        Double freq = entry.getValue();
        ScoreDistribution pmmlScoreDist = pmmlNode.addNewScoreDistribution();
        pmmlScoreDist.setValue(cell.toString());
        pmmlScoreDist.setRecordCount(freq);
    }
    // adding children
    if (!(node instanceof DecisionTreeNodeLeaf)) {
        for (int i = 0; i < node.getChildCount(); i++) {
            addTreeNode(pmmlNode.addNewNode(), node.getChildAt(i), mapper);
        }
    }
}
Also used : DecisionTreeNodeSplitNominal(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplitNominal) ArrayType(org.dmg.pmml.ArrayType) Entry(java.util.Map.Entry) DecisionTreeNodeSplitPMML(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplitPMML) SimplePredicate(org.dmg.pmml.SimplePredicateDocument.SimplePredicate) SimpleSetPredicate(org.dmg.pmml.SimpleSetPredicateDocument.SimpleSetPredicate) XmlCursor(org.apache.xmlbeans.XmlCursor) BigInteger(java.math.BigInteger) ScoreDistribution(org.dmg.pmml.ScoreDistributionDocument.ScoreDistribution) DecisionTreeNodeLeaf(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeLeaf) DecisionTreeNodeSplitNominalBinary(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplitNominalBinary) DataCell(org.knime.core.data.DataCell) DecisionTreeNodeSplitContinuous(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplitContinuous) DecisionTreeNode(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode)

Example 14 with DecisionTreeNode

use of org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode in project knime-core by knime.

the class DecisionTreeLearnerNodeModel2 method execute.

/**
 * Start of decision tree induction.
 *
 * @param exec the execution context for this run
 * @param data the input data to build the decision tree from
 * @return an empty data table array, as just a model is provided
 * @throws Exception any type of exception, e.g. for cancellation,
 *         invalid input,...
 * @see NodeModel#execute(BufferedDataTable[],ExecutionContext)
 */
@Override
protected PortObject[] execute(final PortObject[] data, final ExecutionContext exec) throws Exception {
    // holds the warning message displayed after execution
    m_warningMessageSb = new StringBuilder();
    ParallelProcessing parallelProcessing = new ParallelProcessing(m_parallelProcessing.getIntValue());
    if (LOGGER.isDebugEnabled()) {
        LOGGER.debug("Number available threads: " + parallelProcessing.getMaxNumberThreads() + " used threads: " + parallelProcessing.getCurrentThreadsInUse());
    }
    exec.setProgress("Preparing...");
    // check input data
    assert (data != null && data[DATA_INPORT] != null);
    BufferedDataTable inData = (BufferedDataTable) data[DATA_INPORT];
    // get column with color information
    String colorColumn = null;
    for (DataColumnSpec s : inData.getDataTableSpec()) {
        if (s.getColorHandler() != null) {
            colorColumn = s.getName();
            break;
        }
    }
    // the data table must have more than 2 records
    if (inData.size() <= 1) {
        throw new IllegalArgumentException("Input data table must have at least 2 records!");
    }
    // get class column index
    int classColumnIndex = inData.getDataTableSpec().findColumnIndex(m_classifyColumn.getStringValue());
    assert classColumnIndex > -1;
    // create initial In-Memory table
    exec.setProgress("Create initial In-Memory table...");
    InMemoryTableCreator tableCreator = new InMemoryTableCreator(inData, classColumnIndex, m_minNumberRecordsPerNode.getIntValue(), m_skipColumns.getBooleanValue());
    InMemoryTable initialTable = tableCreator.createInMemoryTable(exec.createSubExecutionContext(0.05));
    int removedRows = tableCreator.getRemovedRowsDueToMissingClassValue();
    if (removedRows == inData.size()) {
        throw new IllegalArgumentException("Class column contains only " + "missing values");
    }
    if (removedRows > 0) {
        m_warningMessageSb.append(removedRows);
        m_warningMessageSb.append(" rows removed due to missing class value;");
    }
    // the all over row count is used to report progress
    m_alloverRowCount = initialTable.getSumOfWeights();
    // set the finishing counter
    // this counter will always be incremented when a leaf node is
    // created, as this determines the recursion end and can thus
    // be used for progress indication
    m_finishedCounter = new AtomicDouble(0);
    // get the number of attributes
    m_numberAttributes = initialTable.getNumAttributes();
    // create the quality measure
    final SplitQualityMeasure splitQualityMeasure;
    if (m_splitQualityMeasureType.getStringValue().equals(SPLIT_QUALITY_GINI)) {
        splitQualityMeasure = new SplitQualityGini();
    } else {
        splitQualityMeasure = new SplitQualityGainRatio();
    }
    // build the tree
    // before this set the node counter to 0
    m_counter.set(0);
    exec.setMessage("Building tree...");
    final int firstSplitColIdx = initialTable.getAttributeIndex(m_firstSplitCol.getStringValue());
    DecisionTreeNode root = null;
    root = buildTree(initialTable, exec, 0, splitQualityMeasure, parallelProcessing, firstSplitColIdx);
    boolean isBinaryNominal = m_binaryNominalSplitMode.getBooleanValue();
    boolean isFilterInvalidAttributeValues = m_filterNominalValuesFromParent.getBooleanValue();
    if (isBinaryNominal && isFilterInvalidAttributeValues) {
        // traverse tree nodes and remove from the children the attribute
        // values that were filtered out further up in the tree. "Bug" 3124
        root.filterIllegalAttributes(Collections.<String, Set<String>>emptyMap());
    }
    // the decision tree model saved as PMML at the second out-port
    DecisionTree decisionTree = new DecisionTree(root, m_classifyColumn.getStringValue(), /* strategy has to be set explicitly as the default in PMML is
                    none, which means rows with missing values are not
                    classified. */
    PMMLMissingValueStrategy.get(m_missingValues.getStringValue()), PMMLNoTrueChildStrategy.get(m_noTrueChild.getStringValue()));
    decisionTree.setColorColumn(colorColumn);
    // prune the tree
    exec.setMessage("Prune tree with " + m_pruningMethod.getStringValue() + "...");
    pruneTree(decisionTree);
    // add highlight patterns and color information
    exec.setMessage("Adding hilite and color info to tree...");
    addHiliteAndColorInfo(inData, decisionTree);
    LOGGER.info("Decision tree consisting of " + decisionTree.getNumberNodes() + " nodes created with pruning method " + m_pruningMethod.getStringValue());
    // set the warning message if available
    if (m_warningMessageSb.length() > 0) {
        setWarningMessage(m_warningMessageSb.toString());
    }
    // reset the number available threads
    parallelProcessing.reset();
    parallelProcessing = null;
    // no data out table is created -> return an empty table array
    exec.setMessage("Creating PMML decision tree model...");
    // handle the optional PMML input
    PMMLPortObject inPMMLPort = m_pmmlInEnabled ? (PMMLPortObject) data[1] : null;
    DataTableSpec inSpec = inData.getSpec();
    PMMLPortObjectSpec outPortSpec = createPMMLPortObjectSpec(inPMMLPort == null ? null : inPMMLPort.getSpec(), inSpec);
    PMMLPortObject outPMMLPort = new PMMLPortObject(outPortSpec, inPMMLPort, inData.getSpec());
    outPMMLPort.addModelTranslater(new PMMLDecisionTreeTranslator(decisionTree));
    m_decisionTree = decisionTree;
    return new PortObject[] { outPMMLPort };
}
Also used : DecisionTree(org.knime.base.node.mine.decisiontree2.model.DecisionTree) DataTableSpec(org.knime.core.data.DataTableSpec) PMMLPortObjectSpec(org.knime.core.node.port.pmml.PMMLPortObjectSpec) PMMLDecisionTreeTranslator(org.knime.base.node.mine.decisiontree2.PMMLDecisionTreeTranslator) SettingsModelString(org.knime.core.node.defaultnodesettings.SettingsModelString) DataColumnSpec(org.knime.core.data.DataColumnSpec) PMMLPortObject(org.knime.core.node.port.pmml.PMMLPortObject) BufferedDataTable(org.knime.core.node.BufferedDataTable) PortObject(org.knime.core.node.port.PortObject) PMMLPortObject(org.knime.core.node.port.pmml.PMMLPortObject) DecisionTreeNode(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode)

Example 15 with DecisionTreeNode

use of org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode in project knime-core by knime.

the class Pruner method trainingErrorPruning.

// private static double estimatedError(final double all, final double error,
// final double zValue) {
// double f = error / all;
// double z = zValue;
// double N = all;
// 
// double estimatedError =
// (f + z * z / (2 * N) + z
// * Math.sqrt(f / N - f * f / N + z * z / (4 * N * N)))
// / (1 + z * z / N);
// 
// // return the weighted value
// return estimatedError * all;
// }
// 
/**
 * Prunes a {@link DecisionTree} according to the training error. I.e.
 * if the error in the subtree according to the training data is the same
 * as in the current node, the subtree is pruned, as nothing is gained.
 *
 * @param decTree the decision tree to prune
 */
public static void trainingErrorPruning(final DecisionTree decTree) {
    // traverse the tree depth first (in-fix)
    DecisionTreeNode root = decTree.getRootNode();
    trainingErrorPruningRecurse(root);
}
Also used : DecisionTreeNode(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode)

Aggregations

DecisionTreeNode (org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode)50 RowKey (org.knime.core.data.RowKey)18 HashSet (java.util.HashSet)14 LinkedList (java.util.LinkedList)14 DecisionTreeNodeLeaf (org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeLeaf)12 ArrayList (java.util.ArrayList)10 DecisionTreeNodeSplitPMML (org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplitPMML)10 DataCell (org.knime.core.data.DataCell)9 Action (javax.swing.Action)8 JMenu (javax.swing.JMenu)8 CollapseBranchAction (org.knime.base.node.mine.decisiontree2.view.graph.CollapseBranchAction)8 ExpandBranchAction (org.knime.base.node.mine.decisiontree2.view.graph.ExpandBranchAction)8 PMMLPredicate (org.knime.base.node.mine.decisiontree2.PMMLPredicate)7 PMMLDecisionTreeTranslator (org.knime.base.node.mine.decisiontree2.PMMLDecisionTreeTranslator)5 DecisionTree (org.knime.base.node.mine.decisiontree2.model.DecisionTree)5 TreePath (javax.swing.tree.TreePath)4 PortObject (org.knime.core.node.port.PortObject)4 PMMLPortObject (org.knime.core.node.port.pmml.PMMLPortObject)4 LinkedHashMap (java.util.LinkedHashMap)3 PMMLSimplePredicate (org.knime.base.node.mine.decisiontree2.PMMLSimplePredicate)3