Search in sources :

Example 16 with TreeNodeCondition

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition in project knime-core by knime.

the class TreeModelImporter method createNodeFromPMML.

private N createNodeFromPMML(final Node pmmlNode, final TreeNodeSignature signature) {
    List<N> children = new ArrayList<>();
    byte i = 0;
    for (Node child : pmmlNode.getNodeList()) {
        TreeNodeSignature childSignature = m_signatureFactory.getChildSignatureFor(signature, i);
        i++;
        children.add(createNodeFromPMML(child, childSignature));
    }
    TreeNodeCondition condition = m_conditionParser.parseCondition(pmmlNode);
    N node = m_contentParser.createNode(pmmlNode, m_metaDataMapper.getTargetColumnHelper(), signature, children);
    node.setTreeNodeCondition(condition);
    return node;
}
Also used : AbstractTreeNode(org.knime.base.node.mine.treeensemble2.model.AbstractTreeNode) Node(org.dmg.pmml.NodeDocument.Node) ArrayList(java.util.ArrayList) TreeNodeCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)

Example 17 with TreeNodeCondition

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition in project knime-core by knime.

the class Surrogates method learnSurrogates.

/**
 * This function searches for splits in the remaining columns of <b>colSample</b>. It is doing so by taking the
 * directions (left or right) that are induced by the <b>bestSplit</b> as new target.
 *
 * @param dataMemberships provides information which rows are in the current branch
 * @param bestSplit the best split for the current node
 * @param oldData the TreeData object that contains all attributes and the target
 * @param colSample provides information which columns are to be considered as surrogates
 * @param config the configuration
 * @param rd
 * @return a SurrogateSplit that contains the conditions for both children
 */
public static SurrogateSplit learnSurrogates(final DataMemberships dataMemberships, final SplitCandidate bestSplit, final TreeData oldData, final ColumnSample colSample, final TreeEnsembleLearnerConfiguration config, final RandomData rd) {
    TreeAttributeColumnData bestSplitCol = bestSplit.getColumnData();
    TreeNodeCondition[] bestSplitChildConditions = bestSplit.getChildConditions();
    // calculate new Target
    BitSet bestSplitLeft = bestSplitCol.updateChildMemberships(bestSplitChildConditions[0], dataMemberships);
    BitSet bestSplitRight = bestSplitCol.updateChildMemberships(bestSplitChildConditions[1], dataMemberships);
    // create DataMemberships that only contains the instances that are not missed by bestSplit
    BitSet surrogateBitSet = (BitSet) bestSplitLeft.clone();
    surrogateBitSet.or(bestSplitRight);
    DataMemberships surrogateCalcDataMemberships = dataMemberships.createChildMemberships(surrogateBitSet);
    TreeTargetNominalColumnData newTarget = createNewTargetColumn(bestSplitLeft, bestSplitRight, oldData.getNrRows(), surrogateCalcDataMemberships);
    // find best splits on new target
    ArrayList<SplitCandidate> candidates = new ArrayList<SplitCandidate>();
    ClassificationPriors newTargetPriors = newTarget.getDistribution(surrogateCalcDataMemberships, config);
    for (TreeAttributeColumnData col : colSample) {
        if (col != bestSplitCol) {
            SplitCandidate candidate = col.calcBestSplitClassification(surrogateCalcDataMemberships, newTargetPriors, newTarget, rd);
            if (candidate != null) {
                candidates.add(candidate);
            }
        }
    }
    SplitCandidate[] candidatesWithBestAtHead = new SplitCandidate[candidates.size() + 1];
    candidatesWithBestAtHead[0] = bestSplit;
    for (int i = 1; i < candidatesWithBestAtHead.length; i++) {
        candidatesWithBestAtHead[i] = candidates.get(i - 1);
    }
    return calculateSurrogates(dataMemberships, candidatesWithBestAtHead);
}
Also used : TreeAttributeColumnData(org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData) BitSet(java.util.BitSet) ArrayList(java.util.ArrayList) TreeNodeCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) TreeTargetNominalColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData) ClassificationPriors(org.knime.base.node.mine.treeensemble2.data.ClassificationPriors)

Example 18 with TreeNodeCondition

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition in project knime-core by knime.

the class TreeLearnerClassification method buildTreeNode.

