use of com.tencent.angel.ml.GBDT.algo.tree.SplitEntry in project angel by Tencent.
the class GradHistHelper method findBestSplitHelper.
// find the best split result of the histogram of a tree node
public static SplitEntry findBestSplitHelper(TDoubleVector histogram) throws InterruptedException {
LOG.info(String.format("------To find the best split of histogram size[%d]------", histogram.getDimension()));
SplitEntry splitEntry = new SplitEntry();
LOG.info(String.format("The best split before looping the histogram: fid[%d], fvalue[%f]", splitEntry.fid, splitEntry.fvalue));
int featureNum = WorkerContext.get().getConf().getInt(MLConf.ML_FEATURE_NUM(), MLConf.DEFAULT_ML_FEATURE_NUM());
int splitNum = WorkerContext.get().getConf().getInt(MLConf.ML_GBDT_SPLIT_NUM(), MLConf.DEFAULT_ML_GBDT_SPLIT_NUM());
if (histogram.getDimension() != featureNum * 2 * splitNum) {
LOG.info("The size of histogram is not equal to 2 * featureNum*splitNum.");
return splitEntry;
}
for (int fid = 0; fid < featureNum; fid++) {
// 2.2. get the indexes of histogram of this feature
int startIdx = 2 * splitNum * fid;
// 2.3. find the best split of current feature
SplitEntry curSplit = findBestSplitOfOneFeatureHelper(fid, histogram, startIdx);
// 2.4. update the best split result if possible
splitEntry.update(curSplit);
}
LOG.info(String.format("The best split after looping the histogram: fid[%d], fvalue[%f], loss gain[%f]", splitEntry.fid, splitEntry.fvalue, splitEntry.lossChg));
return splitEntry;
}
use of com.tencent.angel.ml.GBDT.algo.tree.SplitEntry in project angel by Tencent.
the class GradHistHelper method findSplitOfFeature.
// find the best split result of one feature from a server row, used by the PS
public static SplitEntry findSplitOfFeature(int fid, ServerDenseDoubleRow row, int startIdx, GBDTParam param) {
LOG.info(String.format("Find best split for fid[%d] in histogram size[%d], startIdx[%d]", fid, row.size(), startIdx));
SplitEntry splitEntry = new SplitEntry();
// 1. set the feature id
splitEntry.setFid(fid);
// 2. create the best left stats and right stats
GradStats bestLeftStat = new GradStats();
GradStats bestRightStat = new GradStats();
GradStats rootStats = calGradStats(row, startIdx, param.numSplit);
if (startIdx + 2 * param.numSplit <= row.getEndCol()) {
// 3. the gain of the root node
float rootGain = rootStats.calcGain(param);
// 4. create the temp left and right grad stats
GradStats leftStats = new GradStats();
GradStats rightStats = new GradStats();
// 5. loop over all the data in histogram
for (int histIdx = startIdx; histIdx < startIdx + param.numSplit; histIdx++) {
// 5.1. get the grad and hess of current hist bin
float grad = (float) row.getData().get(histIdx);
float hess = (float) row.getData().get(param.numSplit + histIdx);
leftStats.add(grad, hess);
// 5.2. check whether we can split with current left hessian
if (leftStats.sumHess >= param.minChildWeight) {
// right = root - left
rightStats.setSubstract(rootStats, leftStats);
// 5.3. check whether we can split with current right hessian
if (rightStats.sumHess >= param.minChildWeight) {
// 5.4. calculate the current loss gain
float lossChg = leftStats.calcGain(param) + rightStats.calcGain(param) - rootGain;
// 5.5. check whether we should update the split result with current loss gain
// split rule: value <= split
int splitIdx = histIdx - startIdx;
// the task use index to find fvalue
if (splitEntry.update(lossChg, fid, splitIdx)) {
// 5.6. if should update, also update the best left and right grad stats
bestLeftStat.update(leftStats.sumGrad, leftStats.sumHess);
bestRightStat.update(rightStats.sumGrad, rightStats.sumHess);
}
}
}
}
// 6. set the best left and right grad stats
splitEntry.leftGradStat = bestLeftStat;
splitEntry.rightGradStat = bestRightStat;
} else {
LOG.error("index out of grad histogram size.");
}
return splitEntry;
}
use of com.tencent.angel.ml.GBDT.algo.tree.SplitEntry in project angel by Tencent.
the class GradHistHelper method findSplitOfServerRow.
// find the best split result of a serve row on the PS
public static SplitEntry findSplitOfServerRow(ServerDenseDoubleRow row, GBDTParam param) {
LOG.info(String.format("------To find the best split from server row[%d], cols[%d-%d]------", row.getRowId(), row.getStartCol(), row.getEndCol()));
SplitEntry splitEntry = new SplitEntry();
splitEntry.leftGradStat = new GradStats();
splitEntry.rightGradStat = new GradStats();
LOG.info(String.format("The best split before looping the histogram: fid[%d], fvalue[%f]", splitEntry.fid, splitEntry.fvalue));
int startFid = (int) row.getStartCol() / (2 * param.numSplit);
int endFid = ((int) row.getEndCol()) / (2 * param.numSplit) - 1;
LOG.info(String.format("Row split col[%d-%d), start feature[%d], end feature[%d]", row.getStartCol(), row.getEndCol(), startFid, endFid));
// 2. the fid here is the index in the sampled feature set, rather than the true feature id
for (int i = 0; startFid + i <= endFid; i++) {
// 2.2. get the start index in histogram of this feature
int startIdx = 2 * param.numSplit * i;
// 2.3. find the best split of current feature
SplitEntry curSplit = findSplitOfFeature(startFid + i, row, startIdx, param);
// 2.4. update the best split result if possible
splitEntry.update(curSplit);
}
LOG.info(String.format("The best split after looping the histogram: fid[%d], fvalue[%f], loss gain[%f]", splitEntry.fid, splitEntry.fvalue, splitEntry.lossChg));
return splitEntry;
}
use of com.tencent.angel.ml.GBDT.algo.tree.SplitEntry in project angel by Tencent.
the class GradHistHelper method findBestSplitOfOneFeatureHelper.
// find the best split result of one feature
public static SplitEntry findBestSplitOfOneFeatureHelper(int fid, TDoubleVector histogram, int startIdx) {
LOG.info(String.format("Find best split for fid[%d] in histogram size[%d], startIdx[%d]", fid, histogram.getDimension(), startIdx));
int splitNum = WorkerContext.get().getConf().getInt(MLConf.ML_GBDT_SPLIT_NUM(), MLConf.DEFAULT_ML_GBDT_SPLIT_NUM());
SplitEntry splitEntry = new SplitEntry();
// 1. set the feature id
// splitEntry.setFid(fid);
// 2. create the best left stats and right stats
GradStats bestLeftStat = new GradStats();
GradStats bestRightStat = new GradStats();
GradStats rootStats = calGradStats(histogram, startIdx, splitNum);
GBDTParam param = new GBDTParam();
if (startIdx + 2 * splitNum <= histogram.getDimension()) {
// 3. the gain of the root node
float rootGain = rootStats.calcGain(param);
LOG.info(String.format("Feature[%d]: sumGrad[%f], sumHess[%f], gain[%f]", fid, rootStats.sumGrad, rootStats.sumHess, rootGain));
// 4. create the temp left and right grad stats
GradStats leftStats = new GradStats();
GradStats rightStats = new GradStats();
// 5. loop over all the data in histogram
for (int histIdx = startIdx; histIdx < startIdx + splitNum - 1; histIdx++) {
// 5.1. get the grad and hess of current hist bin
float grad = (float) histogram.get(histIdx);
float hess = (float) histogram.get(splitNum + histIdx);
leftStats.add(grad, hess);
// 5.2. check whether we can split with current left hessian
if (leftStats.sumHess >= param.minChildWeight) {
// right = root - left
rightStats.setSubstract(rootStats, leftStats);
// 5.3. check whether we can split with current right hessian
if (rightStats.sumHess >= param.minChildWeight) {
// 5.4. calculate the current loss gain
float lossChg = leftStats.calcGain(param) + rightStats.calcGain(param) - rootGain;
// 5.5. check whether we should update the split result with current loss gain
int splitIdx = histIdx - startIdx + 1;
if (splitEntry.update(lossChg, fid, splitIdx)) {
// 5.6. if should update, also update the best left and right grad stats
bestLeftStat.update(leftStats.sumGrad, leftStats.sumHess);
bestRightStat.update(rightStats.sumGrad, rightStats.sumHess);
}
}
}
}
// 6. set the best left and right grad stats
splitEntry.leftGradStat = bestLeftStat;
splitEntry.rightGradStat = bestRightStat;
LOG.info(String.format("Find best split for fid[%d], split feature[%d]: split index[%f], lossChg[%f]", fid, splitEntry.fid, splitEntry.fvalue, splitEntry.lossChg));
} else {
LOG.error("index out of grad histogram size.");
}
return splitEntry;
}
Aggregations