Search in sources :

Example 41 with DataMemberships

use of org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships in project knime-core by knime.

the class TreeNumericColumnData method calcBestSplitRegression.

@Override
public SplitCandidate calcBestSplitRegression(final DataMemberships dataMemberships, final RegressionPriors targetPriors, final TreeTargetNumericColumnData targetColumn, final RandomData rd) {
    final TreeEnsembleLearnerConfiguration config = getConfiguration();
    final boolean useAverageSplitPoints = config.isUseAverageSplitPoints();
    final int minChildNodeSize = config.getMinChildSize();
    // get columnMemberships
    final ColumnMemberships columnMemberships = dataMemberships.getColumnMemberships(getMetaData().getAttributeIndex());
    final int lengthNonMissing = getLengthNonMissing();
    // missing value handling
    final boolean useXGBoostMissingValueHandling = config.getMissingValueHandling() == MissingValueHandling.XGBoost;
    // are there missing values in this column (complete column)
    boolean branchContainsMissingValues = containsMissingValues();
    boolean missingsGoLeft = true;
    double missingWeight = 0.0;
    double missingY = 0.0;
    // check if there are missing values in this rowsample
    if (branchContainsMissingValues) {
        columnMemberships.goToLast();
        while (columnMemberships.getIndexInColumn() >= lengthNonMissing) {
            missingWeight += columnMemberships.getRowWeight();
            missingY += targetColumn.getValueFor(columnMemberships.getOriginalIndex());
            if (!columnMemberships.previous()) {
                break;
            }
        }
        columnMemberships.reset();
        branchContainsMissingValues = missingWeight > 0.0;
    }
    final double ySumTotal = targetPriors.getYSum() - missingY;
    final double nrRecordsTotal = targetPriors.getNrRecords() - missingWeight;
    final double criterionTotal = useXGBoostMissingValueHandling ? (ySumTotal + missingY) * (ySumTotal + missingY) / (nrRecordsTotal + missingWeight) : ySumTotal * ySumTotal / nrRecordsTotal;
    double ySumLeft = 0.0;
    double nrRecordsLeft = 0.0;
    double ySumRight = ySumTotal;
    double nrRecordsRight = nrRecordsTotal;
    // all values in the current branch are missing
    if (nrRecordsRight == 0) {
        // it is impossible to determine a split
        return null;
    }
    double bestSplit = Double.NEGATIVE_INFINITY;
    double bestImprovement = 0.0;
    double lastSeenY = Double.NaN;
    double lastSeenValue = Double.NEGATIVE_INFINITY;
    double lastSeenWeight = -1.0;
    // compute the gain, keep the one that maximizes the split
    while (columnMemberships.next()) {
        final double weight = columnMemberships.getRowWeight();
        if (weight < EPSILON) {
            // ignore record: not in current branch or not in sample
            continue;
        } else if (Math.floor(weight) != weight) {
            throw new UnsupportedOperationException("weighted records (missing values?) not supported, " + "weight is " + weight);
        }
        final double value = getSorted(columnMemberships.getIndexInColumn());
        if (lastSeenWeight > 0.0) {
            ySumLeft += lastSeenWeight * lastSeenY;
            ySumRight -= lastSeenWeight * lastSeenY;
            nrRecordsLeft += lastSeenWeight;
            nrRecordsRight -= lastSeenWeight;
            if (nrRecordsLeft >= minChildNodeSize && nrRecordsRight >= minChildNodeSize && lastSeenValue < value) {
                boolean tempMissingsGoLeft = true;
                double childrenSquaredSum;
                if (branchContainsMissingValues && useXGBoostMissingValueHandling) {
                    final double[] tempChildrenSquaredSum = new double[2];
                    tempChildrenSquaredSum[0] = ((ySumLeft + missingY) * (ySumLeft + missingY) / (nrRecordsLeft + missingWeight)) + (ySumRight * ySumRight / nrRecordsRight);
                    tempChildrenSquaredSum[1] = (ySumLeft * ySumLeft / nrRecordsLeft) + ((ySumRight + missingY) * (ySumRight + missingY) / (nrRecordsRight + missingWeight));
                    if (tempChildrenSquaredSum[0] >= tempChildrenSquaredSum[1]) {
                        childrenSquaredSum = tempChildrenSquaredSum[0];
                        tempMissingsGoLeft = true;
                    } else {
                        childrenSquaredSum = tempChildrenSquaredSum[1];
                        tempMissingsGoLeft = false;
                    }
                } else {
                    childrenSquaredSum = (ySumLeft * ySumLeft / nrRecordsLeft) + (ySumRight * ySumRight / nrRecordsRight);
                }
                double criterion = childrenSquaredSum - criterionTotal;
                boolean randomTieBreaker = criterion == bestImprovement ? rd.nextInt(0, 1) == 1 : false;
                if (criterion > bestImprovement || randomTieBreaker) {
                    bestImprovement = criterion;
                    bestSplit = useAverageSplitPoints ? getCenter(lastSeenValue, value) : lastSeenValue;
                    // if there are no missing values go with majority
                    missingsGoLeft = branchContainsMissingValues ? tempMissingsGoLeft : nrRecordsLeft >= nrRecordsRight;
                }
            }
        }
        lastSeenY = targetColumn.getValueFor(columnMemberships.getOriginalIndex());
        lastSeenValue = value;
        lastSeenWeight = weight;
    }
    // + " but was " + lastSeenY * lastSeenWeight;
    if (bestImprovement > 0.0) {
        if (useXGBoostMissingValueHandling) {
            // return new NumericMissingSplitCandidate(this, bestSplit, bestImprovement, missingsGoLeft);
            return new NumericSplitCandidate(this, bestSplit, bestImprovement, new BitSet(), missingsGoLeft ? NumericSplitCandidate.MISSINGS_GO_LEFT : NumericSplitCandidate.MISSINGS_GO_RIGHT);
        }
        return new NumericSplitCandidate(this, bestSplit, bestImprovement, getMissedRows(columnMemberships), NumericSplitCandidate.NO_MISSINGS);
    } else {
        return null;
    }
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) NumericSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NumericSplitCandidate) BitSet(java.util.BitSet) ColumnMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.ColumnMemberships)