private TreeNodeClassification buildTreeNode(final ExecutionMonitor exec, final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final TreeNodeSignature treeNodeSignature, final ClassificationPriors targetPriors, final BitSet forbiddenColumnSet) throws CanceledExecutionException {
    final TreeData data = getData();
    final TreeEnsembleLearnerConfiguration config = getConfig();
    exec.checkCanceled();
    final boolean useSurrogates = getConfig().getMissingValueHandling() == MissingValueHandling.Surrogate;
    TreeNodeCondition[] childConditions;
    boolean markAttributeAsForbidden = false;
    final TreeTargetNominalColumnData targetColumn = (TreeTargetNominalColumnData) data.getTargetColumn();
    TreeNodeClassification[] childNodes;
    int attributeIndex = -1;
    if (useSurrogates) {
        SplitCandidate[] candidates = findBestSplitsClassification(currentDepth, dataMemberships, columnSample, treeNodeSignature, targetPriors, forbiddenColumnSet);
        if (candidates == null) {
            return new TreeNodeClassification(treeNodeSignature, targetPriors, config);
        }
        SurrogateSplit surrogateSplit = Surrogates.learnSurrogates(dataMemberships, candidates[0], data, columnSample, config, getRandomData());
        childConditions = surrogateSplit.getChildConditions();
        BitSet[] childMarkers = surrogateSplit.getChildMarkers();
        childNodes = new TreeNodeClassification[2];
        for (int i = 0; i < 2; i++) {
            DataMemberships childMemberships = dataMemberships.createChildMemberships(childMarkers[i]);
            ClassificationPriors childTargetPriors = targetColumn.getDistribution(childMemberships, config);
            TreeNodeSignature childSignature = getSignatureFactory().getChildSignatureFor(treeNodeSignature, (byte) i);
            ColumnSample childColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(childSignature);
            childNodes[i] = buildTreeNode(exec, currentDepth + 1, childMemberships, childColumnSample, childSignature, childTargetPriors, forbiddenColumnSet);
            childNodes[i].setTreeNodeCondition(childConditions[i]);
        }
    } else {
        // handle non surrogate case
        SplitCandidate bestSplit = findBestSplitClassification(currentDepth, dataMemberships, columnSample, treeNodeSignature, targetPriors, forbiddenColumnSet);
        if (bestSplit == null) {
            return new TreeNodeClassification(treeNodeSignature, targetPriors, config);
        }
        TreeAttributeColumnData splitColumn = bestSplit.getColumnData();
        attributeIndex = splitColumn.getMetaData().getAttributeIndex();
        markAttributeAsForbidden = !bestSplit.canColumnBeSplitFurther();
        forbiddenColumnSet.set(attributeIndex, markAttributeAsForbidden);
        childConditions = bestSplit.getChildConditions();
        childNodes = new TreeNodeClassification[childConditions.length];
        if (childConditions.length > Short.MAX_VALUE) {
            throw new RuntimeException("Too many children when splitting " + "attribute " + bestSplit.getColumnData() + " (maximum supported: " + Short.MAX_VALUE + "): " + childConditions.length);
        }
        // Build child nodes
        for (int i = 0; i < childConditions.length; i++) {
            DataMemberships childMemberships = null;
            TreeNodeCondition cond = childConditions[i];
            childMemberships = dataMemberships.createChildMemberships(splitColumn.updateChildMemberships(cond, dataMemberships));
            ClassificationPriors childTargetPriors = targetColumn.getDistribution(childMemberships, config);
            TreeNodeSignature childSignature = treeNodeSignature.createChildSignature((byte) i);
            ColumnSample childColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(childSignature);
            childNodes[i] = buildTreeNode(exec, currentDepth + 1, childMemberships, childColumnSample, childSignature, childTargetPriors, forbiddenColumnSet);
            childNodes[i].setTreeNodeCondition(cond);
        }
    }
    if (markAttributeAsForbidden) {
        forbiddenColumnSet.set(attributeIndex, false);
    }
    return new TreeNodeClassification(treeNodeSignature, targetPriors, childNodes, getConfig());
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) TreeNodeClassification(org.knime.base.node.mine.treeensemble2.model.TreeNodeClassification) TreeAttributeColumnData(org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData) ColumnSample(org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample) BitSet(java.util.BitSet) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) TreeNodeCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition) TreeTargetNominalColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData) ClassificationPriors(org.knime.base.node.mine.treeensemble2.data.ClassificationPriors)

Example 19 with TreeNodeCondition

use of org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition in project knime-core by knime.

the class TreeLearnerRegression method buildTreeNode.

