Search in sources :

Example 26 with DecisionTreeNode

use of org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode in project knime-core by knime.

the class DecisionTreeLearnerNodeModel method execute.

/**
 * Start of decision tree induction.
 *
 * @param exec the execution context for this run
 * @param data the input data to build the decision tree from
 * @return an empty data table array, as just a model is provided
 * @throws Exception any type of exception, e.g. for cancellation,
 *         invalid input,...
 * @see NodeModel#execute(BufferedDataTable[],ExecutionContext)
 */
@Override
protected PortObject[] execute(final PortObject[] data, final ExecutionContext exec) throws Exception {
    // holds the warning message displayed after execution
    StringBuilder warningMessageSb = new StringBuilder();
    ParallelProcessing parallelProcessing = new ParallelProcessing(m_parallelProcessing.getIntValue());
    if (LOGGER.isDebugEnabled()) {
        LOGGER.debug("Number available threads: " + parallelProcessing.getMaxNumberThreads() + " used threads: " + parallelProcessing.getCurrentThreadsInUse());
    }
    exec.setProgress("Preparing...");
    // check input data
    assert (data != null && data[DATA_INPORT] != null);
    BufferedDataTable inData = (BufferedDataTable) data[DATA_INPORT];
    // get column with color information
    String colorColumn = null;
    for (DataColumnSpec s : inData.getDataTableSpec()) {
        if (s.getColorHandler() != null) {
            colorColumn = s.getName();
            break;
        }
    }
    // the data table must have more than 2 records
    if (inData.getRowCount() <= 1) {
        throw new IllegalArgumentException("Input data table must have at least 2 records!");
    }
    // get class column index
    int classColumnIndex = inData.getDataTableSpec().findColumnIndex(m_classifyColumn.getStringValue());
    assert classColumnIndex > -1;
    // create initial In-Memory table
    exec.setProgress("Create initial In-Memory table...");
    InMemoryTableCreator tableCreator = new InMemoryTableCreator(inData, classColumnIndex, m_minNumberRecordsPerNode.getIntValue(), m_skipColumns.getBooleanValue());
    InMemoryTable initialTable = tableCreator.createInMemoryTable(exec.createSubExecutionContext(0.05));
    int removedRows = tableCreator.getRemovedRowsDueToMissingClassValue();
    if (removedRows == inData.getRowCount()) {
        throw new IllegalArgumentException("Class column contains only " + "missing values");
    }
    if (removedRows > 0) {
        warningMessageSb.append(removedRows);
        warningMessageSb.append(" rows removed due to missing class value;");
    }
    // the all over row count is used to report progress
    m_alloverRowCount = initialTable.getSumOfWeights();
    // set the finishing counter
    // this counter will always be incremented when a leaf node is
    // created, as this determines the recursion end and can thus
    // be used for progress indication
    m_finishedCounter = new AtomicDouble(0);
    // get the number of attributes
    m_numberAttributes = initialTable.getNumAttributes();
    // create the quality measure
    final SplitQualityMeasure splitQualityMeasure;
    if (m_splitQualityMeasureType.getStringValue().equals(SPLIT_QUALITY_GINI)) {
        splitQualityMeasure = new SplitQualityGini();
    } else {
        splitQualityMeasure = new SplitQualityGainRatio();
    }
    // build the tree
    // before this set the node counter to 0
    m_counter.set(0);
    exec.setMessage("Building tree...");
    DecisionTreeNode root = null;
    root = buildTree(initialTable, exec, 0, splitQualityMeasure, parallelProcessing);
    boolean isBinaryNominal = m_binaryNominalSplitMode.getBooleanValue();
    boolean isFilterInvalidAttributeValues = m_filterNominalValuesFromParent.getBooleanValue();
    if (isBinaryNominal && isFilterInvalidAttributeValues) {
        // traverse tree nodes and remove from the children the attribute
        // values that were filtered out further up in the tree. "Bug" 3124
        root.filterIllegalAttributes(Collections.EMPTY_MAP);
    }
    // the decision tree model saved as PMML at the second out-port
    DecisionTree decisionTree = new DecisionTree(root, m_classifyColumn.getStringValue(), /* strategy has to be set explicitly as the default in PMML is
                    none, which means rows with missing values are not
                    classified. */
    PMMLMissingValueStrategy.LAST_PREDICTION);
    decisionTree.setColorColumn(colorColumn);
    // prune the tree
    exec.setMessage("Prune tree with " + m_pruningMethod.getStringValue() + "...");
    pruneTree(decisionTree);
    // add highlight patterns and color information
    exec.setMessage("Adding hilite and color info to tree...");
    addHiliteAndColorInfo(inData, decisionTree);
    LOGGER.info("Decision tree consisting of " + decisionTree.getNumberNodes() + " nodes created with pruning method " + m_pruningMethod.getStringValue());
    // set the warning message if available
    if (warningMessageSb.length() > 0) {
        setWarningMessage(warningMessageSb.toString());
    }
    // reset the number available threads
    parallelProcessing.reset();
    parallelProcessing = null;
    // no data out table is created -> return an empty table array
    exec.setMessage("Creating PMML decision tree model...");
    // handle the optional PMML input
    PMMLPortObject inPMMLPort = (PMMLPortObject) data[1];
    DataTableSpec inSpec = inData.getSpec();
    PMMLPortObjectSpec outPortSpec = createPMMLPortObjectSpec(inPMMLPort == null ? null : inPMMLPort.getSpec(), inSpec);
    PMMLPortObject outPMMLPort = new PMMLPortObject(outPortSpec, inPMMLPort, inData.getSpec());
    outPMMLPort.addModelTranslater(new PMMLDecisionTreeTranslator(decisionTree));
    m_decisionTree = decisionTree;
    return new PortObject[] { outPMMLPort };
}
Also used : DecisionTree(org.knime.base.node.mine.decisiontree2.model.DecisionTree) DataTableSpec(org.knime.core.data.DataTableSpec) PMMLPortObjectSpec(org.knime.core.node.port.pmml.PMMLPortObjectSpec) PMMLDecisionTreeTranslator(org.knime.base.node.mine.decisiontree2.PMMLDecisionTreeTranslator) SettingsModelString(org.knime.core.node.defaultnodesettings.SettingsModelString) DataColumnSpec(org.knime.core.data.DataColumnSpec) PMMLPortObject(org.knime.core.node.port.pmml.PMMLPortObject) BufferedDataTable(org.knime.core.node.BufferedDataTable) PortObject(org.knime.core.node.port.PortObject) PMMLPortObject(org.knime.core.node.port.pmml.PMMLPortObject) DecisionTreeNode(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode)