Example 42 with DataMemberships

use of org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships in project knime-core by knime.

the class TreeNumericColumnData method calcBestSplitClassification.

@Override
public NumericSplitCandidate calcBestSplitClassification(final DataMemberships dataMemberships, final ClassificationPriors targetPriors, final TreeTargetNominalColumnData targetColumn, final RandomData rd) {
    final TreeEnsembleLearnerConfiguration config = getConfiguration();
    final NominalValueRepresentation[] targetVals = targetColumn.getMetaData().getValues();
    final boolean useAverageSplitPoints = config.isUseAverageSplitPoints();
    final int minChildNodeSize = config.getMinChildSize();
    // distribution of target for each attribute value
    final int targetCounts = targetVals.length;
    final double[] targetCountsLeftOfSplit = new double[targetCounts];
    final double[] targetCountsRightOfSplit = targetPriors.getDistribution().clone();
    assert targetCountsRightOfSplit.length == targetCounts;
    final double totalSumWeight = targetPriors.getNrRecords();
    final IImpurity impurityCriterion = targetPriors.getImpurityCriterion();
    final boolean useXGBoostMissingValueHandling = config.getMissingValueHandling() == MissingValueHandling.XGBoost;
    // get columnMemberships
    final ColumnMemberships columnMemberships = dataMemberships.getColumnMemberships(getMetaData().getAttributeIndex());
    // missing value handling
    boolean branchContainsMissingValues = containsMissingValues();
    boolean missingsGoLeft = true;
    final int lengthNonMissing = getLengthNonMissing();
    final double[] missingTargetCounts = new double[targetCounts];
    int lastValidSplitPosition = -1;
    double missingWeight = 0;
    columnMemberships.goToLast();
    do {
        final int indexInColumn = columnMemberships.getIndexInColumn();
        if (indexInColumn >= lengthNonMissing) {
            final double weight = columnMemberships.getRowWeight();
            final int classIdx = targetColumn.getValueFor(columnMemberships.getOriginalIndex());
            targetCountsRightOfSplit[classIdx] -= weight;
            missingTargetCounts[classIdx] += weight;
            missingWeight += weight;
        } else {
            if (lastValidSplitPosition < 0) {
                lastValidSplitPosition = indexInColumn;
            } else if ((getSorted(lastValidSplitPosition) - getSorted(indexInColumn)) >= EPSILON) {
                break;
            } else {
                lastValidSplitPosition = indexInColumn;
            }
        }
    } while (columnMemberships.previous());
    // it is possible that the column contains missing values but in the current branch there are no missing values
    branchContainsMissingValues = missingWeight > 0.0;
    columnMemberships.reset();
    double sumWeightsLeftOfSplit = 0.0;
    double sumWeightsRightOfSplit = totalSumWeight - missingWeight;
    final double priorImpurity = useXGBoostMissingValueHandling || !branchContainsMissingValues ? targetPriors.getPriorImpurity() : impurityCriterion.getPartitionImpurity(TreeNominalColumnData.subtractMissingClassCounts(targetPriors.getDistribution(), missingTargetCounts), sumWeightsRightOfSplit);
    // all values in branch are missing
    if (sumWeightsRightOfSplit == 0) {
        // it is impossible to determine a split
        return null;
    }
    double bestSplit = Double.NEGATIVE_INFINITY;
    // gain for best split point, unnormalized (not using info gain ratio)
    double bestGain = Double.NEGATIVE_INFINITY;
    // gain for best split, normalized by attribute entropy when
    // info gain ratio is used.
    double bestGainValueForSplit = Double.NEGATIVE_INFINITY;
    final double[] tempArray1 = new double[2];
    double[] tempArray2 = new double[2];
    double lastSeenValue = Double.NEGATIVE_INFINITY;
    boolean mustTestOnNextValueChange = false;
    boolean testSplitOnStart = true;
    boolean firstIteration = true;
    int lastSeenTarget = -1;
    int indexInCol = -1;
    // We iterate over the instances in the sample/branch instead of the whole data set
    while (columnMemberships.next() && (indexInCol = columnMemberships.getIndexInColumn()) < lengthNonMissing) {
        final double weight = columnMemberships.getRowWeight();
        assert weight >= EPSILON : "Rows with zero row weight should never be seen!";
        final double value = getSorted(indexInCol);
        final int target = targetColumn.getValueFor(columnMemberships.getOriginalIndex());
        final boolean hasValueChanged = (value - lastSeenValue) >= EPSILON;
        final boolean hasTargetChanged = lastSeenTarget != target || indexInCol == lastValidSplitPosition;
        if (hasTargetChanged && !firstIteration) {
            mustTestOnNextValueChange = true;
            testSplitOnStart = false;
        }
        if (!firstIteration && hasValueChanged && (mustTestOnNextValueChange || testSplitOnStart) && sumWeightsLeftOfSplit >= minChildNodeSize && sumWeightsRightOfSplit >= minChildNodeSize) {
            double postSplitImpurity;
            boolean tempMissingsGoLeft = false;
            // missing value handling
            if (branchContainsMissingValues && useXGBoostMissingValueHandling) {
                final double[] targetCountsLeftPlusMissing = new double[targetCounts];
                final double[] targetCountsRightPlusMissing = new double[targetCounts];
                for (int i = 0; i < targetCounts; i++) {
                    targetCountsLeftPlusMissing[i] = targetCountsLeftOfSplit[i] + missingTargetCounts[i];
                    targetCountsRightPlusMissing[i] = targetCountsRightOfSplit[i] + missingTargetCounts[i];
                }
                final double[][] temp = new double[2][2];
                final double[] postSplitImpurities = new double[2];
                // send all missing values left
                tempArray1[0] = impurityCriterion.getPartitionImpurity(targetCountsLeftPlusMissing, sumWeightsLeftOfSplit + missingWeight);
                tempArray1[1] = impurityCriterion.getPartitionImpurity(targetCountsRightOfSplit, sumWeightsRightOfSplit);
                temp[0][0] = sumWeightsLeftOfSplit + missingWeight;
                temp[0][1] = sumWeightsRightOfSplit;
                postSplitImpurities[0] = impurityCriterion.getPostSplitImpurity(tempArray1, temp[0], totalSumWeight);
                // send all missing values right
                tempArray1[0] = impurityCriterion.getPartitionImpurity(targetCountsLeftOfSplit, sumWeightsLeftOfSplit);
                tempArray1[1] = impurityCriterion.getPartitionImpurity(targetCountsRightPlusMissing, sumWeightsRightOfSplit + missingWeight);
                temp[1][0] = sumWeightsLeftOfSplit;
                temp[1][1] = sumWeightsRightOfSplit + missingWeight;
                postSplitImpurities[1] = impurityCriterion.getPostSplitImpurity(tempArray1, temp[1], totalSumWeight);
                // take better split
                if (postSplitImpurities[0] < postSplitImpurities[1]) {
                    postSplitImpurity = postSplitImpurities[0];
                    tempArray2 = temp[0];
                    tempMissingsGoLeft = true;
                // TODO random tie breaker
                } else {
                    postSplitImpurity = postSplitImpurities[1];
                    tempArray2 = temp[1];
                    tempMissingsGoLeft = false;
                }
            } else {
                tempArray1[0] = impurityCriterion.getPartitionImpurity(targetCountsLeftOfSplit, sumWeightsLeftOfSplit);
                tempArray1[1] = impurityCriterion.getPartitionImpurity(targetCountsRightOfSplit, sumWeightsRightOfSplit);
                tempArray2[0] = sumWeightsLeftOfSplit;
                tempArray2[1] = sumWeightsRightOfSplit;
                postSplitImpurity = impurityCriterion.getPostSplitImpurity(tempArray1, tempArray2, totalSumWeight);
            }
            if (postSplitImpurity < priorImpurity) {
                // Use absolute gain (IG) for split calculation even
                // if the split criterion is information gain ratio (IGR).
                // IGR wouldn't work as it favors extreme unfair splits,
                // i.e. 1:9999 would have an attribute entropy
                // (IGR denominator) of
                // 9999/10000*log(9999/10000) + 1/10000*log(1/10000)
                // which is ~0.00148
                double gain = (priorImpurity - postSplitImpurity);
                boolean randomTieBreaker = gain == bestGain ? rd.nextInt(0, 1) == 1 : false;
                if (gain > bestGain || randomTieBreaker) {
                    bestGainValueForSplit = impurityCriterion.getGain(priorImpurity, postSplitImpurity, tempArray2, totalSumWeight);
                    bestGain = gain;
                    bestSplit = useAverageSplitPoints ? getCenter(lastSeenValue, value) : lastSeenValue;
                    // Go with the majority if there are no missing values during training this is because we should
                    // still provide a missing direction for the case that there are missing values during prediction
                    missingsGoLeft = branchContainsMissingValues ? tempMissingsGoLeft : sumWeightsLeftOfSplit > sumWeightsRightOfSplit;
                }
            }
            mustTestOnNextValueChange = false;
        }
        targetCountsLeftOfSplit[target] += weight;
        sumWeightsLeftOfSplit += weight;
        targetCountsRightOfSplit[target] -= weight;
        sumWeightsRightOfSplit -= weight;
        lastSeenTarget = target;
        lastSeenValue = value;
        firstIteration = false;
    }
    columnMemberships.reset();
    if (bestGainValueForSplit < 0.0) {
        // (see info gain ratio implementation)
        return null;
    }
    if (useXGBoostMissingValueHandling) {
        // return new NumericMissingSplitCandidate(this, bestSplit, bestGainValueForSplit, missingsGoLeft);
        return new NumericSplitCandidate(this, bestSplit, bestGainValueForSplit, new BitSet(), missingsGoLeft ? NumericSplitCandidate.MISSINGS_GO_LEFT : NumericSplitCandidate.MISSINGS_GO_RIGHT);
    }
    return new NumericSplitCandidate(this, bestSplit, bestGainValueForSplit, getMissedRows(columnMemberships), NumericSplitCandidate.NO_MISSINGS);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) NumericSplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NumericSplitCandidate) BitSet(java.util.BitSet) ColumnMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.ColumnMemberships) IImpurity(org.knime.base.node.mine.treeensemble2.learner.IImpurity)

