use of com.alibaba.alink.operator.common.tree.parallelcart.data.Slice in project Alink by alibaba.
the class HistogramBaseTreeObjs method initPerTree.
void initPerTree(BoostingObjs boostingObjs, EpsilonApproQuantile.WQSummary[] summaries) {
assert queue.isEmpty();
Node root = new Node();
queue.push(new NodeInfoPair(new NodeInfoPair.NodeInfo(root, new Slice(0, boostingObjs.numBaggingInstances), new Slice(boostingObjs.numBaggingInstances, boostingObjs.data.getM()), 1, null), null));
roots.add(root);
leaves.clear();
this.summaries = summaries;
}
use of com.alibaba.alink.operator.common.tree.parallelcart.data.Slice in project Alink by alibaba.
the class Boosting method calc.
@Override
public void calc(ComContext context) {
BoostingObjs boostingObjs = context.getObj(InitBoostingObjs.BOOSTING_OBJS);
if (boostingObjs.inWeakLearner) {
return;
}
LOG.info("taskId: {}, {} start", context.getTaskId(), Boosting.class.getSimpleName());
if (context.getStepNo() == 1) {
Booster booster;
if (LossUtils.isRanking(boostingObjs.params.get(LossUtils.LOSS_TYPE))) {
booster = BoosterFactory.createRankingBooster(boostingObjs.params.get(BoosterType.BOOSTER_TYPE), boostingObjs.rankingLoss, boostingObjs.data.getQueryIdOffset(), boostingObjs.data.getWeights(), new Slice(0, boostingObjs.data.getQueryIdOffset().length - 1), new Slice(0, boostingObjs.data.getM()));
} else {
booster = BoosterFactory.createBooster(boostingObjs.params.get(BoosterType.BOOSTER_TYPE), boostingObjs.loss, boostingObjs.data.getWeights(), new Slice(0, boostingObjs.data.getM()));
}
context.putObj(BOOSTER, booster);
}
context.<Booster>getObj(BOOSTER).boosting(boostingObjs, boostingObjs.data.getLabels(), boostingObjs.pred);
boostingObjs.numBoosting++;
LOG.info("taskId: {}, {} end", context.getTaskId(), Boosting.class.getSimpleName());
}
use of com.alibaba.alink.operator.common.tree.parallelcart.data.Slice in project Alink by alibaba.
the class CalcFeatureGain method calc.
@Override
public void calc(ComContext context) {
LOG.info("taskId: {}, {} start", context.getTaskId(), CalcFeatureGain.class.getSimpleName());
BoostingObjs boostingObjs = context.getObj("boostingObjs");
HistogramBaseTreeObjs tree = context.getObj("tree");
double[] histogram = context.getObj("histogram");
if (context.getStepNo() == 1) {
context.putObj("best", new Node[tree.maxNodeSize]);
featureSplitters = new HistogramFeatureSplitter[boostingObjs.data.getN()];
for (int i = 0; i < boostingObjs.data.getN(); ++i) {
featureSplitters[i] = createFeatureSplitter(boostingObjs.data.getFeatureMetas()[i].getType() == FeatureMeta.FeatureType.CATEGORICAL, boostingObjs.params, boostingObjs.data.getFeatureMetas()[i], tree.compareIndex4Categorical);
}
}
int sumFeatureCount = 0;
for (NodeInfoPair item : tree.queue) {
sumFeatureCount += boostingObjs.numBaggingFeatures;
if (item.big != null) {
sumFeatureCount += boostingObjs.numBaggingFeatures;
}
}
DistributedInfo distributedInfo = new DefaultDistributedInfo();
int start = (int) distributedInfo.startPos(context.getTaskId(), context.getNumTask(), sumFeatureCount);
int cnt = (int) distributedInfo.localRowCnt(context.getTaskId(), context.getNumTask(), sumFeatureCount);
int end = start + cnt;
int featureCnt = 0;
int featureBinCnt = 0;
Node[] best = context.getObj("best");
int index = 0;
for (NodeInfoPair item : tree.queue) {
best[index] = null;
final int[] smallBaggingFeatures = item.small.baggingFeatures;
for (int smallBaggingFeature : smallBaggingFeatures) {
if (featureCnt >= start && featureCnt < end) {
featureSplitters[smallBaggingFeature].reset(histogram, new Slice(featureBinCnt, featureBinCnt + DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[smallBaggingFeature], tree.useMissing)), item.small.depth);
double gain = featureSplitters[smallBaggingFeature].bestSplit(tree.leaves.size());
if (best[index] == null || (featureSplitters[smallBaggingFeature].canSplit() && gain > best[index].getGain())) {
best[index] = new Node();
featureSplitters[smallBaggingFeature].fillNode(best[index]);
}
featureBinCnt += DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[smallBaggingFeature], tree.useMissing);
}
featureCnt++;
}
index++;
if (item.big != null) {
best[index] = null;
final int[] bigBaggingFeatures = item.big.baggingFeatures;
for (int bigBaggingFeature : bigBaggingFeatures) {
if (featureCnt >= start && featureCnt < end) {
featureSplitters[bigBaggingFeature].reset(histogram, new Slice(featureBinCnt, featureBinCnt + DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[bigBaggingFeature], tree.useMissing)), item.big.depth);
double gain = featureSplitters[bigBaggingFeature].bestSplit(tree.leaves.size());
if (best[index] == null || (featureSplitters[bigBaggingFeature].canSplit() && gain > best[index].getGain())) {
best[index] = new Node();
featureSplitters[bigBaggingFeature].fillNode(best[index]);
}
featureBinCnt += DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[bigBaggingFeature], tree.useMissing);
}
featureCnt++;
}
index++;
}
}
context.putObj("bestLength", index);
LOG.info("taskId: {}, {} end", context.getTaskId(), CalcFeatureGain.class.getSimpleName());
}
use of com.alibaba.alink.operator.common.tree.parallelcart.data.Slice in project Alink by alibaba.
the class SplitInstances method split.
public static NodeInfoPair split(NodeInfoPair.NodeInfo nodeInfo, EpsilonApproQuantile.WQSummary summary, int[] indices, Data data) {
int mid = data.splitInstances(nodeInfo.node, summary, indices, nodeInfo.slice);
int oobMid = data.splitInstances(nodeInfo.node, summary, indices, nodeInfo.oob);
nodeInfo.node.setNextNodes(new Node[] { new Node(), new Node() });
return new NodeInfoPair(// left
new NodeInfoPair.NodeInfo(nodeInfo.node.getNextNodes()[0], new Slice(nodeInfo.slice.start, mid), new Slice(nodeInfo.oob.start, oobMid), nodeInfo.depth + 1, nodeInfo.baggingFeatures), // right
new NodeInfoPair.NodeInfo(nodeInfo.node.getNextNodes()[1], new Slice(mid, nodeInfo.slice.end), new Slice(oobMid, nodeInfo.oob.end), nodeInfo.depth + 1, null));
}
use of com.alibaba.alink.operator.common.tree.parallelcart.data.Slice in project Alink by alibaba.
the class BoostingTest method testGradient.
@Test
public void testGradient() {
GradientBaseBooster booster = new GradientBaseBooster(new LogLoss(2.0, 1.0), new double[] { 1.0, 1.0, 1.0 }, new Slice(0, 3));
booster.boosting(null, new double[] { 1, 1, 0 }, new double[] { 0.5, 0.5, 0.1 });
Assert.assertArrayEquals(new double[] { 0.3775406687981454, 0.3775406687981454, -0.52497918747894 }, booster.getGradients(), 1e-6);
Assert.assertArrayEquals(new double[] { 0.1425369565965509, 0.1425369565965509, 0.27560314728604807 }, booster.getGradientsSqr(), 1e-6);
Assert.assertNull(booster.getHessions());
Assert.assertArrayEquals(new double[] { 1.0, 1.0, 1.0 }, booster.getWeights(), 1e-6);
}
Aggregations