Search in sources :

Example 1 with TreeTargetNumericColumnData

use of org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData in project knime-core by knime.

the class TreeNominalColumnData method calcBestSplitRegression.

/**
 * {@inheritDoc}
 */
@Override
public SplitCandidate calcBestSplitRegression(final DataMemberships dataMemberships, final RegressionPriors targetPriors, final TreeTargetNumericColumnData targetColumn, final RandomData rd) {
    final NominalValueRepresentation[] nomVals = getMetaData().getValues();
    final ColumnMemberships columnMemberships = dataMemberships.getColumnMemberships(getMetaData().getAttributeIndex());
    final boolean useBinaryNominalSplits = getConfiguration().isUseBinaryNominalSplits();
    if (useBinaryNominalSplits) {
        return calcBestSplitRegressionBinaryBreiman(columnMemberships, targetPriors, targetColumn, nomVals, rd);
    } else {
        return calcBestSplitRegressionMultiway(columnMemberships, targetPriors, targetColumn, nomVals, rd);
    }
}
Also used : ColumnMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.ColumnMemberships)

Example 2 with TreeTargetNumericColumnData

use of org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData in project knime-core by knime.

the class TreeNominalColumnData method calcBestSplitRegressionBinaryBreiman.

/**
 * If an attribute value does not appear in the current branch, it is not guaranteed in which child branch this
 * value will fall. (This should not be a problem since we cannot make any assumptions about this attribute value
 * anyway)
 *
 * @param membershipController
 * @param rowWeights
 * @param targetPriors
 * @param targetColumn
 * @param nomVals
 * @param originalIndexInColumnList
 * @return best split candidate or null if there is no split candidate with positive gain or too small child nodes
 */