Example 43 with DataMemberships

use of org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships in project knime-core by knime.

the class Surrogates method learnSurrogates.

/**
 * This function searches for splits in the remaining columns of <b>colSample</b>. It is doing so by taking the
 * directions (left or right) that are induced by the <b>bestSplit</b> as new target.
 *
 * @param dataMemberships provides information which rows are in the current branch
 * @param bestSplit the best split for the current node
 * @param oldData the TreeData object that contains all attributes and the target
 * @param colSample provides information which columns are to be considered as surrogates
 * @param config the configuration
 * @param rd
 * @return a SurrogateSplit that contains the conditions for both children
 */
public static SurrogateSplit learnSurrogates(final DataMemberships dataMemberships, final SplitCandidate bestSplit, final TreeData oldData, final ColumnSample colSample, final TreeEnsembleLearnerConfiguration config, final RandomData rd) {
    TreeAttributeColumnData bestSplitCol = bestSplit.getColumnData();
    TreeNodeCondition[] bestSplitChildConditions = bestSplit.getChildConditions();
    // calculate new Target
    BitSet bestSplitLeft = bestSplitCol.updateChildMemberships(bestSplitChildConditions[0], dataMemberships);
    BitSet bestSplitRight = bestSplitCol.updateChildMemberships(bestSplitChildConditions[1], dataMemberships);
    // create DataMemberships that only contains the instances that are not missed by bestSplit
    BitSet surrogateBitSet = (BitSet) bestSplitLeft.clone();
    surrogateBitSet.or(bestSplitRight);
    DataMemberships surrogateCalcDataMemberships = dataMemberships.createChildMemberships(surrogateBitSet);
    TreeTargetNominalColumnData newTarget = createNewTargetColumn(bestSplitLeft, bestSplitRight, oldData.getNrRows(), surrogateCalcDataMemberships);
    // find best splits on new target
    ArrayList<SplitCandidate> candidates = new ArrayList<SplitCandidate>();
    ClassificationPriors newTargetPriors = newTarget.getDistribution(surrogateCalcDataMemberships, config);
    for (TreeAttributeColumnData col : colSample) {
        if (col != bestSplitCol) {
            SplitCandidate candidate = col.calcBestSplitClassification(surrogateCalcDataMemberships, newTargetPriors, newTarget, rd);
            if (candidate != null) {
                candidates.add(candidate);
            }
        }
    }
    SplitCandidate[] candidatesWithBestAtHead = new SplitCandidate[candidates.size() + 1];
    candidatesWithBestAtHead[0] = bestSplit;
    for (int i = 1; i < candidatesWithBestAtHead.length; i++) {
        candidatesWithBestAtHead[i] = candidates.get(i - 1);
    }
    return calculateSurrogates(dataMemberships, candidatesWithBestAtHead);
}
Also used : TreeAttributeColumnData(org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData) BitSet(java.util.BitSet) ArrayList(java.util.ArrayList) TreeNodeCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) TreeTargetNominalColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData) ClassificationPriors(org.knime.base.node.mine.treeensemble2.data.ClassificationPriors)