Example 27 with DecisionTreeNode

use of org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode in project knime-core by knime.

the class DecisionTreeLearnerNodeModel method buildTree.

/**
 * Recursively induces the decision tree.
 *
 * @param table the {@link InMemoryTable} representing the data for this
 *            node to determine the split and after that perform
 *            partitioning
 * @param exec the execution context for progress information
 * @param depth the current recursion depth
 */
private DecisionTreeNode buildTree(final InMemoryTable table, final ExecutionContext exec, final int depth, final SplitQualityMeasure splitQualityMeasure, final ParallelProcessing parallelProcessing) throws CanceledExecutionException, IllegalAccessException {
    exec.checkCanceled();
    // derive this node's id from the counter
    int nodeId = m_counter.getAndIncrement();
    DataCell majorityClass = table.getMajorityClassAsCell();
    LinkedHashMap<DataCell, Double> frequencies = table.getClassFrequencies();
    // if the distribution allows for a leaf
    if (table.isPureEnough()) {
        // free memory
        table.freeUnderlyingDataRows();
        double value = m_finishedCounter.incrementAndGet(table.getSumOfWeights());
        exec.setProgress(value / m_alloverRowCount, "Created node with id " + nodeId + " at level " + depth);
        return new DecisionTreeNodeLeaf(nodeId, majorityClass, frequencies);
    } else {
        // find the best splits for all attributes
        SplitFinder splittFinder = new SplitFinder(table, splitQualityMeasure, m_averageSplitpoint.getBooleanValue(), m_minNumberRecordsPerNode.getIntValue(), m_binaryNominalSplitMode.getBooleanValue(), m_maxNumNominalsForCompleteComputation.getIntValue());
        // check for enough memory
        checkMemory();
        // get the best split among the best attribute splits
        Split split = splittFinder.getSplit();
        // if no best split could be evaluated, create a leaf node
        if (split == null || !split.isValidSplit()) {
            table.freeUnderlyingDataRows();
            double value = m_finishedCounter.incrementAndGet(table.getSumOfWeights());
            exec.setProgress(value / m_alloverRowCount, "Created node with id " + nodeId + " at level " + depth);
            return new DecisionTreeNodeLeaf(nodeId, majorityClass, frequencies);
        }
        // partition the attribute lists according to this split
        Partitioner partitioner = new Partitioner(table, split, m_minNumberRecordsPerNode.getIntValue());
        if (!partitioner.couldBeUsefulPartitioned()) {
            table.freeUnderlyingDataRows();
            double value = m_finishedCounter.incrementAndGet(table.getSumOfWeights());
            exec.setProgress(value / m_alloverRowCount, "Created node with id " + nodeId + " at level " + depth);
            return new DecisionTreeNodeLeaf(nodeId, majorityClass, frequencies);
        }
        // get the just created partitions
        InMemoryTable[] partitionTables = partitioner.getPartitionTables();
        // recursively build the  child nodes
        DecisionTreeNode[] children = new DecisionTreeNode[partitionTables.length];
        ArrayList<ParallelBuilding> threads = new ArrayList<ParallelBuilding>();
        int i = 0;
        for (InMemoryTable partitionTable : partitionTables) {
            exec.checkCanceled();
            if (partitionTable.getNumberDataRows() * m_numberAttributes < 10000 || !parallelProcessing.isThreadAvailable()) {
                children[i] = buildTree(partitionTable, exec, depth + 1, splitQualityMeasure, parallelProcessing);
            } else {
                String threadName = "Build thread, node: " + nodeId + "." + i;
                ParallelBuilding buildThread = new ParallelBuilding(threadName, partitionTable, exec, depth + 1, i, splitQualityMeasure, parallelProcessing);
                LOGGER.debug("Start new parallel building thread: " + threadName);
                threads.add(buildThread);
                buildThread.start();
            }
            i++;
        }
        // already assigned to the child array
        for (ParallelBuilding buildThread : threads) {
            children[buildThread.getThreadIndex()] = buildThread.getResultNode();
            exec.checkCanceled();
            if (buildThread.getException() != null) {
                for (ParallelBuilding buildThread2 : threads) {
                    buildThread2.stop();
                }
                throw new RuntimeException(buildThread.getException().getMessage());
            }
        }
        threads.clear();
        if (split instanceof SplitContinuous) {
            double splitValue = ((SplitContinuous) split).getBestSplitValue();
            // return new DecisionTreeNodeSplitContinuous(nodeId,
            // majorityClass, frequencies, split
            // .getSplitAttributeName(), children, splitValue);
            String splitAttribute = split.getSplitAttributeName();
            PMMLPredicate[] splitPredicates = new PMMLPredicate[] { new PMMLSimplePredicate(splitAttribute, PMMLOperator.LESS_OR_EQUAL, Double.toString(splitValue)), new PMMLSimplePredicate(splitAttribute, PMMLOperator.GREATER_THAN, Double.toString(splitValue)) };
            return new DecisionTreeNodeSplitPMML(nodeId, majorityClass, frequencies, splitAttribute, splitPredicates, children);
        } else if (split instanceof SplitNominalNormal) {
            // else the attribute is nominal
            DataCell[] splitValues = ((SplitNominalNormal) split).getSplitValues();
            // return new DecisionTreeNodeSplitNominal(nodeId, majorityClass,
            // frequencies, split.getSplitAttributeName(),
            // splitValues, children);
            int num = children.length;
            PMMLPredicate[] splitPredicates = new PMMLPredicate[num];
            String splitAttribute = split.getSplitAttributeName();
            for (int j = 0; j < num; j++) {
                splitPredicates[j] = new PMMLSimplePredicate(splitAttribute, PMMLOperator.EQUAL, splitValues[j].toString());
            }
            return new DecisionTreeNodeSplitPMML(nodeId, majorityClass, frequencies, splitAttribute, splitPredicates, children);
        } else {
            // binary nominal
            SplitNominalBinary splitNominalBinary = (SplitNominalBinary) split;
            DataCell[] splitValues = splitNominalBinary.getSplitValues();
            // return new DecisionTreeNodeSplitNominalBinary(nodeId,
            // majorityClass, frequencies, split
            // .getSplitAttributeName(), splitValues,
            // splitNominalBinary.getIntMappingsLeftPartition(),
            // splitNominalBinary.getIntMappingsRightPartition(),
            // children/* children[0]=left, ..[1] right */);
            String splitAttribute = split.getSplitAttributeName();
            int[][] indices = new int[][] { splitNominalBinary.getIntMappingsLeftPartition(), splitNominalBinary.getIntMappingsRightPartition() };
            PMMLPredicate[] splitPredicates = new PMMLPredicate[2];
            for (int j = 0; j < splitPredicates.length; j++) {
                PMMLSimpleSetPredicate pred = null;
                pred = new PMMLSimpleSetPredicate(splitAttribute, PMMLSetOperator.IS_IN);
                pred.setArrayType(PMMLArrayType.STRING);
                LinkedHashSet<String> values = new LinkedHashSet<String>();
                for (int index : indices[j]) {
                    values.add(splitValues[index].toString());
                }
                pred.setValues(values);
                splitPredicates[j] = pred;
            }
            return new DecisionTreeNodeSplitPMML(nodeId, majorityClass, frequencies, splitAttribute, splitPredicates, children);
        }
    }
}
Also used : LinkedHashSet(java.util.LinkedHashSet) ArrayList(java.util.ArrayList) SettingsModelString(org.knime.core.node.defaultnodesettings.SettingsModelString) DecisionTreeNodeSplitPMML(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplitPMML) PMMLSimpleSetPredicate(org.knime.base.node.mine.decisiontree2.PMMLSimpleSetPredicate) PMMLSimplePredicate(org.knime.base.node.mine.decisiontree2.PMMLSimplePredicate) PMMLPredicate(org.knime.base.node.mine.decisiontree2.PMMLPredicate) DecisionTreeNodeLeaf(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeLeaf) DataCell(org.knime.core.data.DataCell) DecisionTreeNode(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode)

