Search in sources :

Example 1 with Booster

use of com.alibaba.alink.operator.common.tree.parallelcart.booster.Booster in project Alink by alibaba.

the class Boosting method calc.

@Override
public void calc(ComContext context) {
    BoostingObjs boostingObjs = context.getObj(InitBoostingObjs.BOOSTING_OBJS);
    if (boostingObjs.inWeakLearner) {
        return;
    }
    LOG.info("taskId: {}, {} start", context.getTaskId(), Boosting.class.getSimpleName());
    if (context.getStepNo() == 1) {
        Booster booster;
        if (LossUtils.isRanking(boostingObjs.params.get(LossUtils.LOSS_TYPE))) {
            booster = BoosterFactory.createRankingBooster(boostingObjs.params.get(BoosterType.BOOSTER_TYPE), boostingObjs.rankingLoss, boostingObjs.data.getQueryIdOffset(), boostingObjs.data.getWeights(), new Slice(0, boostingObjs.data.getQueryIdOffset().length - 1), new Slice(0, boostingObjs.data.getM()));
        } else {
            booster = BoosterFactory.createBooster(boostingObjs.params.get(BoosterType.BOOSTER_TYPE), boostingObjs.loss, boostingObjs.data.getWeights(), new Slice(0, boostingObjs.data.getM()));
        }
        context.putObj(BOOSTER, booster);
    }
    context.<Booster>getObj(BOOSTER).boosting(boostingObjs, boostingObjs.data.getLabels(), boostingObjs.pred);
    boostingObjs.numBoosting++;
    LOG.info("taskId: {}, {} end", context.getTaskId(), Boosting.class.getSimpleName());
}
Also used : Slice(com.alibaba.alink.operator.common.tree.parallelcart.data.Slice) Booster(com.alibaba.alink.operator.common.tree.parallelcart.booster.Booster)

Example 2 with Booster

use of com.alibaba.alink.operator.common.tree.parallelcart.booster.Booster in project Alink by alibaba.

the class ConstructLocalHistogram method calc.

@Override
public void calc(ComContext context) {
    LOG.info("taskId: {}, {} start", context.getTaskId(), ConstructLocalHistogram.class.getSimpleName());
    BoostingObjs boostingObjs = context.getObj("boostingObjs");
    Booster booster = context.getObj("booster");
    HistogramBaseTreeObjs tree = context.getObj("tree");
    if (context.getStepNo() == 1) {
        LOG.info("maxDepth: {}, maxLeaves: {}", boostingObjs.params.get(GbdtRegTrainParams.MAX_DEPTH), boostingObjs.params.get(GbdtRegTrainParams.MAX_LEAVES));
        int histogramLen = Math.min(tree.maxNodeSize * tree.maxFeatureBins * STEP * boostingObjs.numBaggingFeatures, tree.maxNodeSize * STEP * tree.allFeatureBins);
        int featureSplitHistogramLen = tree.maxNodeSize * STEP * tree.allFeatureBins;
        context.putObj("histogram", new double[histogramLen]);
        context.putObj("recvcnts", new int[context.getNumTask()]);
        featureSplitHistogram = new double[featureSplitHistogramLen];
        featureValid = new BitSet(boostingObjs.data.getN());
        aligned = new int[boostingObjs.data.getM()];
        validFeatureOffset = new int[boostingObjs.data.getN()];
        results = new Future[boostingObjs.data.getN()];
        useInstanceCount = LossUtils.useInstanceCount(boostingObjs.params.get(LossUtils.LOSS_TYPE));
    }
    if (!boostingObjs.inWeakLearner) {
        tree.initPerTree(boostingObjs, context.getObj(BuildLocalSketch.SKETCH));
        boostingObjs.inWeakLearner = true;
    }
    double[] histogram = context.getObj("histogram");
    int sumFeatureCount = calcWithNodeIdCache(context, boostingObjs, booster, tree, histogram);
    int[] recvcnts = context.getObj("recvcnts");
    Arrays.fill(recvcnts, 0);
    // split the histogram by feature
    int taskPos = 0;
    int featureCnt = 0;
    int next = (int) (distributedInfo.startPos(taskPos, context.getNumTask(), sumFeatureCount) + distributedInfo.localRowCnt(taskPos, context.getNumTask(), sumFeatureCount));
    for (NodeInfoPair item : tree.queue) {
        final int[] smallBaggingFeatures = item.small.baggingFeatures;
        for (int smallBaggingFeature : smallBaggingFeatures) {
            featureCnt++;
            while (featureCnt > next) {
                taskPos++;
                next = (int) (distributedInfo.startPos(taskPos, context.getNumTask(), sumFeatureCount) + distributedInfo.localRowCnt(taskPos, context.getNumTask(), sumFeatureCount));
            }
            recvcnts[taskPos] += DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[smallBaggingFeature], tree.useMissing) * STEP;
        }
        if (item.big != null) {
            final int[] bigBaggingFeatures = item.big.baggingFeatures;
            for (int bigBaggingFeature : bigBaggingFeatures) {
                featureCnt++;
                while (featureCnt > next) {
                    taskPos++;
                    next = (int) (distributedInfo.startPos(taskPos, context.getNumTask(), sumFeatureCount) + distributedInfo.localRowCnt(taskPos, context.getNumTask(), sumFeatureCount));
                }
                recvcnts[taskPos] += DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[bigBaggingFeature], tree.useMissing) * STEP;
            }
        }
    }
    LOG.info("taskId: {}, {} end", context.getTaskId(), ConstructLocalHistogram.class.getSimpleName());
}
Also used : Booster(com.alibaba.alink.operator.common.tree.parallelcart.booster.Booster) BitSet(java.util.BitSet)

Aggregations

Booster (com.alibaba.alink.operator.common.tree.parallelcart.booster.Booster)2 Slice (com.alibaba.alink.operator.common.tree.parallelcart.data.Slice)1 BitSet (java.util.BitSet)1