use of com.alibaba.alink.operator.common.tree.Node in project Alink by alibaba.
the class DecisionTree method split.
private void split(SequentialFeatureSplitter[] all, Node node, SequentialFeatureSplitter best) {
SequentialFeatureSplitter[][] childSplitters = (SequentialFeatureSplitter[][]) best.split(all);
Node[] nextNodes = new Node[childSplitters.length];
for (int i = 0; i < childSplitters.length; ++i) {
nextNodes[i] = new Node();
queue.add(Tuple2.of(nextNodes[i], childSplitters[i]));
}
node.setNextNodes(nextNodes);
}
use of com.alibaba.alink.operator.common.tree.Node in project Alink by alibaba.
the class DecisionTree method fit.
public Node fit() {
Node root = new Node();
init(root);
while (!queue.isEmpty()) {
Tuple2<Node, SequentialFeatureSplitter[]> item = queue.poll();
SequentialFeatureSplitter best = fitNode(bagging(item.f1), queue.size());
best.fillNode(item.f0);
if (!best.canSplit()) {
item.f0.makeLeaf();
item.f0.makeLeafProb();
continue;
}
split(item.f1, item.f0, best);
}
return root;
}
use of com.alibaba.alink.operator.common.tree.Node in project Alink by alibaba.
the class IForest method fit.
public Node fit() {
Node root = new Node();
nodeQueue.push(new QueueItem(root, 1, new SlicedPartition(0, nSamples)));
int leafNodeCount = 0;
while (leafNodeCount <= params.get(HasMaxLeaves.MAX_LEAVES)) {
if (nodeQueue.isEmpty()) {
break;
}
QueueItem item = nodeQueue.poll();
fitNode(item, leafNodeCount + nodeQueue.size());
if (item.node.isLeaf()) {
leafNodeCount++;
continue;
}
split(item);
}
return root;
}
use of com.alibaba.alink.operator.common.tree.Node in project Alink by alibaba.
the class CalcFeatureGain method calc.
@Override
public void calc(ComContext context) {
LOG.info("taskId: {}, {} start", context.getTaskId(), CalcFeatureGain.class.getSimpleName());
BoostingObjs boostingObjs = context.getObj("boostingObjs");
HistogramBaseTreeObjs tree = context.getObj("tree");
double[] histogram = context.getObj("histogram");
if (context.getStepNo() == 1) {
context.putObj("best", new Node[tree.maxNodeSize]);
featureSplitters = new HistogramFeatureSplitter[boostingObjs.data.getN()];
for (int i = 0; i < boostingObjs.data.getN(); ++i) {
featureSplitters[i] = createFeatureSplitter(boostingObjs.data.getFeatureMetas()[i].getType() == FeatureMeta.FeatureType.CATEGORICAL, boostingObjs.params, boostingObjs.data.getFeatureMetas()[i], tree.compareIndex4Categorical);
}
}
int sumFeatureCount = 0;
for (NodeInfoPair item : tree.queue) {
sumFeatureCount += boostingObjs.numBaggingFeatures;
if (item.big != null) {
sumFeatureCount += boostingObjs.numBaggingFeatures;
}
}
DistributedInfo distributedInfo = new DefaultDistributedInfo();
int start = (int) distributedInfo.startPos(context.getTaskId(), context.getNumTask(), sumFeatureCount);
int cnt = (int) distributedInfo.localRowCnt(context.getTaskId(), context.getNumTask(), sumFeatureCount);
int end = start + cnt;
int featureCnt = 0;
int featureBinCnt = 0;
Node[] best = context.getObj("best");
int index = 0;
for (NodeInfoPair item : tree.queue) {
best[index] = null;
final int[] smallBaggingFeatures = item.small.baggingFeatures;
for (int smallBaggingFeature : smallBaggingFeatures) {
if (featureCnt >= start && featureCnt < end) {
featureSplitters[smallBaggingFeature].reset(histogram, new Slice(featureBinCnt, featureBinCnt + DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[smallBaggingFeature], tree.useMissing)), item.small.depth);
double gain = featureSplitters[smallBaggingFeature].bestSplit(tree.leaves.size());
if (best[index] == null || (featureSplitters[smallBaggingFeature].canSplit() && gain > best[index].getGain())) {
best[index] = new Node();
featureSplitters[smallBaggingFeature].fillNode(best[index]);
}
featureBinCnt += DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[smallBaggingFeature], tree.useMissing);
}
featureCnt++;
}
index++;
if (item.big != null) {
best[index] = null;
final int[] bigBaggingFeatures = item.big.baggingFeatures;
for (int bigBaggingFeature : bigBaggingFeatures) {
if (featureCnt >= start && featureCnt < end) {
featureSplitters[bigBaggingFeature].reset(histogram, new Slice(featureBinCnt, featureBinCnt + DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[bigBaggingFeature], tree.useMissing)), item.big.depth);
double gain = featureSplitters[bigBaggingFeature].bestSplit(tree.leaves.size());
if (best[index] == null || (featureSplitters[bigBaggingFeature].canSplit() && gain > best[index].getGain())) {
best[index] = new Node();
featureSplitters[bigBaggingFeature].fillNode(best[index]);
}
featureBinCnt += DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[bigBaggingFeature], tree.useMissing);
}
featureCnt++;
}
index++;
}
}
context.putObj("bestLength", index);
LOG.info("taskId: {}, {} end", context.getTaskId(), CalcFeatureGain.class.getSimpleName());
}
use of com.alibaba.alink.operator.common.tree.Node in project Alink by alibaba.
the class TreeObj method bestSplit.
public final void bestSplit() throws Exception {
double[] stat = hist;
if (useStatPair()) {
System.arraycopy(stat, 0, minusHist, 0, histLen());
}
int lenPerNode = lenStatUnit();
int minusId = 0;
for (int i = 0; i < loopBufferSize; ++i) {
int statStart = i * lenPerNode;
if (!useStatPair()) {
System.arraycopy(stat, statStart, minusHist, minusId * lenPerNode, lenPerNode);
}
Node left = ofNode(loopBuffer[i], true);
bestSplit(left, minusId, loopBuffer[i]);
split(left, loopBuffer[i], true, minusId);
replaceWithActual(left);
if (loopBuffer[i].big != null) {
if (!useStatPair()) {
double[] parent = parentHistPool.get(loopBuffer[i].parentQueueId);
int bigStart = (minusId + 1) * lenPerNode;
for (int j = 0; j < lenPerNode; ++j) {
minusHist[bigStart + j] = parent[j] - stat[statStart + j];
}
parentHistPool.release(loopBuffer[i].parentQueueId);
}
minusId += 1;
Node right = ofNode(loopBuffer[i], false);
bestSplit(right, minusId, loopBuffer[i]);
split(right, loopBuffer[i], false, minusId);
replaceWithActual(right);
} else if (useStatPair()) {
minusId += 1;
}
minusId += 1;
}
}
Aggregations