Search in sources :

Example 1 with AbstractTreeNode

use of org.knime.base.node.mine.treeensemble2.model.AbstractTreeNode in project knime-core by knime.

the class TreeEnsembleLearnerNodeView2 method getSubtree.

// private void recreateHiLite() {
// HiLiteHandler handler = m_hiLiteHdl;
// if (handler == null) {
// return;
// }
// Set<RowKey> hilited = handler.getHiLitKeys();
// Set<AbstractTreeNode> toHilite = new HashSet<AbstractTreeNode>();
// AbstractTreeNode root = m_graph.getRootNode();
// 
// List<AbstractTreeNode> toProcess = new LinkedList<AbstractTreeNode>();
// if (null != root) {
// toProcess.add(0, root);
// }
// // Traverse the tree breadth first
// while (!toProcess.isEmpty()) {
// AbstractTreeNode curr = toProcess.remove(0);
// // bug 2695: if not all pattern are selected for hilting, the
// // view will automatically hilite all branches that does not
// // cover any pattern
// if (curr.coveredPattern().isEmpty()) {
// continue;
// }
// if (hilited.containsAll(curr.coveredPattern())) {
// // hilite subtree starting from curr
// toHilite.addAll(getSubtree(curr));
// } else {
// for (int i = 0; i < curr.getChildCount(); i++) {
// toProcess.add(0, curr.getChildAt(i));
// }
// }
// }
// m_graph.hiLite(toHilite);
// }
private List<AbstractTreeNode> getSubtree(final AbstractTreeNode node) {
    List<AbstractTreeNode> subTree = new ArrayList<AbstractTreeNode>();
    List<AbstractTreeNode> toProcess = new LinkedList<AbstractTreeNode>();
    toProcess.add(0, node);
    // Traverse the tree breadth first
    while (!toProcess.isEmpty()) {
        AbstractTreeNode curr = toProcess.remove(0);
        subTree.add(curr);
        for (int i = 0; i < curr.getNrChildren(); i++) {
            toProcess.add(0, curr.getChild(i));
        }
    }
    return subTree;
}
Also used : ArrayList(java.util.ArrayList) AbstractTreeNode(org.knime.base.node.mine.treeensemble2.model.AbstractTreeNode) LinkedList(java.util.LinkedList)

Example 2 with AbstractTreeNode

use of org.knime.base.node.mine.treeensemble2.model.AbstractTreeNode in project knime-core by knime.

the class AbstractTreeModelExporter method recursivelyCheckUsedFields.

private void recursivelyCheckUsedFields(final AbstractTreeNode node, final int numLearnCols, final Set<String> usedLearningFields) {
    if (usedLearningFields.size() == numLearnCols) {
        return;
    }
    TreeNodeCondition cond = node.getCondition();
    addAllFieldsInCondition(cond, usedLearningFields);
    node.getChildren().forEach(c -> recursivelyCheckUsedFields(c, numLearnCols, usedLearningFields));
}
Also used : TreeNodeCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition)

Example 3 with AbstractTreeNode

use of org.knime.base.node.mine.treeensemble2.model.AbstractTreeNode in project knime-core by knime.

the class RandomForestDistance method computeDistance.

/**
 * {@inheritDoc}
 */
@Override
public double computeDistance(final DataRow row1, final DataRow row2) throws DistanceMeasurementException {
    List<Integer> filterIndicesList = getColumnIndices();
    int[] filterIndices = new int[filterIndicesList.size()];
    int i = 0;
    for (Integer index : filterIndicesList) {
        filterIndices[i++] = index;
    }
    final DataRow filterRow1 = new FilterColumnRow(row1, filterIndices);
    final DataRow filterRow2 = new FilterColumnRow(row2, filterIndices);
    final PredictorRecord record1 = m_ensembleModel.createPredictorRecord(filterRow1, m_learnTableSpec);
    final PredictorRecord record2 = m_ensembleModel.createPredictorRecord(filterRow2, m_learnTableSpec);
    final int nrModels = m_ensembleModel.getNrModels();
    double proximity = 0.0;
    for (int t = 0; t < nrModels; t++) {
        AbstractTreeModel<?> tree = m_ensembleModel.getTreeModel(t);
        AbstractTreeNode leaf1 = tree.findMatchingNode(record1);
        AbstractTreeNode leaf2 = tree.findMatchingNode(record2);
        if (leaf1.getSignature().equals(leaf2.getSignature())) {
            proximity += 1.0;
        }
    }
    proximity /= nrModels;
    // to get a distance measure, we have to subtract the proximity from 1
    return 1 - proximity;
}
Also used : PredictorRecord(org.knime.base.node.mine.treeensemble2.data.PredictorRecord) AbstractTreeNode(org.knime.base.node.mine.treeensemble2.model.AbstractTreeNode) DataRow(org.knime.core.data.DataRow) FilterColumnRow(org.knime.base.data.filter.column.FilterColumnRow)

Example 4 with AbstractTreeNode

