Search in sources :

Example 66 with TreeEnsembleLearnerConfiguration

use of org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration in project knime-core by knime.

the class TreeNominalColumnData method calcBestSplitClassificationBinaryTwoClass.

private NominalBinarySplitCandidate calcBestSplitClassificationBinaryTwoClass(final ColumnMemberships columnMemberships, final ClassificationPriors targetPriors, final TreeTargetNominalColumnData targetColumn, final IImpurity impCriterion, final NominalValueRepresentation[] nomVals, final NominalValueRepresentation[] targetVals, final RandomData rd) {
    if (targetColumn.getMetaData().getValues().length != 2) {
        throw new IllegalArgumentException("This method can only be used for two class problems.");
    }
    final TreeEnsembleLearnerConfiguration config = getConfiguration();
    final int minChildSize = config.getMinChildSize();
    final boolean useXGBoostMissingValueHandling = config.getMissingValueHandling() == MissingValueHandling.XGBoost;
    int start = 0;
    final int firstClass = targetColumn.getMetaData().getValues()[0].getAssignedInteger();
    double totalWeight = 0.0;
    double totalFirstClassWeight = 0.0;
    final ArrayList<NomValProbabilityPair> nomValProbabilities = new ArrayList<NomValProbabilityPair>();
    if (!columnMemberships.next()) {
        throw new IllegalStateException("The columnMemberships has not been reset or is empty.");
    }
    final int lengthNonMissing = containsMissingValues() ? nomVals.length - 1 : nomVals.length;
    // final int attToConsider = useXGBoostMissingValueHandling ? nomVals.length : lengthNonMissing;
    boolean branchContainsMissingValues = containsMissingValues();
    // calculate probabilities for first class in each nominal value
    for (int att = 0; att < /*attToConsider*/
    lengthNonMissing; att++) {
        int end = start + m_nominalValueCounts[att];
        double attFirstClassWeight = 0;
        double attWeight = 0;
        boolean reachedEnd = false;
        for (int index = columnMemberships.getIndexInColumn(); index < end; index = columnMemberships.getIndexInColumn()) {
            double weight = columnMemberships.getRowWeight();
            assert weight > EPSILON : "Instances in columnMemberships must have weights larger than EPSILON.";
            final int instanceClass = targetColumn.getValueFor(columnMemberships.getOriginalIndex());
            if (instanceClass == firstClass) {
                attFirstClassWeight += weight;
                totalFirstClassWeight += weight;
            }
            attWeight += weight;
            totalWeight += weight;
            if (!columnMemberships.next()) {
                // reached end of columnMemberships
                reachedEnd = true;
                if (att == nomVals.length - 1) {
                    // if the column contains no missing values, the last possible nominal value is
                    // not the missing value and therefore branchContainsMissingValues needs to be false
                    branchContainsMissingValues = branchContainsMissingValues && true;
                }
                break;
            }
        }
        if (attWeight > 0) {
            final double firstClassProbability = attFirstClassWeight / attWeight;
            final NominalValueRepresentation nomVal = getMetaData().getValues()[att];
            nomValProbabilities.add(new NomValProbabilityPair(nomVal, firstClassProbability, attWeight, attFirstClassWeight));
        }
        start = end;
        if (reachedEnd) {
            break;
        }
    }
    // account for missing values and their weight
    double missingWeight = 0.0;
    double missingWeightFirstClass = 0.0;
    // otherwise the current indexInColumn won't be larger than start
    if (columnMemberships.getIndexInColumn() >= start) {
        do {
            final double recordWeight = columnMemberships.getRowWeight();
            missingWeight += recordWeight;
            final int recordClass = targetColumn.getValueFor(columnMemberships.getOriginalIndex());
            if (recordClass == firstClass) {
                missingWeightFirstClass += recordWeight;
            }
        } while (columnMemberships.next());
    }
    if (missingWeight > EPSILON) {
        branchContainsMissingValues = true;
    }
    nomValProbabilities.sort(null);
    int highestBitPosition = getMetaData().getValues().length - 1;
    if (containsMissingValues()) {
        highestBitPosition--;
    }
    final double[] targetCountsSplitPartition = new double[2];
    final double[] targetCountsSplitRemaining = new double[2];
    final double[] binaryImpurityValues = new double[2];
    final double[] binaryPartitionWeights = new double[2];
    BigInteger partitionMask = BigInteger.ZERO;
    double bestPartitionGain = Double.NEGATIVE_INFINITY;
    BigInteger bestPartitionMask = null;
    boolean isBestSplitValid = false;
    double sumWeightsPartitionTotal = 0.0;
    double sumWeightsPartitionFirstClass = 0.0;
    boolean missingsGoLeft = false;
    final double priorImpurity = useXGBoostMissingValueHandling ? targetPriors.getPriorImpurity() : impCriterion.getPartitionImpurity(subtractMissingClassCounts(targetPriors.getDistribution(), createMissingClassCountsTwoClass(missingWeight, missingWeightFirstClass)), totalWeight);
    // we don't need to iterate over the full list because we always need some value on the other side
    for (int i = 0; i < nomValProbabilities.size() - 1; i++) {
        NomValProbabilityPair nomVal = nomValProbabilities.get(i);
        sumWeightsPartitionTotal += nomVal.m_sumWeights;
        sumWeightsPartitionFirstClass += nomVal.m_firstClassSumWeights;
        partitionMask = partitionMask.or(nomVal.m_bitMask);
        // check if split represented by currentSplitList is in the right branch
        // by convention a split goes towards the right branch if the highest possible bit is set to 1
        final boolean isRightBranch = partitionMask.testBit(highestBitPosition);
        double gain;
        boolean isValidSplit;
        boolean tempMissingsGoLeft = true;
        if (branchContainsMissingValues && useXGBoostMissingValueHandling) {
            // send missing values both ways and take the better direction
            // send missings left
            targetCountsSplitPartition[0] = sumWeightsPartitionFirstClass + missingWeightFirstClass;
            targetCountsSplitPartition[1] = sumWeightsPartitionTotal + missingWeight - targetCountsSplitPartition[0];
            binaryPartitionWeights[1] = sumWeightsPartitionTotal + missingWeight;
            // totalFirstClassWeight and totalWeight only include non missing values
            targetCountsSplitRemaining[0] = totalFirstClassWeight - sumWeightsPartitionFirstClass;
            targetCountsSplitRemaining[1] = totalWeight - sumWeightsPartitionTotal - targetCountsSplitRemaining[0];
            binaryPartitionWeights[0] = totalWeight - sumWeightsPartitionTotal;
            boolean isValidSplitLeft = binaryPartitionWeights[0] >= minChildSize && binaryPartitionWeights[1] >= minChildSize;
            binaryImpurityValues[0] = impCriterion.getPartitionImpurity(targetCountsSplitRemaining, binaryPartitionWeights[0]);
            binaryImpurityValues[1] = impCriterion.getPartitionImpurity(targetCountsSplitPartition, binaryPartitionWeights[1]);
            double postSplitImpurity = impCriterion.getPostSplitImpurity(binaryImpurityValues, binaryPartitionWeights, totalWeight + missingWeight);
            double gainLeft = impCriterion.getGain(priorImpurity, postSplitImpurity, binaryPartitionWeights, totalWeight + missingWeight);
            // send missings right
            targetCountsSplitPartition[0] = sumWeightsPartitionFirstClass;
            targetCountsSplitPartition[1] = sumWeightsPartitionTotal - sumWeightsPartitionFirstClass;
            binaryPartitionWeights[1] = sumWeightsPartitionTotal;
            targetCountsSplitRemaining[0] = totalFirstClassWeight - sumWeightsPartitionFirstClass + missingWeightFirstClass;
            targetCountsSplitRemaining[1] = totalWeight - sumWeightsPartitionTotal + missingWeight - targetCountsSplitRemaining[0];
            binaryPartitionWeights[0] = totalWeight + missingWeight - sumWeightsPartitionTotal;
            boolean isValidSplitRight = binaryPartitionWeights[0] >= minChildSize && binaryPartitionWeights[1] >= minChildSize;
            binaryImpurityValues[0] = impCriterion.getPartitionImpurity(targetCountsSplitRemaining, binaryPartitionWeights[0]);
            binaryImpurityValues[1] = impCriterion.getPartitionImpurity(targetCountsSplitPartition, binaryPartitionWeights[1]);
            postSplitImpurity = impCriterion.getPostSplitImpurity(binaryImpurityValues, binaryPartitionWeights, totalWeight + missingWeight);
            double gainRight = impCriterion.getGain(priorImpurity, postSplitImpurity, binaryPartitionWeights, totalWeight + missingWeight);
            // decide which is better (better gain)
            if (gainLeft >= gainRight) {
                gain = gainLeft;
                isValidSplit = isValidSplitLeft;
                tempMissingsGoLeft = true;
            } else {
                gain = gainRight;
                isValidSplit = isValidSplitRight;
                tempMissingsGoLeft = false;
            }
        } else {
            // assign weights to branches
            targetCountsSplitPartition[0] = sumWeightsPartitionFirstClass;
            targetCountsSplitPartition[1] = sumWeightsPartitionTotal - sumWeightsPartitionFirstClass;
            binaryPartitionWeights[1] = sumWeightsPartitionTotal;
            targetCountsSplitRemaining[0] = totalFirstClassWeight - sumWeightsPartitionFirstClass;
            targetCountsSplitRemaining[1] = totalWeight - sumWeightsPartitionTotal - targetCountsSplitRemaining[0];
            binaryPartitionWeights[0] = totalWeight - sumWeightsPartitionTotal;
            isValidSplit = binaryPartitionWeights[0] >= minChildSize && binaryPartitionWeights[1] >= minChildSize;
            binaryImpurityValues[0] = impCriterion.getPartitionImpurity(targetCountsSplitRemaining, binaryPartitionWeights[0]);
            binaryImpurityValues[1] = impCriterion.getPartitionImpurity(targetCountsSplitPartition, binaryPartitionWeights[1]);
            double postSplitImpurity = impCriterion.getPostSplitImpurity(binaryImpurityValues, binaryPartitionWeights, totalWeight);
            gain = impCriterion.getGain(priorImpurity, postSplitImpurity, binaryPartitionWeights, totalWeight);
        }
        // use random tie breaker if gains are equal
        boolean randomTieBreaker = gain == bestPartitionGain ? rd.nextInt(0, 1) == 1 : false;
        // store if better than before or first valid split
        if (gain > bestPartitionGain || (!isBestSplitValid && isValidSplit) || randomTieBreaker) {
            if (isValidSplit || !isBestSplitValid) {
                bestPartitionGain = gain;
                bestPartitionMask = isRightBranch ? partitionMask : BigInteger.ZERO.setBit(highestBitPosition + 1).subtract(BigInteger.ONE).xor(partitionMask);
                isBestSplitValid = isValidSplit;
                // missingsGoLeft is only used later on if XGBoost Missing Value Handling is used
                if (branchContainsMissingValues) {
                    // missingsGoLeft = isRightBranch;
                    missingsGoLeft = tempMissingsGoLeft;
                } else {
                    // no missing values in this branch
                    // send missing values with the majority
                    missingsGoLeft = isRightBranch ? sumWeightsPartitionTotal < 0.5 * totalWeight : sumWeightsPartitionTotal >= 0.5 * totalWeight;
                }
            }
        }
    }
    if (isBestSplitValid && bestPartitionGain > 0.0) {
        if (useXGBoostMissingValueHandling) {
            return new NominalBinarySplitCandidate(this, bestPartitionGain, bestPartitionMask, NO_MISSED_ROWS, missingsGoLeft ? NominalBinarySplitCandidate.MISSINGS_GO_LEFT : NominalBinarySplitCandidate.MISSINGS_GO_RIGHT);
        }
        return new NominalBinarySplitCandidate(this, bestPartitionGain, bestPartitionMask, getMissedRows(columnMemberships), NominalBinarySplitCandidate.NO_MISSINGS);
    }
    return null;
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) ArrayList(java.util.ArrayList) BigInteger(java.math.BigInteger) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate)

