use of com.alibaba.alink.operator.common.tree.Node in project Alink by alibaba.
the class IForest method split.
public void split(QueueItem item) {
int partitionMid = splitAndSwapData(item.partition, item.node.getFeatureIndex(), item.node.getContinuousSplit());
SlicedPartition leftPartition = new SlicedPartition(item.partition.start, partitionMid);
SlicedPartition rightPartition = new SlicedPartition(partitionMid, item.partition.end);
Node left = new Node();
Node right = new Node();
item.node.setNextNodes(new Node[] { left, right });
nodeQueue.add(new QueueItem(left, item.depth + 1, leftPartition));
nodeQueue.add(new QueueItem(right, item.depth + 1, rightPartition));
}
use of com.alibaba.alink.operator.common.tree.Node in project Alink by alibaba.
the class HistogramBaseTreeObjs method initPerTree.
void initPerTree(BoostingObjs boostingObjs, EpsilonApproQuantile.WQSummary[] summaries) {
assert queue.isEmpty();
Node root = new Node();
queue.push(new NodeInfoPair(new NodeInfoPair.NodeInfo(root, new Slice(0, boostingObjs.numBaggingInstances), new Slice(boostingObjs.numBaggingInstances, boostingObjs.data.getM()), 1, null), null));
roots.add(root);
leaves.clear();
this.summaries = summaries;
}
use of com.alibaba.alink.operator.common.tree.Node in project Alink by alibaba.
the class TreeModelViz method addChildren.
private static void addChildren(ArrayList<Node4CalcPos> nodelist, int idx) {
if (idx >= nodelist.size()) {
throw new RuntimeException();
}
Node4CalcPos curTreeNode = nodelist.get(idx);
Node curNode = curTreeNode.node;
if (!curNode.isLeaf()) {
int[] idxChildren = new int[curNode.getNextNodes().length];
int level = curTreeNode.level + 1;
for (int i = 0; i < idxChildren.length; i++) {
Node4CalcPos tnode = new Node4CalcPos();
tnode.node = curNode.getNextNodes()[i];
tnode.parentIdx = idx;
tnode.level = level;
idxChildren[i] = nodelist.size();
nodelist.add(tnode);
}
curTreeNode.childrenIdx = idxChildren;
}
}
use of com.alibaba.alink.operator.common.tree.Node in project Alink by alibaba.
the class RandomForestModelMapper method predictResultDetail.
@Override
protected Tuple2<Object, String> predictResultDetail(SlicedSelectedSample selection) throws Exception {
Node[] root = treeModel.roots;
Row inputBuffer = inputBufferThreadLocal.get();
selection.fillRow(inputBuffer);
transform(inputBuffer);
int len = root.length;
Object result = null;
Map<String, Double> detail = null;
if (len > 0) {
LabelCounter labelCounter = new LabelCounter(0, 0, new double[root[0].getCounter().getDistributions().length]);
predict(inputBuffer, root[0], labelCounter, 1.0);
for (int i = 1; i < len; ++i) {
predict(inputBuffer, root[i], labelCounter, 1.0);
}
labelCounter.normWithWeight();
if (!Criteria.isRegression(treeModel.meta.get(TreeUtil.TREE_TYPE))) {
detail = new HashMap<>();
double[] probability = labelCounter.getDistributions();
double max = 0.0;
int maxIndex = -1;
for (int i = 0; i < probability.length; ++i) {
detail.put(String.valueOf(treeModel.labels[i]), probability[i]);
if (max < probability[i]) {
max = probability[i];
maxIndex = i;
}
}
if (maxIndex == -1) {
LOG.warn("Can not find the probability: {}", JsonConverter.toJson(probability));
}
result = treeModel.labels[maxIndex];
} else {
result = labelCounter.getDistributions()[0];
}
}
return new Tuple2<>(result, detail == null ? null : JsonConverter.toJson(detail));
}
use of com.alibaba.alink.operator.common.tree.Node in project Alink by alibaba.
the class TreeModelEncoderModelMapper method selectMaxWeightedCriteriaOfChild.
private int selectMaxWeightedCriteriaOfChild(Node node) {
if (node.getMissingSplit() != null && node.getMissingSplit().length == 1) {
return node.getMissingSplit()[0];
}
int maxIndex = 0;
double maxWeightedCriteria = 0.;
int index = 0;
for (Node child : node.getNextNodes()) {
if (child.getCounter() != null) {
double weightedCriteria = child.getCounter().getWeightSum();
if (weightedCriteria > maxWeightedCriteria) {
maxWeightedCriteria = weightedCriteria;
maxIndex = index;
}
}
index++;
}
return maxIndex;
}
Aggregations