use of com.alibaba.alink.operator.common.tree.parallelcart.booster.Booster in project Alink by alibaba.
the class Boosting method calc.
@Override
public void calc(ComContext context) {
BoostingObjs boostingObjs = context.getObj(InitBoostingObjs.BOOSTING_OBJS);
if (boostingObjs.inWeakLearner) {
return;
}
LOG.info("taskId: {}, {} start", context.getTaskId(), Boosting.class.getSimpleName());
if (context.getStepNo() == 1) {
Booster booster;
if (LossUtils.isRanking(boostingObjs.params.get(LossUtils.LOSS_TYPE))) {
booster = BoosterFactory.createRankingBooster(boostingObjs.params.get(BoosterType.BOOSTER_TYPE), boostingObjs.rankingLoss, boostingObjs.data.getQueryIdOffset(), boostingObjs.data.getWeights(), new Slice(0, boostingObjs.data.getQueryIdOffset().length - 1), new Slice(0, boostingObjs.data.getM()));
} else {
booster = BoosterFactory.createBooster(boostingObjs.params.get(BoosterType.BOOSTER_TYPE), boostingObjs.loss, boostingObjs.data.getWeights(), new Slice(0, boostingObjs.data.getM()));
}
context.putObj(BOOSTER, booster);
}
context.<Booster>getObj(BOOSTER).boosting(boostingObjs, boostingObjs.data.getLabels(), boostingObjs.pred);
boostingObjs.numBoosting++;
LOG.info("taskId: {}, {} end", context.getTaskId(), Boosting.class.getSimpleName());
}
use of com.alibaba.alink.operator.common.tree.parallelcart.booster.Booster in project Alink by alibaba.
the class ConstructLocalHistogram method calc.
@Override
public void calc(ComContext context) {
LOG.info("taskId: {}, {} start", context.getTaskId(), ConstructLocalHistogram.class.getSimpleName());
BoostingObjs boostingObjs = context.getObj("boostingObjs");
Booster booster = context.getObj("booster");
HistogramBaseTreeObjs tree = context.getObj("tree");
if (context.getStepNo() == 1) {
LOG.info("maxDepth: {}, maxLeaves: {}", boostingObjs.params.get(GbdtRegTrainParams.MAX_DEPTH), boostingObjs.params.get(GbdtRegTrainParams.MAX_LEAVES));
int histogramLen = Math.min(tree.maxNodeSize * tree.maxFeatureBins * STEP * boostingObjs.numBaggingFeatures, tree.maxNodeSize * STEP * tree.allFeatureBins);
int featureSplitHistogramLen = tree.maxNodeSize * STEP * tree.allFeatureBins;
context.putObj("histogram", new double[histogramLen]);
context.putObj("recvcnts", new int[context.getNumTask()]);
featureSplitHistogram = new double[featureSplitHistogramLen];
featureValid = new BitSet(boostingObjs.data.getN());
aligned = new int[boostingObjs.data.getM()];
validFeatureOffset = new int[boostingObjs.data.getN()];
results = new Future[boostingObjs.data.getN()];
useInstanceCount = LossUtils.useInstanceCount(boostingObjs.params.get(LossUtils.LOSS_TYPE));
}
if (!boostingObjs.inWeakLearner) {
tree.initPerTree(boostingObjs, context.getObj(BuildLocalSketch.SKETCH));
boostingObjs.inWeakLearner = true;
}
double[] histogram = context.getObj("histogram");
int sumFeatureCount = calcWithNodeIdCache(context, boostingObjs, booster, tree, histogram);
int[] recvcnts = context.getObj("recvcnts");
Arrays.fill(recvcnts, 0);
// split the histogram by feature
int taskPos = 0;
int featureCnt = 0;
int next = (int) (distributedInfo.startPos(taskPos, context.getNumTask(), sumFeatureCount) + distributedInfo.localRowCnt(taskPos, context.getNumTask(), sumFeatureCount));
for (NodeInfoPair item : tree.queue) {
final int[] smallBaggingFeatures = item.small.baggingFeatures;
for (int smallBaggingFeature : smallBaggingFeatures) {
featureCnt++;
while (featureCnt > next) {
taskPos++;
next = (int) (distributedInfo.startPos(taskPos, context.getNumTask(), sumFeatureCount) + distributedInfo.localRowCnt(taskPos, context.getNumTask(), sumFeatureCount));
}
recvcnts[taskPos] += DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[smallBaggingFeature], tree.useMissing) * STEP;
}
if (item.big != null) {
final int[] bigBaggingFeatures = item.big.baggingFeatures;
for (int bigBaggingFeature : bigBaggingFeatures) {
featureCnt++;
while (featureCnt > next) {
taskPos++;
next = (int) (distributedInfo.startPos(taskPos, context.getNumTask(), sumFeatureCount) + distributedInfo.localRowCnt(taskPos, context.getNumTask(), sumFeatureCount));
}
recvcnts[taskPos] += DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[bigBaggingFeature], tree.useMissing) * STEP;
}
}
}
LOG.info("taskId: {}, {} end", context.getTaskId(), ConstructLocalHistogram.class.getSimpleName());
}