Search in sources :

Example 1 with Slice

use of com.alibaba.alink.operator.common.tree.parallelcart.data.Slice in project Alink by alibaba.

the class HistogramBaseTreeObjs method initPerTree.

void initPerTree(BoostingObjs boostingObjs, EpsilonApproQuantile.WQSummary[] summaries) {
    assert queue.isEmpty();
    Node root = new Node();
    queue.push(new NodeInfoPair(new NodeInfoPair.NodeInfo(root, new Slice(0, boostingObjs.numBaggingInstances), new Slice(boostingObjs.numBaggingInstances, boostingObjs.data.getM()), 1, null), null));
    roots.add(root);
    leaves.clear();
    this.summaries = summaries;
}
Also used : Slice(com.alibaba.alink.operator.common.tree.parallelcart.data.Slice) Node(com.alibaba.alink.operator.common.tree.Node)

Example 2 with Slice

use of com.alibaba.alink.operator.common.tree.parallelcart.data.Slice in project Alink by alibaba.

the class Boosting method calc.

@Override
public void calc(ComContext context) {
    BoostingObjs boostingObjs = context.getObj(InitBoostingObjs.BOOSTING_OBJS);
    if (boostingObjs.inWeakLearner) {
        return;
    }
    LOG.info("taskId: {}, {} start", context.getTaskId(), Boosting.class.getSimpleName());
    if (context.getStepNo() == 1) {
        Booster booster;
        if (LossUtils.isRanking(boostingObjs.params.get(LossUtils.LOSS_TYPE))) {
            booster = BoosterFactory.createRankingBooster(boostingObjs.params.get(BoosterType.BOOSTER_TYPE), boostingObjs.rankingLoss, boostingObjs.data.getQueryIdOffset(), boostingObjs.data.getWeights(), new Slice(0, boostingObjs.data.getQueryIdOffset().length - 1), new Slice(0, boostingObjs.data.getM()));
        } else {
            booster = BoosterFactory.createBooster(boostingObjs.params.get(BoosterType.BOOSTER_TYPE), boostingObjs.loss, boostingObjs.data.getWeights(), new Slice(0, boostingObjs.data.getM()));
        }
        context.putObj(BOOSTER, booster);
    }
    context.<Booster>getObj(BOOSTER).boosting(boostingObjs, boostingObjs.data.getLabels(), boostingObjs.pred);
    boostingObjs.numBoosting++;
    LOG.info("taskId: {}, {} end", context.getTaskId(), Boosting.class.getSimpleName());
}
Also used : Slice(com.alibaba.alink.operator.common.tree.parallelcart.data.Slice) Booster(com.alibaba.alink.operator.common.tree.parallelcart.booster.Booster)

Example 3 with Slice

use of com.alibaba.alink.operator.common.tree.parallelcart.data.Slice in project Alink by alibaba.

the class CalcFeatureGain method calc.

@Override
public void calc(ComContext context) {
    LOG.info("taskId: {}, {} start", context.getTaskId(), CalcFeatureGain.class.getSimpleName());
    BoostingObjs boostingObjs = context.getObj("boostingObjs");
    HistogramBaseTreeObjs tree = context.getObj("tree");
    double[] histogram = context.getObj("histogram");
    if (context.getStepNo() == 1) {
        context.putObj("best", new Node[tree.maxNodeSize]);
        featureSplitters = new HistogramFeatureSplitter[boostingObjs.data.getN()];
        for (int i = 0; i < boostingObjs.data.getN(); ++i) {
            featureSplitters[i] = createFeatureSplitter(boostingObjs.data.getFeatureMetas()[i].getType() == FeatureMeta.FeatureType.CATEGORICAL, boostingObjs.params, boostingObjs.data.getFeatureMetas()[i], tree.compareIndex4Categorical);
        }
    }
    int sumFeatureCount = 0;
    for (NodeInfoPair item : tree.queue) {
        sumFeatureCount += boostingObjs.numBaggingFeatures;
        if (item.big != null) {
            sumFeatureCount += boostingObjs.numBaggingFeatures;
        }
    }
    DistributedInfo distributedInfo = new DefaultDistributedInfo();
    int start = (int) distributedInfo.startPos(context.getTaskId(), context.getNumTask(), sumFeatureCount);
    int cnt = (int) distributedInfo.localRowCnt(context.getTaskId(), context.getNumTask(), sumFeatureCount);
    int end = start + cnt;
    int featureCnt = 0;
    int featureBinCnt = 0;
    Node[] best = context.getObj("best");
    int index = 0;
    for (NodeInfoPair item : tree.queue) {
        best[index] = null;
        final int[] smallBaggingFeatures = item.small.baggingFeatures;
        for (int smallBaggingFeature : smallBaggingFeatures) {
            if (featureCnt >= start && featureCnt < end) {
                featureSplitters[smallBaggingFeature].reset(histogram, new Slice(featureBinCnt, featureBinCnt + DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[smallBaggingFeature], tree.useMissing)), item.small.depth);
                double gain = featureSplitters[smallBaggingFeature].bestSplit(tree.leaves.size());
                if (best[index] == null || (featureSplitters[smallBaggingFeature].canSplit() && gain > best[index].getGain())) {
                    best[index] = new Node();
                    featureSplitters[smallBaggingFeature].fillNode(best[index]);
                }
                featureBinCnt += DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[smallBaggingFeature], tree.useMissing);
            }
            featureCnt++;
        }
        index++;
        if (item.big != null) {
            best[index] = null;
            final int[] bigBaggingFeatures = item.big.baggingFeatures;
            for (int bigBaggingFeature : bigBaggingFeatures) {
                if (featureCnt >= start && featureCnt < end) {
                    featureSplitters[bigBaggingFeature].reset(histogram, new Slice(featureBinCnt, featureBinCnt + DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[bigBaggingFeature], tree.useMissing)), item.big.depth);
                    double gain = featureSplitters[bigBaggingFeature].bestSplit(tree.leaves.size());
                    if (best[index] == null || (featureSplitters[bigBaggingFeature].canSplit() && gain > best[index].getGain())) {
                        best[index] = new Node();
                        featureSplitters[bigBaggingFeature].fillNode(best[index]);
                    }
                    featureBinCnt += DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[bigBaggingFeature], tree.useMissing);
                }
                featureCnt++;
            }
            index++;
        }
    }
    context.putObj("bestLength", index);
    LOG.info("taskId: {}, {} end", context.getTaskId(), CalcFeatureGain.class.getSimpleName());
}
Also used : DefaultDistributedInfo(com.alibaba.alink.common.io.directreader.DefaultDistributedInfo) Node(com.alibaba.alink.operator.common.tree.Node) DistributedInfo(com.alibaba.alink.common.io.directreader.DistributedInfo) DefaultDistributedInfo(com.alibaba.alink.common.io.directreader.DefaultDistributedInfo) Slice(com.alibaba.alink.operator.common.tree.parallelcart.data.Slice)