Example 67 with TreeEnsembleLearnerConfiguration

use of org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration in project knime-core by knime.

the class TreeLearnerRegression method buildTreeNode.

private TreeNodeRegression buildTreeNode(final ExecutionMonitor exec, final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final TreeNodeSignature treeNodeSignature, final RegressionPriors targetPriors, final BitSet forbiddenColumnSet) throws CanceledExecutionException {
    final TreeData data = getData();
    final RandomData rd = getRandomData();
    final TreeEnsembleLearnerConfiguration config = getConfig();
    exec.checkCanceled();
    final SplitCandidate candidate = findBestSplitRegression(currentDepth, dataMemberships, columnSample, targetPriors, forbiddenColumnSet);
    if (candidate == null) {
        if (config instanceof GradientBoostingLearnerConfiguration) {
            TreeNodeRegression leaf = new TreeNodeRegression(treeNodeSignature, targetPriors, dataMemberships.getOriginalIndices());
            addToLeafList(leaf);
            return leaf;
        }
        return new TreeNodeRegression(treeNodeSignature, targetPriors);
    }
    final TreeTargetNumericColumnData targetColumn = (TreeTargetNumericColumnData) data.getTargetColumn();
    boolean useSurrogates = config.getMissingValueHandling() == MissingValueHandling.Surrogate;
    TreeNodeCondition[] childConditions;
    TreeNodeRegression[] childNodes;
    if (useSurrogates) {
        SurrogateSplit surrogateSplit = Surrogates.learnSurrogates(dataMemberships, candidate, data, columnSample, config, rd);
        childConditions = surrogateSplit.getChildConditions();
        BitSet[] childMarkers = surrogateSplit.getChildMarkers();
        assert childMarkers[0].cardinality() + childMarkers[1].cardinality() == dataMemberships.getRowCount() : "Sum of rows in children does not add up to number of rows in parent.";
        childNodes = new TreeNodeRegression[2];
        for (int i = 0; i < 2; i++) {
            DataMemberships childMemberships = dataMemberships.createChildMemberships(childMarkers[i]);
            TreeNodeSignature childSignature = getSignatureFactory().getChildSignatureFor(treeNodeSignature, (byte) i);
            ColumnSample childColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(childSignature);
            RegressionPriors childTargetPriors = targetColumn.getPriors(childMemberships, config);
            childNodes[i] = buildTreeNode(exec, currentDepth + 1, childMemberships, childColumnSample, childSignature, childTargetPriors, forbiddenColumnSet);
            childNodes[i].setTreeNodeCondition(childConditions[i]);
        }
    } else {
        SplitCandidate bestSplit = candidate;
        TreeAttributeColumnData splitColumn = bestSplit.getColumnData();
        final int attributeIndex = splitColumn.getMetaData().getAttributeIndex();
        boolean markAttributeAsForbidden = !bestSplit.canColumnBeSplitFurther();
        forbiddenColumnSet.set(attributeIndex, markAttributeAsForbidden);
        childConditions = bestSplit.getChildConditions();
        if (childConditions.length > Short.MAX_VALUE) {
            throw new RuntimeException("Too many children when splitting " + "attribute " + bestSplit.getColumnData() + " (maximum supported: " + Short.MAX_VALUE + "): " + childConditions.length);
        }
        childNodes = new TreeNodeRegression[childConditions.length];
        for (int i = 0; i < childConditions.length; i++) {
            TreeNodeCondition cond = childConditions[i];
            DataMemberships childMemberships = dataMemberships.createChildMemberships(splitColumn.updateChildMemberships(cond, dataMemberships));
            RegressionPriors childTargetPriors = targetColumn.getPriors(childMemberships, config);
            TreeNodeSignature childSignature = treeNodeSignature.createChildSignature((byte) i);
            ColumnSample childColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(childSignature);
            childNodes[i] = buildTreeNode(exec, currentDepth + 1, childMemberships, childColumnSample, childSignature, childTargetPriors, forbiddenColumnSet);
            childNodes[i].setTreeNodeCondition(cond);
        }
        if (markAttributeAsForbidden) {
            forbiddenColumnSet.set(attributeIndex, false);
        }
    }
    return new TreeNodeRegression(treeNodeSignature, targetPriors, childNodes);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RandomData(org.apache.commons.math.random.RandomData) TreeAttributeColumnData(org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData) ColumnSample(org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample) RegressionPriors(org.knime.base.node.mine.treeensemble2.data.RegressionPriors) BitSet(java.util.BitSet) TreeTargetNumericColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeNodeRegression(org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) GradientBoostingLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.gradientboosting.learner.GradientBoostingLearnerConfiguration) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) TreeNodeCondition(org.knime.base.node.mine.treeensemble2.model.TreeNodeCondition)