private TreeNodeRegression buildTreeNode(final ExecutionMonitor exec, final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final TreeNodeSignature treeNodeSignature, final RegressionPriors targetPriors, final BitSet forbiddenColumnSet) throws CanceledExecutionException {
    final TreeData data = getData();
    final RandomData rd = getRandomData();
    final TreeEnsembleLearnerConfiguration config = getConfig();
    exec.checkCanceled();
    final SplitCandidate candidate = findBestSplitRegression(currentDepth, dataMemberships, columnSample, targetPriors, forbiddenColumnSet);
    if (candidate == null) {
        if (config instanceof GradientBoostingLearnerConfiguration) {
            TreeNodeRegression leaf = new TreeNodeRegression(treeNodeSignature, targetPriors, dataMemberships.getOriginalIndices());
            addToLeafList(leaf);
            return leaf;
        }
        return new TreeNodeRegression(treeNodeSignature, targetPriors);
    }
    final TreeTargetNumericColumnData targetColumn = (TreeTargetNumericColumnData) data.getTargetColumn();
    boolean useSurrogates = config.getMissingValueHandling() == MissingValueHandling.Surrogate;
    TreeNodeCondition[] childConditions;
    TreeNodeRegression[] childNodes;
    if (useSurrogates) {
        SurrogateSplit surrogateSplit = Surrogates.learnSurrogates(dataMemberships, candidate, data, columnSample, config, rd);
        childConditions = surrogateSplit.getChildConditions();
        BitSet[] childMarkers = surrogateSplit.getChildMarkers();
        assert childMarkers[0].cardinality() + childMarkers[1].cardinality() == dataMemberships.getRowCount() : "Sum of rows in children does not add up to number of rows in parent.";
        childNodes = new TreeNodeRegression[2];
        for (int i = 0; i < 2; i++) {
            DataMemberships childMemberships = dataMemberships.createChildMemberships(childMarkers[i]);
            TreeNodeSignature childSignature = getSignatureFactory().getChildSignatureFor(treeNodeSignature, (byte) i);
            ColumnSample childColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(childSignature);
            RegressionPriors childTargetPriors = targetColumn.getPriors(childMemberships, config);
            childNodes[i] = buildTreeNode(exec, currentDepth + 1, childMemberships, childColumnSample, childSignature, childTargetPriors, forbiddenColumnSet);
            childNodes[i].setTreeNodeCondition(childConditions[i]);
        }
    } else {
        SplitCandidate bestSplit = candidate;
        TreeAttributeColumnData splitColumn = bestSplit.getColumnData();
        final int attributeIndex = splitColumn.getMetaData().getAttributeIndex();
        boolean markAttributeAsForbidden = !bestSplit.canColumnBeSplitFurther();
        forbiddenColumnSet.set(attributeIndex, markAttributeAsForbidden);
        childConditions = bestSplit.getChildConditions();
        if (childConditions.length > Short.MAX_VALUE) {
            throw new RuntimeException("Too many children when splitting " + "attribute " + bestSplit.getColumnData() + " (maximum supported: " + Short.MAX_VALUE + "): " + childConditions.length);
        }
        childNodes = new TreeNodeRegression[childConditions.length];
        for (int i = 0; i < childConditions.length; i++) {
            TreeNodeCondition cond = childConditions[i];
            DataMemberships childMemberships = dataMemberships.createChildMemberships(splitColumn.updateChildMemberships(cond, dataMemberships));
            RegressionPriors childTargetPriors = targetColumn.getPriors(childMemberships, config);
            TreeNodeSignature childSignature = treeNodeSignature.createChildSignature((byte) i);
            ColumnSample childColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(childSignature);
            childNodes[i] = buildTreeNode(exec, currentDepth + 1, childMemberships, childColumnSample, childSignature, childTargetPriors, forbiddenColumnSet);
            childNodes[i].setTreeNodeCondition(cond);
        }
        if (markAttributeAsForbidden) {
            forbiddenColumnSet.set(attributeIndex, false);
        }
    }
    return new TreeNodeRegression(treeNodeSignature, targetPriors, childNodes);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RandomData(org.apache.commons.math.random.RandomData) TreeAttributeColumnData(org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData) ColumnSample(org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample) RegressionPriors(org.knime.base.node.mine.treeensemble2.data.RegressionPriors) BitSet(java.util.BitSet) TreeTargetNumericColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeNodeRegression(org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) GradientBoostingLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.gradientboosting.learner.GradientBoostingLearnerConfiguration) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) TreeNodeCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition)

Aggregations

TreeNodeCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition)13 BitSet (java.util.BitSet)9 TreeAttributeColumnData (org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData)5 DataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships)5 TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)5 ArrayList (java.util.ArrayList)4 RootDataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships)4 TreeNodeNumericCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeNumericCondition)4 RandomData (org.apache.commons.math.random.RandomData)3 Test (org.junit.Test)3 NominalValueRepresentation (org.knime.base.node.mine.treeensemble2.data.NominalValueRepresentation)3 TreeNodeColumnCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeColumnCondition)3 TreeNodeSignature (org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)3 ClassificationPriors (org.knime.base.node.mine.treeensemble2.data.ClassificationPriors)2 TreeData (org.knime.base.node.mine.treeensemble2.data.TreeData)2 TreeTargetNominalColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData)2 ColumnMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.ColumnMemberships)2 NumericMissingSplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NumericMissingSplitCandidate)2 NumericSplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NumericSplitCandidate)2 SplitCandidate (org.knime.base.node.mine.treeensemble2.learner.SplitCandidate)2