Example 44 with DataMemberships

use of org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships in project knime-core by knime.

the class TreeLearnerClassification method buildTreeNode.

private TreeNodeClassification buildTreeNode(final ExecutionMonitor exec, final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final TreeNodeSignature treeNodeSignature, final ClassificationPriors targetPriors, final BitSet forbiddenColumnSet) throws CanceledExecutionException {
    final TreeData data = getData();
    final TreeEnsembleLearnerConfiguration config = getConfig();
    exec.checkCanceled();
    final boolean useSurrogates = getConfig().getMissingValueHandling() == MissingValueHandling.Surrogate;
    TreeNodeCondition[] childConditions;
    boolean markAttributeAsForbidden = false;
    final TreeTargetNominalColumnData targetColumn = (TreeTargetNominalColumnData) data.getTargetColumn();
    TreeNodeClassification[] childNodes;
    int attributeIndex = -1;
    if (useSurrogates) {
        SplitCandidate[] candidates = findBestSplitsClassification(currentDepth, dataMemberships, columnSample, treeNodeSignature, targetPriors, forbiddenColumnSet);
        if (candidates == null) {
            return new TreeNodeClassification(treeNodeSignature, targetPriors, config);
        }
        SurrogateSplit surrogateSplit = Surrogates.learnSurrogates(dataMemberships, candidates[0], data, columnSample, config, getRandomData());
        childConditions = surrogateSplit.getChildConditions();
        BitSet[] childMarkers = surrogateSplit.getChildMarkers();
        childNodes = new TreeNodeClassification[2];
        for (int i = 0; i < 2; i++) {
            DataMemberships childMemberships = dataMemberships.createChildMemberships(childMarkers[i]);
            ClassificationPriors childTargetPriors = targetColumn.getDistribution(childMemberships, config);
            TreeNodeSignature childSignature = getSignatureFactory().getChildSignatureFor(treeNodeSignature, (byte) i);
            ColumnSample childColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(childSignature);
            childNodes[i] = buildTreeNode(exec, currentDepth + 1, childMemberships, childColumnSample, childSignature, childTargetPriors, forbiddenColumnSet);
            childNodes[i].setTreeNodeCondition(childConditions[i]);
        }
    } else {
        // handle non surrogate case
        SplitCandidate bestSplit = findBestSplitClassification(currentDepth, dataMemberships, columnSample, treeNodeSignature, targetPriors, forbiddenColumnSet);
        if (bestSplit == null) {
            return new TreeNodeClassification(treeNodeSignature, targetPriors, config);
        }
        TreeAttributeColumnData splitColumn = bestSplit.getColumnData();
        attributeIndex = splitColumn.getMetaData().getAttributeIndex();
        markAttributeAsForbidden = !bestSplit.canColumnBeSplitFurther();
        forbiddenColumnSet.set(attributeIndex, markAttributeAsForbidden);
        childConditions = bestSplit.getChildConditions();
        childNodes = new TreeNodeClassification[childConditions.length];
        if (childConditions.length > Short.MAX_VALUE) {
            throw new RuntimeException("Too many children when splitting " + "attribute " + bestSplit.getColumnData() + " (maximum supported: " + Short.MAX_VALUE + "): " + childConditions.length);
        }
        // Build child nodes
        for (int i = 0; i < childConditions.length; i++) {
            DataMemberships childMemberships = null;
            TreeNodeCondition cond = childConditions[i];
            childMemberships = dataMemberships.createChildMemberships(splitColumn.updateChildMemberships(cond, dataMemberships));
            ClassificationPriors childTargetPriors = targetColumn.getDistribution(childMemberships, config);
            TreeNodeSignature childSignature = treeNodeSignature.createChildSignature((byte) i);
            ColumnSample childColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(childSignature);
            childNodes[i] = buildTreeNode(exec, currentDepth + 1, childMemberships, childColumnSample, childSignature, childTargetPriors, forbiddenColumnSet);
            childNodes[i].setTreeNodeCondition(cond);
        }
    }
    if (markAttributeAsForbidden) {
        forbiddenColumnSet.set(attributeIndex, false);
    }
    return new TreeNodeClassification(treeNodeSignature, targetPriors, childNodes, getConfig());
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) TreeNodeClassification(org.knime.base.node.mine.treeensemble2.model.TreeNodeClassification) TreeAttributeColumnData(org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData) ColumnSample(org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample) BitSet(java.util.BitSet) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) TreeNodeCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition) TreeTargetNominalColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData) ClassificationPriors(org.knime.base.node.mine.treeensemble2.data.ClassificationPriors)

