Search in sources :

Example 1 with TreeNodeCondition

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition in project knime-core by knime.

the class RandomForestClassificationTreeNodeWidget method getConnectorLabelBelow.

/**
 * {@inheritDoc}
 */
@Override
public String getConnectorLabelBelow() {
    TreeNodeClassification node = (TreeNodeClassification) getUserObject();
    if (node.getNrChildren() != 0) {
        TreeNodeClassification child = node.getChild(0);
        TreeNodeCondition childCondition = child.getCondition();
        if (childCondition instanceof TreeNodeColumnCondition) {
            return ((TreeNodeColumnCondition) childCondition).getAttributeName();
        } else if (childCondition instanceof TreeNodeSurrogateCondition) {
            TreeNodeSurrogateCondition surrogateCondition = (TreeNodeSurrogateCondition) childCondition;
            TreeNodeCondition headCondition = surrogateCondition.getFirstCondition();
            if (headCondition instanceof TreeNodeColumnCondition) {
                return ((TreeNodeColumnCondition) headCondition).getAttributeName();
            }
        }
    }
    return null;
}
Also used : TreeNodeClassification(org.knime.base.node.mine.treeensemble2.model.TreeNodeClassification) TreeNodeSurrogateCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeSurrogateCondition) TreeNodeColumnCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeColumnCondition) TreeNodeCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition)

Example 2 with TreeNodeCondition

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition in project knime-core by knime.

the class AbstractTreeModelExporter method recursivelyCheckUsedFields.

private void recursivelyCheckUsedFields(final AbstractTreeNode node, final int numLearnCols, final Set<String> usedLearningFields) {
    if (usedLearningFields.size() == numLearnCols) {
        return;
    }
    TreeNodeCondition cond = node.getCondition();
    addAllFieldsInCondition(cond, usedLearningFields);
    node.getChildren().forEach(c -> recursivelyCheckUsedFields(c, numLearnCols, usedLearningFields));
}
Also used : TreeNodeCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition)

Example 3 with TreeNodeCondition

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition in project knime-core by knime.

the class BitSplitCandidate method getChildConditions.

/**
 * {@inheritDoc}
 */
@Override
public TreeNodeCondition[] getChildConditions() {
    TreeBitColumnMetaData metaData = getColumnData().getMetaData();
    TreeNodeCondition[] result = new TreeNodeCondition[2];
    result[0] = new TreeNodeBitCondition(metaData, true);
    result[1] = new TreeNodeBitCondition(metaData, false);
    return result;
}
Also used : TreeBitColumnMetaData(org.knime.base.node.mine.treeensemble2.data.TreeBitColumnMetaData) TreeNodeCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition) TreeNodeBitCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeBitCondition)

Example 4 with TreeNodeCondition

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition in project knime-core by knime.

the class NominalMultiwaySplitCandidate method getChildConditions.

/**
 * {@inheritDoc}
 */
@Override
public TreeNodeNominalCondition[] getChildConditions() {
    TreeNominalColumnMetaData columnMeta = getColumnData().getMetaData();
    NominalValueRepresentation[] values = columnMeta.getValues();
    final int lengthNonMissing = values[values.length - 1].getNominalValue().equals(NominalValueRepresentation.MISSING_VALUE) ? values.length - 1 : values.length;
    List<TreeNodeCondition> resultList = new ArrayList<TreeNodeCondition>(lengthNonMissing);
    for (int i = 0; i < lengthNonMissing; i++) {
        if (m_sumWeightsAttributes[i] >= TreeColumnData.EPSILON) {
            resultList.add(new TreeNodeNominalCondition(columnMeta, i, i == m_missingsGoToChildIdx));
        }
    }
    return resultList.toArray(new TreeNodeNominalCondition[resultList.size()]);
}
Also used : TreeNominalColumnMetaData(org.knime.base.node.mine.treeensemble2.data.TreeNominalColumnMetaData) TreeNodeNominalCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalCondition) ArrayList(java.util.ArrayList) NominalValueRepresentation(org.knime.base.node.mine.treeensemble2.data.NominalValueRepresentation) TreeNodeCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition)

