Search in sources :

Example 16 with SplitEntry

use of com.tencent.angel.ml.GBDT.algo.tree.SplitEntry in project angel by Tencent.

the class GradHistHelper method findBestSplitHelper.

// find the best split result of the histogram of a tree node
public static SplitEntry findBestSplitHelper(TDoubleVector histogram) throws InterruptedException {
    LOG.info(String.format("------To find the best split of histogram size[%d]------", histogram.getDimension()));
    SplitEntry splitEntry = new SplitEntry();
    LOG.info(String.format("The best split before looping the histogram: fid[%d], fvalue[%f]", splitEntry.fid, splitEntry.fvalue));
    int featureNum = WorkerContext.get().getConf().getInt(MLConf.ML_FEATURE_NUM(), MLConf.DEFAULT_ML_FEATURE_NUM());
    int splitNum = WorkerContext.get().getConf().getInt(MLConf.ML_GBDT_SPLIT_NUM(), MLConf.DEFAULT_ML_GBDT_SPLIT_NUM());
    if (histogram.getDimension() != featureNum * 2 * splitNum) {
        LOG.info("The size of histogram is not equal to 2 * featureNum*splitNum.");
        return splitEntry;
    }
    for (int fid = 0; fid < featureNum; fid++) {
        // 2.2. get the indexes of histogram of this feature
        int startIdx = 2 * splitNum * fid;
        // 2.3. find the best split of current feature
        SplitEntry curSplit = findBestSplitOfOneFeatureHelper(fid, histogram, startIdx);
        // 2.4. update the best split result if possible
        splitEntry.update(curSplit);
    }
    LOG.info(String.format("The best split after looping the histogram: fid[%d], fvalue[%f], loss gain[%f]", splitEntry.fid, splitEntry.fvalue, splitEntry.lossChg));
    return splitEntry;
}
Also used : SplitEntry(com.tencent.angel.ml.GBDT.algo.tree.SplitEntry)

Example 17 with SplitEntry

use of com.tencent.angel.ml.GBDT.algo.tree.SplitEntry in project angel by Tencent.

the class GradHistHelper method findSplitOfFeature.

// find the best split result of one feature from a server row, used by the PS
public static SplitEntry findSplitOfFeature(int fid, ServerDenseDoubleRow row, int startIdx, GBDTParam param) {
    LOG.info(String.format("Find best split for fid[%d] in histogram size[%d], startIdx[%d]", fid, row.size(), startIdx));
    SplitEntry splitEntry = new SplitEntry();
    // 1. set the feature id
    splitEntry.setFid(fid);
    // 2. create the best left stats and right stats
    GradStats bestLeftStat = new GradStats();
    GradStats bestRightStat = new GradStats();
    GradStats rootStats = calGradStats(row, startIdx, param.numSplit);
    if (startIdx + 2 * param.numSplit <= row.getEndCol()) {
        // 3. the gain of the root node
        float rootGain = rootStats.calcGain(param);
        // 4. create the temp left and right grad stats
        GradStats leftStats = new GradStats();
        GradStats rightStats = new GradStats();
        // 5. loop over all the data in histogram
        for (int histIdx = startIdx; histIdx < startIdx + param.numSplit; histIdx++) {
            // 5.1. get the grad and hess of current hist bin
            float grad = (float) row.getData().get(histIdx);
            float hess = (float) row.getData().get(param.numSplit + histIdx);
            leftStats.add(grad, hess);
            // 5.2. check whether we can split with current left hessian
            if (leftStats.sumHess >= param.minChildWeight) {
                // right = root - left
                rightStats.setSubstract(rootStats, leftStats);
                // 5.3. check whether we can split with current right hessian
                if (rightStats.sumHess >= param.minChildWeight) {
                    // 5.4. calculate the current loss gain
                    float lossChg = leftStats.calcGain(param) + rightStats.calcGain(param) - rootGain;
                    // 5.5. check whether we should update the split result with current loss gain
                    // split rule: value <= split
                    int splitIdx = histIdx - startIdx;
                    // the task use index to find fvalue
                    if (splitEntry.update(lossChg, fid, splitIdx)) {
                        // 5.6. if should update, also update the best left and right grad stats
                        bestLeftStat.update(leftStats.sumGrad, leftStats.sumHess);
                        bestRightStat.update(rightStats.sumGrad, rightStats.sumHess);
                    }
                }
            }
        }
        // 6. set the best left and right grad stats
        splitEntry.leftGradStat = bestLeftStat;
        splitEntry.rightGradStat = bestRightStat;
    } else {
        LOG.error("index out of grad histogram size.");
    }
    return splitEntry;
}
Also used : SplitEntry(com.tencent.angel.ml.GBDT.algo.tree.SplitEntry)