private NominalBinarySplitCandidate calcBestSplitRegressionBinaryBreiman(final ColumnMemberships columnMemberships, final RegressionPriors targetPriors, final TreeTargetNumericColumnData targetColumn, final NominalValueRepresentation[] nomVals, final RandomData rd) {
    final int minChildSize = getConfiguration().getMinChildSize();
    double sumYTotal = targetPriors.getYSum();
    double sumWeightTotal = targetPriors.getNrRecords();
    final boolean useXGBoostMissingValueHandling = getConfiguration().getMissingValueHandling() == MissingValueHandling.XGBoost;
    boolean branchContainsMissingValues = containsMissingValues();
    double missingWeight = 0.0;
    double missingY = 0.0;
    if (branchContainsMissingValues) {
        columnMemberships.goToLast();
        while (columnMemberships.getIndexInColumn() >= m_idxOfFirstMissing) {
            final double weight = columnMemberships.getRowWeight();
            missingWeight += weight;
            missingY += weight * targetColumn.getValueFor(columnMemberships.getOriginalIndex());
            if (!columnMemberships.previous()) {
                break;
            }
        }
        sumYTotal -= missingY;
        sumWeightTotal -= missingWeight;
        branchContainsMissingValues = missingWeight > 0.0;
        columnMemberships.reset();
    }
    final double criterionTotal;
    if (useXGBoostMissingValueHandling) {
        criterionTotal = (sumYTotal + missingY) * (sumYTotal + missingY) / (sumWeightTotal + missingWeight);
    } else {
        criterionTotal = sumYTotal + sumYTotal / sumWeightTotal;
    }
    final ArrayList<AttValTupleRegression> attValList = Lists.newArrayList();
    columnMemberships.next();
    int start = 0;
    final int lengthNonMissing = containsMissingValues() ? nomVals.length - 1 : nomVals.length;
    for (int att = 0; att < lengthNonMissing; att++) {
        double sumY = 0.0;
        double sumWeight = 0.0;
        int end = start + m_nominalValueCounts[att];
        boolean reachedEnd = false;
        for (int index = columnMemberships.getIndexInColumn(); index < end; index = columnMemberships.getIndexInColumn()) {
            double weight = columnMemberships.getRowWeight();
            assert weight > EPSILON : "Instances in columnMemberships must have weights larger than EPSILON.";
            sumY += targetColumn.getValueFor(columnMemberships.getOriginalIndex());
            sumWeight += weight;
            if (!columnMemberships.next()) {
                reachedEnd = true;
                break;
            }
        }
        start = end;
        if (sumWeight < EPSILON) {
            // we cannot make any assumptions about this attribute value
            continue;
        }
        attValList.add(new AttValTupleRegression(sumY, sumWeight, sumY / sumWeight, nomVals[att]));
        if (reachedEnd) {
            break;
        }
    }
    assert sumWeights(attValList) == sumWeightTotal : "The weights of the attribute values does not sum up to the total weight";
    // sort attribute values according to their mean Y value
    attValList.sort(null);
    BigInteger bestPartitionMask = null;
    boolean isBestSplitValid = false;
    double bestPartitionGain = Double.NEGATIVE_INFINITY;
    final int highestBitPosition = containsMissingValues() ? nomVals.length - 2 : nomVals.length - 1;
    double sumYPartition = 0.0;
    double sumWeightPartition = 0.0;
    BigInteger partitionMask = BigInteger.ZERO;
    double sumYRemaining = sumYTotal;
    double sumWeightRemaining = sumWeightTotal;
    boolean missingsGoLeft = true;
    // no need to iterate over full list because at least one value must remain on the other side of the split
    for (int i = 0; i < attValList.size() - 1; i++) {
        AttValTupleRegression attVal = attValList.get(i);
        sumYPartition += attVal.m_sumY;
        sumWeightPartition += attVal.m_sumWeight;
        sumYRemaining -= attVal.m_sumY;
        sumWeightRemaining -= attVal.m_sumWeight;
        assert AbsIsSmallerEpsilon(sumWeightTotal - sumWeightRemaining - sumWeightPartition) : "The weights left and right of the split do not add up to the total weight.";
        assert sumWeightPartition > 0.0 : "The weight of the partition is zero.";
        assert sumWeightRemaining > 0.0 : "The weight of the remaining is zero.";
        partitionMask = partitionMask.or(attVal.m_bitMask);
        double gain;
        boolean isValidSplit;
        boolean tempMissingsGoLeft = true;
        if (branchContainsMissingValues && useXGBoostMissingValueHandling) {
            boolean isValidSplitPartitionWithMissing = sumWeightPartition + missingWeight >= minChildSize && sumWeightRemaining >= minChildSize;
            double sumYMissingWithPartition = sumYPartition + missingY;
            double gainMissingWithPartition = sumYMissingWithPartition * sumYMissingWithPartition / (sumWeightPartition + missingWeight) + sumYRemaining * sumYRemaining / sumWeightRemaining - criterionTotal;
            boolean isValidSplitRemainingWithMissing = sumWeightPartition >= minChildSize && sumWeightRemaining + missingWeight >= minChildSize;
            double sumYMissingWithRemaining = sumYRemaining + missingY;
            double gainMissingWithRemaining = sumYPartition * sumYPartition / sumWeightPartition + sumYMissingWithRemaining * sumYMissingWithRemaining / (sumWeightRemaining + missingWeight) - criterionTotal;
            if (gainMissingWithPartition >= gainMissingWithRemaining) {
                gain = gainMissingWithPartition;
                isValidSplit = isValidSplitPartitionWithMissing;
                tempMissingsGoLeft = !partitionMask.testBit(highestBitPosition);
            } else {
                gain = gainMissingWithRemaining;
                isValidSplit = isValidSplitRemainingWithMissing;
                tempMissingsGoLeft = partitionMask.testBit(highestBitPosition);
            }
        } else {
            isValidSplit = sumWeightPartition >= minChildSize && sumWeightRemaining >= minChildSize;
            gain = sumYPartition * sumYPartition / sumWeightPartition + sumYRemaining * sumYRemaining / sumWeightRemaining - criterionTotal;
        }
        // use random tie breaker if gains are equal
        boolean randomTieBreaker = gain == bestPartitionGain ? rd.nextInt(0, 1) == 1 : false;
        // store if better than before or first valid split
        if (gain > bestPartitionGain || (!isBestSplitValid && isValidSplit) || randomTieBreaker) {
            if (isValidSplit || !isBestSplitValid) {
                bestPartitionGain = gain;
                // right branch must by convention always contain the nominal value
                // with the highest assigned integer
                bestPartitionMask = partitionMask.testBit(highestBitPosition) ? partitionMask : BigInteger.ZERO.setBit(highestBitPosition + 1).subtract(BigInteger.ONE).xor(partitionMask);
                isBestSplitValid = isValidSplit;
                if (branchContainsMissingValues) {
                    missingsGoLeft = tempMissingsGoLeft;
                } else {
                    // no missings in this branch, but we still have to provide a direction for missing values
                    // send missings in the direction the most records in the node are sent to
                    boolean sendWithPartition = sumWeightPartition >= sumWeightRemaining;
                    missingsGoLeft = sendWithPartition ? !partitionMask.testBit(highestBitPosition) : partitionMask.testBit(highestBitPosition);
                }
            }
        }
    }
    if (bestPartitionGain > 0.0 && isBestSplitValid) {
        if (useXGBoostMissingValueHandling) {
            return new NominalBinarySplitCandidate(this, bestPartitionGain, bestPartitionMask, NO_MISSED_ROWS, missingsGoLeft ? NominalBinarySplitCandidate.MISSINGS_GO_LEFT : NominalBinarySplitCandidate.MISSINGS_GO_RIGHT);
        }
        return new NominalBinarySplitCandidate(this, bestPartitionGain, bestPartitionMask, getMissedRows(columnMemberships), NominalBinarySplitCandidate.NO_MISSINGS);
    }
    return null;
}
Also used : BigInteger(java.math.BigInteger) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate)

