Search in sources :

Example 6 with Node

use of com.alibaba.alink.operator.common.tree.Node in project Alink by alibaba.

the class DecisionTree method split.

private void split(SequentialFeatureSplitter[] all, Node node, SequentialFeatureSplitter best) {
    SequentialFeatureSplitter[][] childSplitters = (SequentialFeatureSplitter[][]) best.split(all);
    Node[] nextNodes = new Node[childSplitters.length];
    for (int i = 0; i < childSplitters.length; ++i) {
        nextNodes[i] = new Node();
        queue.add(Tuple2.of(nextNodes[i], childSplitters[i]));
    }
    node.setNextNodes(nextNodes);
}
Also used : Node(com.alibaba.alink.operator.common.tree.Node)

Example 7 with Node

use of com.alibaba.alink.operator.common.tree.Node in project Alink by alibaba.

the class DecisionTree method fit.

public Node fit() {
    Node root = new Node();
    init(root);
    while (!queue.isEmpty()) {
        Tuple2<Node, SequentialFeatureSplitter[]> item = queue.poll();
        SequentialFeatureSplitter best = fitNode(bagging(item.f1), queue.size());
        best.fillNode(item.f0);
        if (!best.canSplit()) {
            item.f0.makeLeaf();
            item.f0.makeLeafProb();
            continue;
        }
        split(item.f1, item.f0, best);
    }
    return root;
}
Also used : Node(com.alibaba.alink.operator.common.tree.Node)

Example 8 with Node

use of com.alibaba.alink.operator.common.tree.Node in project Alink by alibaba.

the class IForest method fit.

public Node fit() {
    Node root = new Node();
    nodeQueue.push(new QueueItem(root, 1, new SlicedPartition(0, nSamples)));
    int leafNodeCount = 0;
    while (leafNodeCount <= params.get(HasMaxLeaves.MAX_LEAVES)) {
        if (nodeQueue.isEmpty()) {
            break;
        }
        QueueItem item = nodeQueue.poll();
        fitNode(item, leafNodeCount + nodeQueue.size());
        if (item.node.isLeaf()) {
            leafNodeCount++;
            continue;
        }
        split(item);
    }
    return root;
}
Also used : Node(com.alibaba.alink.operator.common.tree.Node)

Example 9 with Node

use of com.alibaba.alink.operator.common.tree.Node in project Alink by alibaba.

the class CalcFeatureGain method calc.