Example 18 with SplitEntry

use of com.tencent.angel.ml.GBDT.algo.tree.SplitEntry in project angel by Tencent.

the class GradHistHelper method findSplitOfServerRow.

// find the best split result of a serve row on the PS
public static SplitEntry findSplitOfServerRow(ServerDenseDoubleRow row, GBDTParam param) {
    LOG.info(String.format("------To find the best split from server row[%d], cols[%d-%d]------", row.getRowId(), row.getStartCol(), row.getEndCol()));
    SplitEntry splitEntry = new SplitEntry();
    splitEntry.leftGradStat = new GradStats();
    splitEntry.rightGradStat = new GradStats();
    LOG.info(String.format("The best split before looping the histogram: fid[%d], fvalue[%f]", splitEntry.fid, splitEntry.fvalue));
    int startFid = (int) row.getStartCol() / (2 * param.numSplit);
    int endFid = ((int) row.getEndCol()) / (2 * param.numSplit) - 1;
    LOG.info(String.format("Row split col[%d-%d), start feature[%d], end feature[%d]", row.getStartCol(), row.getEndCol(), startFid, endFid));
    // 2. the fid here is the index in the sampled feature set, rather than the true feature id
    for (int i = 0; startFid + i <= endFid; i++) {
        // 2.2. get the start index in histogram of this feature
        int startIdx = 2 * param.numSplit * i;
        // 2.3. find the best split of current feature
        SplitEntry curSplit = findSplitOfFeature(startFid + i, row, startIdx, param);
        // 2.4. update the best split result if possible
        splitEntry.update(curSplit);
    }
    LOG.info(String.format("The best split after looping the histogram: fid[%d], fvalue[%f], loss gain[%f]", splitEntry.fid, splitEntry.fvalue, splitEntry.lossChg));
    return splitEntry;
}
Also used : SplitEntry(com.tencent.angel.ml.GBDT.algo.tree.SplitEntry)

Example 19 with SplitEntry

use of com.tencent.angel.ml.GBDT.algo.tree.SplitEntry in project angel by Tencent.

the class GradHistHelper method findBestSplitOfOneFeatureHelper.