Example 5 with TreeNodeCondition

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition in project knime-core by knime.

the class TreeModelPMMLTranslator method addTreeNode.

/**
 * @param pmmlNode
 * @param node
 */
private void addTreeNode(final Node pmmlNode, final AbstractTreeNode node) {
    int index = m_nodeIndex++;
    pmmlNode.setId(Integer.toString(index));
    if (node instanceof TreeNodeClassification) {
        final TreeNodeClassification clazNode = (TreeNodeClassification) node;
        pmmlNode.setScore(clazNode.getMajorityClassName());
        float[] targetDistribution = clazNode.getTargetDistribution();
        NominalValueRepresentation[] targetVals = clazNode.getTargetMetaData().getValues();
        double sum = 0.0;
        for (Float v : targetDistribution) {
            sum += v;
        }
        pmmlNode.setRecordCount(sum);
        // adding score distribution (class counts)
        for (int i = 0; i < targetDistribution.length; i++) {
            String className = targetVals[i].getNominalValue();
            double freq = targetDistribution[i];
            ScoreDistribution pmmlScoreDist = pmmlNode.addNewScoreDistribution();
            pmmlScoreDist.setValue(className);
            pmmlScoreDist.setRecordCount(freq);
        }
    } else if (node instanceof TreeNodeRegression) {
        final TreeNodeRegression regNode = (TreeNodeRegression) node;
        pmmlNode.setScore(Double.toString(regNode.getMean()));
    }
    TreeNodeCondition condition = node.getCondition();
    if (condition instanceof TreeNodeTrueCondition) {
        pmmlNode.addNewTrue();
    } else if (condition instanceof TreeNodeColumnCondition) {
        final TreeNodeColumnCondition colCondition = (TreeNodeColumnCondition) condition;
        handleColumnCondition(colCondition, pmmlNode);
    } else if (condition instanceof AbstractTreeNodeSurrogateCondition) {
        final AbstractTreeNodeSurrogateCondition surrogateCond = (AbstractTreeNodeSurrogateCondition) condition;
        setValuesFromPMMLCompoundPredicate(pmmlNode.addNewCompoundPredicate(), surrogateCond.toPMMLPredicate());
    } else {
        throw new IllegalStateException("Unsupported condition (not " + "implemented): " + condition.getClass().getSimpleName());
    }
    for (int i = 0; i < node.getNrChildren(); i++) {
        addTreeNode(pmmlNode.addNewNode(), node.getChild(i));
    }
}
Also used : NominalValueRepresentation(org.knime.base.node.mine.treeensemble2.data.NominalValueRepresentation) ScoreDistribution(org.dmg.pmml.ScoreDistributionDocument.ScoreDistribution)

Aggregations

TreeNodeCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition)13 BitSet (java.util.BitSet)9 TreeAttributeColumnData (org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData)5 DataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships)5 TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)5 ArrayList (java.util.ArrayList)4 RootDataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships)4 TreeNodeNumericCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeNumericCondition)4 RandomData (org.apache.commons.math.random.RandomData)3 Test (org.junit.Test)3 NominalValueRepresentation (org.knime.base.node.mine.treeensemble2.data.NominalValueRepresentation)3 TreeNodeColumnCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeColumnCondition)3 TreeNodeSignature (org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)3 ClassificationPriors (org.knime.base.node.mine.treeensemble2.data.ClassificationPriors)2 TreeData (org.knime.base.node.mine.treeensemble2.data.TreeData)2 TreeTargetNominalColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData)2 ColumnMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.ColumnMemberships)2 NumericMissingSplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NumericMissingSplitCandidate)2 NumericSplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NumericSplitCandidate)2 SplitCandidate (org.knime.base.node.mine.treeensemble2.learner.SplitCandidate)2