Example 28 with DecisionTreeNode

use of org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode in project knime-core by knime.

the class Pruner method mdlPruning.

/**
 * Prunes a {@link DecisionTree} according to the minimum description lenght
 * (MDL) principle.
 *
 * @param decTree the decision tree to prune
 */
public static void mdlPruning(final DecisionTree decTree) {
    // traverse the tree depth first (in-fix)
    DecisionTreeNode root = decTree.getRootNode();
    mdlPruningRecurse(root);
}
Also used : DecisionTreeNode(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode)

Example 29 with DecisionTreeNode

use of org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode in project knime-core by knime.

the class TreeNodeClassification method createDecisionTreeNode.

/**
 * Creates DecisionTreeNode model that is used in Decision Tree of KNIME
 *
 * @param idGenerator
 * @param metaData
 * @return a DecisionTreeNode
 */
public DecisionTreeNode createDecisionTreeNode(final MutableInteger idGenerator, final TreeMetaData metaData) {
    DataCell majorityCell = new StringCell(getMajorityClassName());
    double[] targetDistribution = getTargetDistribution();
    int initSize = (int) (targetDistribution.length / 0.75 + 1.0);
    LinkedHashMap<DataCell, Double> scoreDistributionMap = new LinkedHashMap<DataCell, Double>(initSize);
    NominalValueRepresentation[] targets = getTargetMetaData().getValues();
    for (int i = 0; i < targetDistribution.length; i++) {
        String cl = targets[i].getNominalValue();
        double d = targetDistribution[i];
        scoreDistributionMap.put(new StringCell(cl), d);
    }
    final int nrChildren = getNrChildren();
    if (nrChildren == 0) {
        return new DecisionTreeNodeLeaf(idGenerator.inc(), majorityCell, scoreDistributionMap);
    } else {
        int id = idGenerator.inc();
        DecisionTreeNode[] childNodes = new DecisionTreeNode[nrChildren];
        int splitAttributeIndex = getSplitAttributeIndex();
        assert splitAttributeIndex >= 0 : "non-leaf node has no split";
        String splitAttribute = metaData.getAttributeMetaData(splitAttributeIndex).getAttributeName();
        PMMLPredicate[] childPredicates = new PMMLPredicate[nrChildren];
        for (int i = 0; i < nrChildren; i++) {
            final TreeNodeClassification treeNode = getChild(i);
            TreeNodeCondition cond = treeNode.getCondition();
            childPredicates[i] = cond.toPMMLPredicate();
            childNodes[i] = treeNode.createDecisionTreeNode(idGenerator, metaData);
        }
        return new DecisionTreeNodeSplitPMML(id, majorityCell, scoreDistributionMap, splitAttribute, childPredicates, childNodes);
    }
}
Also used : NominalValueRepresentation(org.knime.base.node.mine.treeensemble.data.NominalValueRepresentation) PMMLPredicate(org.knime.base.node.mine.decisiontree2.PMMLPredicate) LinkedHashMap(java.util.LinkedHashMap) DecisionTreeNodeLeaf(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeLeaf) StringCell(org.knime.core.data.def.StringCell) DecisionTreeNodeSplitPMML(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplitPMML) DataCell(org.knime.core.data.DataCell) DecisionTreeNode(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode)