// find the best split result of one feature
public static SplitEntry findBestSplitOfOneFeatureHelper(int fid, TDoubleVector histogram, int startIdx) {
    LOG.info(String.format("Find best split for fid[%d] in histogram size[%d], startIdx[%d]", fid, histogram.getDimension(), startIdx));
    int splitNum = WorkerContext.get().getConf().getInt(MLConf.ML_GBDT_SPLIT_NUM(), MLConf.DEFAULT_ML_GBDT_SPLIT_NUM());
    SplitEntry splitEntry = new SplitEntry();
    // 1. set the feature id
    // splitEntry.setFid(fid);
    // 2. create the best left stats and right stats
    GradStats bestLeftStat = new GradStats();
    GradStats bestRightStat = new GradStats();
    GradStats rootStats = calGradStats(histogram, startIdx, splitNum);
    GBDTParam param = new GBDTParam();
    if (startIdx + 2 * splitNum <= histogram.getDimension()) {
        // 3. the gain of the root node
        float rootGain = rootStats.calcGain(param);
        LOG.info(String.format("Feature[%d]: sumGrad[%f], sumHess[%f], gain[%f]", fid, rootStats.sumGrad, rootStats.sumHess, rootGain));
        // 4. create the temp left and right grad stats
        GradStats leftStats = new GradStats();
        GradStats rightStats = new GradStats();
        // 5. loop over all the data in histogram
        for (int histIdx = startIdx; histIdx < startIdx + splitNum - 1; histIdx++) {
            // 5.1. get the grad and hess of current hist bin
            float grad = (float) histogram.get(histIdx);
            float hess = (float) histogram.get(splitNum + histIdx);
            leftStats.add(grad, hess);
            // 5.2. check whether we can split with current left hessian
            if (leftStats.sumHess >= param.minChildWeight) {
                // right = root - left
                rightStats.setSubstract(rootStats, leftStats);
                // 5.3. check whether we can split with current right hessian
                if (rightStats.sumHess >= param.minChildWeight) {
                    // 5.4. calculate the current loss gain
                    float lossChg = leftStats.calcGain(param) + rightStats.calcGain(param) - rootGain;
                    // 5.5. check whether we should update the split result with current loss gain
                    int splitIdx = histIdx - startIdx + 1;
                    if (splitEntry.update(lossChg, fid, splitIdx)) {
                        // 5.6. if should update, also update the best left and right grad stats
                        bestLeftStat.update(leftStats.sumGrad, leftStats.sumHess);
                        bestRightStat.update(rightStats.sumGrad, rightStats.sumHess);
                    }
                }
            }
        }
        // 6. set the best left and right grad stats
        splitEntry.leftGradStat = bestLeftStat;
        splitEntry.rightGradStat = bestRightStat;
        LOG.info(String.format("Find best split for fid[%d], split feature[%d]: split index[%f], lossChg[%f]", fid, splitEntry.fid, splitEntry.fvalue, splitEntry.lossChg));
    } else {
        LOG.error("index out of grad histogram size.");
    }
    return splitEntry;
}
Also used : SplitEntry(com.tencent.angel.ml.GBDT.algo.tree.SplitEntry) GBDTParam(com.tencent.angel.ml.param.GBDTParam)

Aggregations

SplitEntry (com.tencent.angel.ml.GBDT.algo.tree.SplitEntry)19 GradStats (com.tencent.angel.ml.GBDT.algo.RegTree.GradStats)2 RegTNodeStat (com.tencent.angel.ml.GBDT.algo.RegTree.RegTNodeStat)2 TNode (com.tencent.angel.ml.GBDT.algo.tree.TNode)2 GBDTParam (com.tencent.angel.ml.GBDT.param.GBDTParam)2 ServerIntDoubleRow (com.tencent.angel.ps.storage.vector.ServerIntDoubleRow)2 GBDTGradHistGetRowFunc (com.tencent.angel.ml.GBDT.psf.GBDTGradHistGetRowFunc)1 HistAggrParam (com.tencent.angel.ml.GBDT.psf.HistAggrParam)1 IntDoubleDenseVectorStorage (com.tencent.angel.ml.math2.storage.IntDoubleDenseVectorStorage)1 IntIntDenseVectorStorage (com.tencent.angel.ml.math2.storage.IntIntDenseVectorStorage)1 IntDoubleVector (com.tencent.angel.ml.math2.vector.IntDoubleVector)1 IntIntVector (com.tencent.angel.ml.math2.vector.IntIntVector)1 PartitionGetRowResult (com.tencent.angel.ml.matrix.psf.get.getrow.PartitionGetRowResult)1 PSModel (com.tencent.angel.ml.model.PSModel)1 GBDTParam (com.tencent.angel.ml.param.GBDTParam)1 ServerRow (com.tencent.angel.ps.storage.vector.ServerRow)1 ArrayList (java.util.ArrayList)1 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)1