Example 3 with TreeTargetNumericColumnData

use of org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData in project knime-core by knime.

the class TreeNominalColumnData method calcBestSplitRegressionBinary.

private NominalBinarySplitCandidate calcBestSplitRegressionBinary(final ColumnMemberships columnMemberships, final RegressionPriors targetPriors, final TreeTargetNumericColumnData targetColumn, final NominalValueRepresentation[] nomVals, final RandomData rd) {
    final int minChildSize = getConfiguration().getMinChildSize();
    final double ySumTotal = targetPriors.getYSum();
    final double nrRecordsTotal = targetPriors.getNrRecords();
    final double criterionTotal = ySumTotal * ySumTotal / nrRecordsTotal;
    final double[] ySums = new double[nomVals.length];
    final double[] sumWeightsAttributes = new double[nomVals.length];
    columnMemberships.next();
    int start = 0;
    for (int att = 0; att < nomVals.length; att++) {
        int end = start + m_nominalValueCounts[att];
        double weightSum = 0.0;
        double ySum = 0.0;
        boolean reachedEnd = false;
        for (int index = columnMemberships.getIndexInColumn(); index < end; index = columnMemberships.getIndexInColumn()) {
            final double weight = columnMemberships.getRowWeight();
            assert weight > EPSILON : "Instances in columnMemberships must have weights larger than EPSILON.";
            ySum += weight * targetColumn.getValueFor(columnMemberships.getOriginalIndex());
            weightSum += weight;
            if (!columnMemberships.next()) {
                // reached end of columnMemberships
                reachedEnd = true;
                break;
            }
        }
        sumWeightsAttributes[att] = weightSum;
        ySums[att] = ySum;
        start = end;
        if (reachedEnd) {
            break;
        }
    }
    BinarySplitEnumeration splitEnumeration;
    if (nomVals.length <= 10) {
        splitEnumeration = new FullBinarySplitEnumeration(nomVals.length);
    } else {
        int maxSearch = (1 << 10 - 2);
        splitEnumeration = new RandomBinarySplitEnumeration(nomVals.length, maxSearch, rd);
    }
    BigInteger bestPartitionMask = null;
    boolean isBestSplitValid = false;
    double bestPartitionGain = Double.NEGATIVE_INFINITY;
    do {
        double weightLeft = 0.0;
        double ySumLeft = 0.0;
        double weightRight = 0.0;
        double ySumRight = 0.0;
        for (int i = 0; i < nomVals.length; i++) {
            final boolean isAttributeInRightBranch = splitEnumeration.isInRightBranch(i);
            if (isAttributeInRightBranch) {
                weightRight += sumWeightsAttributes[i];
                ySumRight += ySums[i];
            } else {
                weightLeft += sumWeightsAttributes[i];
                ySumLeft += ySums[i];
            }
        }
        final boolean isValidSplit = weightRight >= minChildSize && weightLeft >= minChildSize;
        double gain = ySumRight * ySumRight / weightRight + ySumLeft * ySumLeft / weightLeft - criterionTotal;
        // use random tie breaker if gains are equal
        boolean randomTieBreaker = gain == bestPartitionGain ? rd.nextInt(0, 1) == 1 : false;
        // store if better than before or first valid split
        if (gain > bestPartitionGain || (!isBestSplitValid && isValidSplit) || randomTieBreaker) {
            if (isValidSplit || !isBestSplitValid) {
                bestPartitionGain = gain;
                bestPartitionMask = splitEnumeration.getValueMask();
                isBestSplitValid = isValidSplit;
            }
        }
    } while (splitEnumeration.next());
    if (bestPartitionGain > 0.0) {
        return new NominalBinarySplitCandidate(this, bestPartitionGain, bestPartitionMask, getMissedRows(columnMemberships), NominalBinarySplitCandidate.NO_MISSINGS);
    }
    return null;
}
Also used : BigInteger(java.math.BigInteger) NominalBinarySplitCandidate(org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate)