Example 4 with Slice

use of com.alibaba.alink.operator.common.tree.parallelcart.data.Slice in project Alink by alibaba.

the class SplitInstances method split.

public static NodeInfoPair split(NodeInfoPair.NodeInfo nodeInfo, EpsilonApproQuantile.WQSummary summary, int[] indices, Data data) {
    int mid = data.splitInstances(nodeInfo.node, summary, indices, nodeInfo.slice);
    int oobMid = data.splitInstances(nodeInfo.node, summary, indices, nodeInfo.oob);
    nodeInfo.node.setNextNodes(new Node[] { new Node(), new Node() });
    return new NodeInfoPair(// left
    new NodeInfoPair.NodeInfo(nodeInfo.node.getNextNodes()[0], new Slice(nodeInfo.slice.start, mid), new Slice(nodeInfo.oob.start, oobMid), nodeInfo.depth + 1, nodeInfo.baggingFeatures), // right
    new NodeInfoPair.NodeInfo(nodeInfo.node.getNextNodes()[1], new Slice(mid, nodeInfo.slice.end), new Slice(oobMid, nodeInfo.oob.end), nodeInfo.depth + 1, null));
}
Also used : Slice(com.alibaba.alink.operator.common.tree.parallelcart.data.Slice) Node(com.alibaba.alink.operator.common.tree.Node)

Example 5 with Slice

use of com.alibaba.alink.operator.common.tree.parallelcart.data.Slice in project Alink by alibaba.

the class BoostingTest method testGradient.

@Test
public void testGradient() {
    GradientBaseBooster booster = new GradientBaseBooster(new LogLoss(2.0, 1.0), new double[] { 1.0, 1.0, 1.0 }, new Slice(0, 3));
    booster.boosting(null, new double[] { 1, 1, 0 }, new double[] { 0.5, 0.5, 0.1 });
    Assert.assertArrayEquals(new double[] { 0.3775406687981454, 0.3775406687981454, -0.52497918747894 }, booster.getGradients(), 1e-6);
    Assert.assertArrayEquals(new double[] { 0.1425369565965509, 0.1425369565965509, 0.27560314728604807 }, booster.getGradientsSqr(), 1e-6);
    Assert.assertNull(booster.getHessions());
    Assert.assertArrayEquals(new double[] { 1.0, 1.0, 1.0 }, booster.getWeights(), 1e-6);
}
Also used : LogLoss(com.alibaba.alink.operator.common.tree.parallelcart.loss.LogLoss) Slice(com.alibaba.alink.operator.common.tree.parallelcart.data.Slice) GradientBaseBooster(com.alibaba.alink.operator.common.tree.parallelcart.booster.GradientBaseBooster) Test(org.junit.Test)

Aggregations

Slice (com.alibaba.alink.operator.common.tree.parallelcart.data.Slice)5 Node (com.alibaba.alink.operator.common.tree.Node)3 DefaultDistributedInfo (com.alibaba.alink.common.io.directreader.DefaultDistributedInfo)1 DistributedInfo (com.alibaba.alink.common.io.directreader.DistributedInfo)1 Booster (com.alibaba.alink.operator.common.tree.parallelcart.booster.Booster)1 GradientBaseBooster (com.alibaba.alink.operator.common.tree.parallelcart.booster.GradientBaseBooster)1 LogLoss (com.alibaba.alink.operator.common.tree.parallelcart.loss.LogLoss)1 Test (org.junit.Test)1