Search in sources :

Example 1 with Node

use of com.alibaba.alink.operator.common.tree.Node in project Alink by alibaba.

the class IForest method split.

public void split(QueueItem item) {
    int partitionMid = splitAndSwapData(item.partition, item.node.getFeatureIndex(), item.node.getContinuousSplit());
    SlicedPartition leftPartition = new SlicedPartition(item.partition.start, partitionMid);
    SlicedPartition rightPartition = new SlicedPartition(partitionMid, item.partition.end);
    Node left = new Node();
    Node right = new Node();
    item.node.setNextNodes(new Node[] { left, right });
    nodeQueue.add(new QueueItem(left, item.depth + 1, leftPartition));
    nodeQueue.add(new QueueItem(right, item.depth + 1, rightPartition));
}
Also used : Node(com.alibaba.alink.operator.common.tree.Node)

Example 2 with Node

use of com.alibaba.alink.operator.common.tree.Node in project Alink by alibaba.

the class HistogramBaseTreeObjs method initPerTree.

void initPerTree(BoostingObjs boostingObjs, EpsilonApproQuantile.WQSummary[] summaries) {
    assert queue.isEmpty();
    Node root = new Node();
    queue.push(new NodeInfoPair(new NodeInfoPair.NodeInfo(root, new Slice(0, boostingObjs.numBaggingInstances), new Slice(boostingObjs.numBaggingInstances, boostingObjs.data.getM()), 1, null), null));
    roots.add(root);
    leaves.clear();
    this.summaries = summaries;
}
Also used : Slice(com.alibaba.alink.operator.common.tree.parallelcart.data.Slice) Node(com.alibaba.alink.operator.common.tree.Node)

Example 3 with Node

use of com.alibaba.alink.operator.common.tree.Node in project Alink by alibaba.

the class TreeModelViz method addChildren.

private static void addChildren(ArrayList<Node4CalcPos> nodelist, int idx) {
    if (idx >= nodelist.size()) {
        throw new RuntimeException();
    }
    Node4CalcPos curTreeNode = nodelist.get(idx);
    Node curNode = curTreeNode.node;
    if (!curNode.isLeaf()) {
        int[] idxChildren = new int[curNode.getNextNodes().length];
        int level = curTreeNode.level + 1;
        for (int i = 0; i < idxChildren.length; i++) {
            Node4CalcPos tnode = new Node4CalcPos();
            tnode.node = curNode.getNextNodes()[i];
            tnode.parentIdx = idx;
            tnode.level = level;
            idxChildren[i] = nodelist.size();
            nodelist.add(tnode);
        }
        curTreeNode.childrenIdx = idxChildren;
    }
}
Also used : Node(com.alibaba.alink.operator.common.tree.Node)

Example 4 with Node

use of com.alibaba.alink.operator.common.tree.Node in project Alink by alibaba.

the class RandomForestModelMapper method predictResultDetail.

@Override
protected Tuple2<Object, String> predictResultDetail(SlicedSelectedSample selection) throws Exception {
    Node[] root = treeModel.roots;
    Row inputBuffer = inputBufferThreadLocal.get();
    selection.fillRow(inputBuffer);
    transform(inputBuffer);
    int len = root.length;
    Object result = null;
    Map<String, Double> detail = null;
    if (len > 0) {
        LabelCounter labelCounter = new LabelCounter(0, 0, new double[root[0].getCounter().getDistributions().length]);
        predict(inputBuffer, root[0], labelCounter, 1.0);
        for (int i = 1; i < len; ++i) {
            predict(inputBuffer, root[i], labelCounter, 1.0);
        }
        labelCounter.normWithWeight();
        if (!Criteria.isRegression(treeModel.meta.get(TreeUtil.TREE_TYPE))) {
            detail = new HashMap<>();
            double[] probability = labelCounter.getDistributions();
            double max = 0.0;
            int maxIndex = -1;
            for (int i = 0; i < probability.length; ++i) {
                detail.put(String.valueOf(treeModel.labels[i]), probability[i]);
                if (max < probability[i]) {
                    max = probability[i];
                    maxIndex = i;
                }
            }
            if (maxIndex == -1) {
                LOG.warn("Can not find the probability: {}", JsonConverter.toJson(probability));
            }
            result = treeModel.labels[maxIndex];
        } else {
            result = labelCounter.getDistributions()[0];
        }
    }
    return new Tuple2<>(result, detail == null ? null : JsonConverter.toJson(detail));
}
Also used : Node(com.alibaba.alink.operator.common.tree.Node) Tuple2(org.apache.flink.api.java.tuple.Tuple2) LabelCounter(com.alibaba.alink.operator.common.tree.LabelCounter) Row(org.apache.flink.types.Row)

Example 5 with Node

use of com.alibaba.alink.operator.common.tree.Node in project Alink by alibaba.

the class TreeModelEncoderModelMapper method selectMaxWeightedCriteriaOfChild.

private int selectMaxWeightedCriteriaOfChild(Node node) {
    if (node.getMissingSplit() != null && node.getMissingSplit().length == 1) {
        return node.getMissingSplit()[0];
    }
    int maxIndex = 0;
    double maxWeightedCriteria = 0.;
    int index = 0;
    for (Node child : node.getNextNodes()) {
        if (child.getCounter() != null) {
            double weightedCriteria = child.getCounter().getWeightSum();
            if (weightedCriteria > maxWeightedCriteria) {
                maxWeightedCriteria = weightedCriteria;
                maxIndex = index;
            }
        }
        index++;
    }
    return maxIndex;
}
Also used : Node(com.alibaba.alink.operator.common.tree.Node)

Aggregations

Node (com.alibaba.alink.operator.common.tree.Node)14 Slice (com.alibaba.alink.operator.common.tree.parallelcart.data.Slice)3 Tuple2 (org.apache.flink.api.java.tuple.Tuple2)2 Row (org.apache.flink.types.Row)2 DefaultDistributedInfo (com.alibaba.alink.common.io.directreader.DefaultDistributedInfo)1 DistributedInfo (com.alibaba.alink.common.io.directreader.DistributedInfo)1 LabelCounter (com.alibaba.alink.operator.common.tree.LabelCounter)1