Example 4 with TreeTargetNumericColumnData

use of org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData in project knime-core by knime.

the class TreeLearnerRegression method learnSingleTree.

/**
 * {@inheritDoc}
 */
@Override
public TreeModelRegression learnSingleTree(final ExecutionMonitor exec, final RandomData rd) throws CanceledExecutionException {
    final TreeTargetNumericColumnData targetColumn = getTargetData();
    final TreeData data = getData();
    final RowSample rowSampling = getRowSampling();
    final TreeEnsembleLearnerConfiguration config = getConfig();
    final IDataIndexManager indexManager = getIndexManager();
    DataMemberships rootDataMemberships = new RootDataMemberships(rowSampling, data, indexManager);
    RegressionPriors targetPriors = targetColumn.getPriors(rootDataMemberships, config);
    BitSet forbiddenColumnSet = new BitSet(data.getNrAttributes());
    boolean isGradientBoosting = config instanceof GradientBoostingLearnerConfiguration;
    if (isGradientBoosting) {
        m_leafs = new ArrayList<TreeNodeRegression>();
    }
    final TreeNodeSignature rootSignature = TreeNodeSignature.ROOT_SIGNATURE;
    final ColumnSample rootColumnSample = getColSamplingStrategy().getColumnSampleForTreeNode(rootSignature);
    TreeNodeRegression rootNode = buildTreeNode(exec, 0, rootDataMemberships, rootColumnSample, getSignatureFactory().getRootSignature(), targetPriors, forbiddenColumnSet);
    assert forbiddenColumnSet.cardinality() == 0;
    rootNode.setTreeNodeCondition(TreeNodeTrueCondition.INSTANCE);
    if (isGradientBoosting) {
        return new TreeModelRegression(rootNode, m_leafs);
    }
    return new TreeModelRegression(rootNode);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) ColumnSample(org.knime.base.node.mine.treeensemble2.sample.column.ColumnSample) RegressionPriors(org.knime.base.node.mine.treeensemble2.data.RegressionPriors) BitSet(java.util.BitSet) TreeTargetNumericColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData) IDataIndexManager(org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager) TreeNodeSignature(org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature) TreeNodeRegression(org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression) DataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships) RootDataMemberships(org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships) TreeModelRegression(org.knime.base.node.mine.treeensemble2.model.TreeModelRegression) GradientBoostingLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.gradientboosting.learner.GradientBoostingLearnerConfiguration) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) RowSample(org.knime.base.node.mine.treeensemble2.sample.row.RowSample)

Example 5 with TreeTargetNumericColumnData

use of org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData in project knime-core by knime.

the class TreeLearnerRegression method findBestSplitsRegression.