Example 30 with DecisionTreeNode

use of org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode in project knime-core by knime.

the class TreeEnsembleLearnerNodeView method updateHiLite.

// ///////////////////////////////////////////////////
private void updateHiLite(final boolean state) {
    DecisionTreeNode selected = m_graph.getSelected();
    if (selected == null) {
        return;
    }
    Set<RowKey> covPat = new HashSet<RowKey>();
    covPat.addAll(selected.coveredPattern());
    if (state) {
        m_hiLiteHdl.fireHiLiteEvent(covPat);
    } else {
        m_hiLiteHdl.fireUnHiLiteEvent(covPat);
    }
}
Also used : RowKey(org.knime.core.data.RowKey) DecisionTreeNode(org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode) HashSet(java.util.HashSet)

Aggregations

DecisionTreeNode (org.knime.base.node.mine.decisiontree2.model.DecisionTreeNode)50 RowKey (org.knime.core.data.RowKey)18 HashSet (java.util.HashSet)14 LinkedList (java.util.LinkedList)14 DecisionTreeNodeLeaf (org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeLeaf)12 ArrayList (java.util.ArrayList)10 DecisionTreeNodeSplitPMML (org.knime.base.node.mine.decisiontree2.model.DecisionTreeNodeSplitPMML)10 DataCell (org.knime.core.data.DataCell)9 Action (javax.swing.Action)8 JMenu (javax.swing.JMenu)8 CollapseBranchAction (org.knime.base.node.mine.decisiontree2.view.graph.CollapseBranchAction)8 ExpandBranchAction (org.knime.base.node.mine.decisiontree2.view.graph.ExpandBranchAction)8 PMMLPredicate (org.knime.base.node.mine.decisiontree2.PMMLPredicate)7 PMMLDecisionTreeTranslator (org.knime.base.node.mine.decisiontree2.PMMLDecisionTreeTranslator)5 DecisionTree (org.knime.base.node.mine.decisiontree2.model.DecisionTree)5 TreePath (javax.swing.tree.TreePath)4 PortObject (org.knime.core.node.port.PortObject)4 PMMLPortObject (org.knime.core.node.port.pmml.PMMLPortObject)4 LinkedHashMap (java.util.LinkedHashMap)3 PMMLSimplePredicate (org.knime.base.node.mine.decisiontree2.PMMLSimplePredicate)3