Example 68 with TreeEnsembleLearnerConfiguration

use of org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration in project knime-core by knime.

the class TreeLearnerRegression method findBestSplitRegression.

private SplitCandidate findBestSplitRegression(final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final RegressionPriors targetPriors, final BitSet forbiddenColumnSet) {
    final TreeData data = getData();
    final RandomData rd = getRandomData();
    final TreeEnsembleLearnerConfiguration config = getConfig();
    final int maxLevels = config.getMaxLevels();
    if (maxLevels != TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE && currentDepth >= maxLevels) {
        return null;
    }
    final int minNodeSize = config.getMinNodeSize();
    if (minNodeSize != TreeEnsembleLearnerConfiguration.MIN_NODE_SIZE_UNDEFINED) {
        if (targetPriors.getNrRecords() < minNodeSize) {
            return null;
        }
    }
    final double priorSquaredDeviation = targetPriors.getSumSquaredDeviation();
    if (priorSquaredDeviation < TreeColumnData.EPSILON) {
        return null;
    }
    final TreeTargetNumericColumnData targetColumn = getTargetData();
    SplitCandidate splitCandidate = null;
    if (currentDepth == 0 && config.getHardCodedRootColumn() != null) {
        final TreeAttributeColumnData rootColumn = data.getColumn(config.getHardCodedRootColumn());
        return rootColumn.calcBestSplitRegression(dataMemberships, targetPriors, targetColumn, rd);
    } else {
        double bestGainValue = 0.0;
        for (TreeAttributeColumnData col : columnSample) {
            if (forbiddenColumnSet.get(col.getMetaData().getAttributeIndex())) {
                continue;
            }
            SplitCandidate currentColSplit = col.calcBestSplitRegression(dataMemberships, targetPriors, targetColumn, rd);
            if (currentColSplit != null) {
                double gainValue = currentColSplit.getGainValue();
                if (gainValue > bestGainValue) {
                    bestGainValue = gainValue;
                    splitCandidate = currentColSplit;
                }
            }
        }
        return splitCandidate;
    }
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RandomData(org.apache.commons.math.random.RandomData) TreeAttributeColumnData(org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData) TreeTargetNumericColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData)

Aggregations

TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)62 Test (org.junit.Test)29 DataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships)27 RootDataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships)26 SplitCandidate (org.knime.base.node.mine.treeensemble2.learner.SplitCandidate)19 RandomData (org.apache.commons.math.random.RandomData)17 BitSet (java.util.BitSet)16 DefaultDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager)15 NominalBinarySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate)15 IDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager)13 NominalMultiwaySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate)13 TreeData (org.knime.base.node.mine.treeensemble2.data.TreeData)10 TreeNodeNominalBinaryCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalBinaryCondition)10 TestDataGenerator (org.knime.base.node.mine.treeensemble2.data.TestDataGenerator)9 TreeAttributeColumnData (org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData)8 NumericSplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NumericSplitCandidate)8 TreeNodeNumericCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeNumericCondition)7 NumericMissingSplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NumericMissingSplitCandidate)6 TreeNodeNominalCondition (org.knime.base.node.mine.treeensemble2.model.TreeNodeNominalCondition)6 TreeTargetNominalColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNominalColumnData)5