Example 45 with DataMemberships

use of org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships in project knime-core by knime.

the class TreeLearnerRegression method buildTreeNode.

private TreeNodeRegression buildTreeNode(final ExecutionMonitor exec, final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final TreeNodeSignature treeNodeSignature, final RegressionPriors targetPriors, final BitSet forbiddenColumnSet) throws CanceledExecutionException {
    final TreeData data = getData();
    final RandomData rd = getRandomData();
    final TreeEnsembleLearnerConfiguration config = getConfig();
    exec.checkCanceled();
    final SplitCandidate candidate = findBestSplitRegression(currentDepth, dataMemberships, columnSample, targetPriors, forbiddenColumnSet);
    if (candidate == null) {
        if (config instanceof GradientBoostingLearnerConfiguration) {
            TreeNodeRegression leaf = new TreeNodeRegression(treeNodeSignature, targetPriors, dataMemberships.getOriginalIndices());
            addToLeafList(leaf);
            return leaf;
        }
        return new TreeNodeRegression(treeNodeSignature, targetPriors);
    }
    final TreeTargetNumericColumnData targetColumn = (TreeTargetNumericColumnData) data.getTargetColumn();
    boolean useSurrogates = config.getMissingValueHandling() == MissingValueHandling.Surrogate;
    TreeNodeCondition[] childConditions;
    TreeNodeRegression[] childNodes;
    if (useSurrogates) {
        SurrogateSplit surrogateSplit = Surrogates.learnSurrogates(dataMemberships, candidate, data, columnSample, config, rd);
        childConditions = surrogateSplit.getChildConditions();
        BitSet[] childMarkers = surrogateSplit.getChildMarkers();
        assert childMarkers[0].cardinality() + childMarkers[1].cardinality() == dataMemberships.getRowCount() : "Sum of rows in children does not add up to number of rows in parent.";
        childNodes = new TreeNodeRegression[2];
        for (int i = 0; i < 2; i++) {
            DataMemberships childMemberships = dataMemberships.createChildMemberships(childMarkers[i]);
            TreeNodeSignature childSignature = getSignatureFactory().getChildSignatureFor(treeNodeSignature, (byte) i);
            ColumnSample childColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(childSignature);
            RegressionPriors childTargetPriors = targetColumn.getPriors(childMemberships, config);
            childNodes[i] = buildTreeNode(exec, currentDepth + 1, childMemberships, childColumnSample, childSignature, childTargetPriors, forbiddenColumnSet);
            childNodes[i].setTreeNodeCondition(childConditions[i]);
        }
    } else {
        SplitCandidate bestSplit = candidate;
        TreeAttributeColumnData splitColumn = bestSplit.getColumnData();
        final int attributeIndex = splitColumn.getMetaData().getAttributeIndex();
        boolean markAttributeAsForbidden = !bestSplit.canColumnBeSplitFurther();
        forbiddenColumnSet.set(attributeIndex, markAttributeAsForbidden);
        childConditions = bestSplit.getChildConditions();
        if (childConditions.length > Short.MAX_VALUE) {
            throw new RuntimeException("Too many children when splitting " + "attribute " + bestSplit.getColumnData() + " (maximum supported: " + Short.MAX_VALUE + "): " + childConditions.length);
        }
        childNodes = new TreeNodeRegression[childConditions.length];
        for (int i = 0; i < childConditions.length; i++) {
            TreeNodeCondition cond = childConditions[i];
            DataMemberships childMemberships = dataMemberships.createChildMemberships(splitColumn.updateChildMemberships(cond, dataMemberships));
            RegressionPriors childTargetPriors = targetColumn.getPriors(childMemberships, config);
            TreeNodeSignature childSignature = treeNodeSignature.createChildSignature((byte) i);
            ColumnSample childColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(childSignature);
            childNodes[i] = buildTreeNode(exec, currentDepth + 1, childMemberships, childColumnSample, childSignature, childTargetPriors, forbiddenColumnSet);
            childNodes[i].setTreeNodeCondition(cond);
        }
        if (markAttributeAsForbidden) {
            forbiddenColumnSet.set(attributeIndex, false);
        }
    }
    return new TreeNodeRegression(treeNodeSignature, targetPriors, childNodes);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RandomData(org.apache.commons.math.random.RandomData) TreeAttributeColumnData(org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData) ColumnSample(org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample) RegressionPriors(org.knime.base.node.mine.treeensemble2.data.RegressionPriors) BitSet(java.util.BitSet) TreeTargetNumericColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeNodeRegression(org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) GradientBoostingLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.gradientboosting.learner.GradientBoostingLearnerConfiguration) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) TreeNodeCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition)

Aggregations

TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)34 DataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships)26 RootDataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships)25 BitSet (java.util.BitSet)21 Test (org.junit.Test)21 SplitCandidate (org.knime.base.node.mine.treeensemble2.learner.SplitCandidate)17 RandomData (org.apache.commons.math.random.RandomData)15 DefaultDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager)14 NominalBinarySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate)13 IDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager)12 NominalMultiwaySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate)12 TreeData (org.knime.base.node.mine.treeensemble2.data.TreeData)10 ColumnMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.ColumnMemberships)10 TreeAttributeColumnData (org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData)9 TreeNodeNominalBinaryCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalBinaryCondition)9 NumericSplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NumericSplitCandidate)7 TreeNodeNumericCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeNumericCondition)7 TreeNodeCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition)6 TreeTargetNominalColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData)5 NumericMissingSplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NumericMissingSplitCandidate)5