private SplitCandidate[] findBestSplitsRegression(final int currentDepth, final DataMemberships dataMemberships, final ColumnSample columnSample, final RegressionPriors targetPriors, final BitSet forbiddenColumnSet) {
    final TreeData data = getData();
    final RandomData rd = getRandomData();
    final TreeEnsembleLearnerConfiguration config = getConfig();
    final int maxLevels = config.getMaxLevels();
    if (maxLevels != TreeEnsembleLearnerConfiguration.MAX_LEVEL_INFINITE && currentDepth >= maxLevels) {
        return null;
    }
    final int minNodeSize = config.getMinNodeSize();
    if (minNodeSize != TreeEnsembleLearnerConfiguration.MIN_NODE_SIZE_UNDEFINED) {
        if (targetPriors.getNrRecords() < minNodeSize) {
            return null;
        }
    }
    final double priorSquaredDeviation = targetPriors.getSumSquaredDeviation();
    if (priorSquaredDeviation < TreeColumnData.EPSILON) {
        return null;
    }
    final TreeTargetNumericColumnData targetColumn = getTargetData();
    ArrayList<SplitCandidate> splitCandidates = null;
    if (currentDepth == 0 && config.getHardCodedRootColumn() != null) {
        final TreeAttributeColumnData rootColumn = data.getColumn(config.getHardCodedRootColumn());
        return new SplitCandidate[] { rootColumn.calcBestSplitRegression(dataMemberships, targetPriors, targetColumn, rd) };
    } else {
        splitCandidates = new ArrayList<SplitCandidate>(columnSample.getNumCols());
        for (TreeAttributeColumnData col : columnSample) {
            if (forbiddenColumnSet.get(col.getMetaData().getAttributeIndex())) {
                continue;
            }
            SplitCandidate currentColSplit = col.calcBestSplitRegression(dataMemberships, targetPriors, targetColumn, rd);
            if (currentColSplit != null) {
                splitCandidates.add(currentColSplit);
            }
        }
    }
    Comparator<SplitCandidate> comp = new Comparator<SplitCandidate>() {

        @Override
        public int compare(final SplitCandidate arg0, final SplitCandidate arg1) {
            int compareDouble = -Double.compare(arg0.getGainValue(), arg1.getGainValue());
            return compareDouble;
        }
    };
    if (splitCandidates.isEmpty()) {
        return null;
    }
    splitCandidates.sort(comp);
    return splitCandidates.toArray(new SplitCandidate[splitCandidates.size()]);
}
Also used : TreeEnsembleLearnerConfiguration(org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration) RandomData(org.apache.commons.math.random.RandomData) TreeAttributeColumnData(org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData) TreeTargetNumericColumnData(org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData) TreeData(org.knime.base.node.mine.treeensemble2.data.TreeData) Comparator(java.util.Comparator)

Aggregations

TreeEnsembleLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.learner.TreeEnsembleLearnerConfiguration)11 TreeTargetNumericColumnData (org.knime.base.node.mine.treeensemble2.data.TreeTargetNumericColumnData)8 DataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.DataMemberships)8 RootDataMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.RootDataMemberships)8 RandomData (org.apache.commons.math.random.RandomData)7 TreeData (org.knime.base.node.mine.treeensemble2.data.TreeData)7 Test (org.junit.Test)6 NominalBinarySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalBinarySplitCandidate)6 SplitCandidate (org.knime.base.node.mine.treeensemble2.learner.SplitCandidate)5 BitSet (java.util.BitSet)4 DefaultDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.DefaultDataIndexManager)4 NominalMultiwaySplitCandidate (org.knime.base.node.mine.treeensemble2.learner.NominalMultiwaySplitCandidate)4 TreeNodeSignature (org.knime.base.node.mine.treeensemble2.model.TreeNodeSignature)4 TreeAttributeColumnData (org.knime.base.node.mine.treeensemble2.data.TreeAttributeColumnData)3 ColumnMemberships (org.knime.base.node.mine.treeensemble2.data.memberships.ColumnMemberships)3 IDataIndexManager (org.knime.base.node.mine.treeensemble2.data.memberships.IDataIndexManager)3 TreeNodeRegression (org.knime.base.node.mine.treeensemble2.model.TreeNodeRegression)3 GradientBoostingLearnerConfiguration (org.knime.base.node.mine.treeensemble2.node.gradientboosting.learner.GradientBoostingLearnerConfiguration)3 BigInteger (java.math.BigInteger)2 HashMap (java.util.HashMap)2