@Override
public void calc(ComContext context) {
    LOG.info("taskId: {}, {} start", context.getTaskId(), CalcFeatureGain.class.getSimpleName());
    BoostingObjs boostingObjs = context.getObj("boostingObjs");
    HistogramBaseTreeObjs tree = context.getObj("tree");
    double[] histogram = context.getObj("histogram");
    if (context.getStepNo() == 1) {
        context.putObj("best", new Node[tree.maxNodeSize]);
        featureSplitters = new HistogramFeatureSplitter[boostingObjs.data.getN()];
        for (int i = 0; i < boostingObjs.data.getN(); ++i) {
            featureSplitters[i] = createFeatureSplitter(boostingObjs.data.getFeatureMetas()[i].getType() == FeatureMeta.FeatureType.CATEGORICAL, boostingObjs.params, boostingObjs.data.getFeatureMetas()[i], tree.compareIndex4Categorical);
        }
    }
    int sumFeatureCount = 0;
    for (NodeInfoPair item : tree.queue) {
        sumFeatureCount += boostingObjs.numBaggingFeatures;
        if (item.big != null) {
            sumFeatureCount += boostingObjs.numBaggingFeatures;
        }
    }
    DistributedInfo distributedInfo = new DefaultDistributedInfo();
    int start = (int) distributedInfo.startPos(context.getTaskId(), context.getNumTask(), sumFeatureCount);
    int cnt = (int) distributedInfo.localRowCnt(context.getTaskId(), context.getNumTask(), sumFeatureCount);
    int end = start + cnt;
    int featureCnt = 0;
    int featureBinCnt = 0;
    Node[] best = context.getObj("best");
    int index = 0;
    for (NodeInfoPair item : tree.queue) {
        best[index] = null;
        final int[] smallBaggingFeatures = item.small.baggingFeatures;
        for (int smallBaggingFeature : smallBaggingFeatures) {
            if (featureCnt >= start && featureCnt < end) {
                featureSplitters[smallBaggingFeature].reset(histogram, new Slice(featureBinCnt, featureBinCnt + DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[smallBaggingFeature], tree.useMissing)), item.small.depth);
                double gain = featureSplitters[smallBaggingFeature].bestSplit(tree.leaves.size());
                if (best[index] == null || (featureSplitters[smallBaggingFeature].canSplit() && gain > best[index].getGain())) {
                    best[index] = new Node();
                    featureSplitters[smallBaggingFeature].fillNode(best[index]);
                }
                featureBinCnt += DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[smallBaggingFeature], tree.useMissing);
            }
            featureCnt++;
        }
        index++;
        if (item.big != null) {
            best[index] = null;
            final int[] bigBaggingFeatures = item.big.baggingFeatures;
            for (int bigBaggingFeature : bigBaggingFeatures) {
                if (featureCnt >= start && featureCnt < end) {
                    featureSplitters[bigBaggingFeature].reset(histogram, new Slice(featureBinCnt, featureBinCnt + DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[bigBaggingFeature], tree.useMissing)), item.big.depth);
                    double gain = featureSplitters[bigBaggingFeature].bestSplit(tree.leaves.size());
                    if (best[index] == null || (featureSplitters[bigBaggingFeature].canSplit() && gain > best[index].getGain())) {
                        best[index] = new Node();
                        featureSplitters[bigBaggingFeature].fillNode(best[index]);
                    }
                    featureBinCnt += DataUtil.getFeatureCategoricalSize(boostingObjs.data.getFeatureMetas()[bigBaggingFeature], tree.useMissing);
                }
                featureCnt++;
            }
            index++;
        }
    }
    context.putObj("bestLength", index);
    LOG.info("taskId: {}, {} end", context.getTaskId(), CalcFeatureGain.class.getSimpleName());
}
Also used : DefaultDistributedInfo(com.alibaba.alink.common.io.directreader.DefaultDistributedInfo) Node(com.alibaba.alink.operator.common.tree.Node) DistributedInfo(com.alibaba.alink.common.io.directreader.DistributedInfo) DefaultDistributedInfo(com.alibaba.alink.common.io.directreader.DefaultDistributedInfo) Slice(com.alibaba.alink.operator.common.tree.parallelcart.data.Slice)

Example 10 with Node

use of com.alibaba.alink.operator.common.tree.Node in project Alink by alibaba.

the class TreeObj method bestSplit.

public final void bestSplit() throws Exception {
    double[] stat = hist;
    if (useStatPair()) {
        System.arraycopy(stat, 0, minusHist, 0, histLen());
    }
    int lenPerNode = lenStatUnit();
    int minusId = 0;
    for (int i = 0; i < loopBufferSize; ++i) {
        int statStart = i * lenPerNode;
        if (!useStatPair()) {
            System.arraycopy(stat, statStart, minusHist, minusId * lenPerNode, lenPerNode);
        }
        Node left = ofNode(loopBuffer[i], true);
        bestSplit(left, minusId, loopBuffer[i]);
        split(left, loopBuffer[i], true, minusId);
        replaceWithActual(left);
        if (loopBuffer[i].big != null) {
            if (!useStatPair()) {
                double[] parent = parentHistPool.get(loopBuffer[i].parentQueueId);
                int bigStart = (minusId + 1) * lenPerNode;
                for (int j = 0; j < lenPerNode; ++j) {
                    minusHist[bigStart + j] = parent[j] - stat[statStart + j];
                }
                parentHistPool.release(loopBuffer[i].parentQueueId);
            }
            minusId += 1;
            Node right = ofNode(loopBuffer[i], false);
            bestSplit(right, minusId, loopBuffer[i]);
            split(right, loopBuffer[i], false, minusId);
            replaceWithActual(right);
        } else if (useStatPair()) {
            minusId += 1;
        }
        minusId += 1;
    }
}
Also used : Node(com.alibaba.alink.operator.common.tree.Node)

Aggregations

Node (com.alibaba.alink.operator.common.tree.Node)14 Slice (com.alibaba.alink.operator.common.tree.parallelcart.data.Slice)3 Tuple2 (org.apache.flink.api.java.tuple.Tuple2)2 Row (org.apache.flink.types.Row)2 DefaultDistributedInfo (com.alibaba.alink.common.io.directreader.DefaultDistributedInfo)1 DistributedInfo (com.alibaba.alink.common.io.directreader.DistributedInfo)1 LabelCounter (com.alibaba.alink.operator.common.tree.LabelCounter)1