use of org.knime.base.node.mine.treeensemble2.model.AbstractTreeNode in project knime-core by knime.

the class TreeModelPMMLTranslator method addTreeNode.

/**
 * @param pmmlNode
 * @param node
 */
private void addTreeNode(final Node pmmlNode, final AbstractTreeNode node) {
    int index = m_nodeIndex++;
    pmmlNode.setId(Integer.toString(index));
    if (node instanceof TreeNodeClassification) {
        final TreeNodeClassification clazNode = (TreeNodeClassification) node;
        pmmlNode.setScore(clazNode.getMajorityClassName());
        float[] targetDistribution = clazNode.getTargetDistribution();
        NominalValueRepresentation[] targetVals = clazNode.getTargetMetaData().getValues();
        double sum = 0.0;
        for (Float v : targetDistribution) {
            sum += v;
        }
        pmmlNode.setRecordCount(sum);
        // adding score distribution (class counts)
        for (int i = 0; i < targetDistribution.length; i++) {
            String className = targetVals[i].getNominalValue();
            double freq = targetDistribution[i];
            ScoreDistribution pmmlScoreDist = pmmlNode.addNewScoreDistribution();
            pmmlScoreDist.setValue(className);
            pmmlScoreDist.setRecordCount(freq);
        }
    } else if (node instanceof TreeNodeRegression) {
        final TreeNodeRegression regNode = (TreeNodeRegression) node;
        pmmlNode.setScore(Double.toString(regNode.getMean()));
    }
    TreeNodeCondition condition = node.getCondition();
    if (condition instanceof TreeNodeTrueCondition) {
        pmmlNode.addNewTrue();
    } else if (condition instanceof TreeNodeColumnCondition) {
        final TreeNodeColumnCondition colCondition = (TreeNodeColumnCondition) condition;
        handleColumnCondition(colCondition, pmmlNode);
    } else if (condition instanceof AbstractTreeNodeSurrogateCondition) {
        final AbstractTreeNodeSurrogateCondition surrogateCond = (AbstractTreeNodeSurrogateCondition) condition;
        setValuesFromPMMLCompoundPredicate(pmmlNode.addNewCompoundPredicate(), surrogateCond.toPMMLPredicate());
    } else {
        throw new IllegalStateException("Unsupported condition (not " + "implemented): " + condition.getClass().getSimpleName());
    }
    for (int i = 0; i < node.getNrChildren(); i++) {
        addTreeNode(pmmlNode.addNewNode(), node.getChild(i));
    }
}
Also used : NominalValueRepresentation(org.knime.base.node.mine.treeensemble2.data.NominalValueRepresentation) ScoreDistribution(org.dmg.pmml.ScoreDistributionDocument.ScoreDistribution)

Example 5 with AbstractTreeNode

use of org.knime.base.node.mine.treeensemble2.model.AbstractTreeNode in project knime-core by knime.

the class RandomForestClassificationTreeGraphView method mousePressed.

/**
 * {@inheritDoc}
 */
@Override
public void mousePressed(final MouseEvent e) {
    AbstractTreeNode nodePressed = nodeAtPoint(e.getPoint());
    Dimension preferredSize = null != nodePressed ? getWidgets().get(nodePressed).getPreferredSize() : null;
    super.mousePressed(e);
    // relayout when preferred size of the clicked node has changed
    Dimension preferredSizeAfter = null != nodePressed ? getWidgets().get(nodePressed).getPreferredSize() : null;
    if (null != nodePressed && !preferredSize.equals(preferredSizeAfter)) {
        layoutGraph();
        getView().revalidate();
        getView().repaint();
        // make sure that the clicked node is in the visible area
        getView().scrollRectToVisible(getVisible().get(nodePressed));
    }
}
Also used : AbstractTreeNode(org.knime.base.node.mine.treeensemble2.model.AbstractTreeNode) Dimension(java.awt.Dimension)

Aggregations

AbstractTreeNode (org.knime.base.node.mine.treeensemble2.model.AbstractTreeNode)4 DataRow (org.knime.core.data.DataRow)2 Dimension (java.awt.Dimension)1 ArrayList (java.util.ArrayList)1 LinkedList (java.util.LinkedList)1 ScoreDistribution (org.dmg.pmml.ScoreDistributionDocument.ScoreDistribution)1 FilterColumnRow (org.knime.base.data.filter.column.FilterColumnRow)1 NominalValueRepresentation (org.knime.base.node.mine.treeensemble2.data.NominalValueRepresentation)1 PredictorRecord (org.knime.base.node.mine.treeensemble2.data.PredictorRecord)1 TreeAttributeColumnData (org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData)1 TreeNodeCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition)1 TreeNodeSignature (org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)1 ColumnSample (org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample)1 ColumnSampleStrategy (org.knime.base.node.mine.treeensemble2.sample.column.ColumnSampleStrategy)1 DefaultRow (org.knime.core.data.def.DefaultRow)1 BufferedDataContainer (org.knime.core.node.BufferedDataContainer)1