Search in sources :

Example 1 with LabelCounter

use of com.alibaba.alink.operator.common.tree.LabelCounter in project Alink by alibaba.

the class IForest method fitNode.

public void fitNode(QueueItem item, int leafNodeCount) {
    int totalSamples = item.partition.end - item.partition.start;
    item.node.setCounter(new LabelCounter(totalSamples, totalSamples, null));
    if (leafNodeCount + 2 >= maxLeaves) {
        item.node.makeLeaf();
        return;
    }
    if (item.depth >= maxDepth) {
        item.node.makeLeaf();
        return;
    }
    if (totalSamples <= minSamplesPerLeaf) {
        item.node.makeLeaf();
        return;
    }
    Collections.shuffle(shuffleBuffer);
    for (Integer featureIndex : shuffleBuffer) {
        double[] featureData = data[featureIndex];
        double min = Double.MAX_VALUE;
        double max = -Double.MAX_VALUE;
        for (int i = item.partition.start; i < item.partition.end; i++) {
            double value = featureData[sampleIndices[i]];
            if (min > value) {
                min = value;
            }
            if (max < value) {
                max = value;
            }
        }
        if (min == max) {
            continue;
        }
        double random;
        do {
            random = Math.random();
        } while (0.0 == random);
        double splitValue = min + (max - min) * random;
        if (splitValue == min || splitValue == max) {
            continue;
        }
        int left = 0;
        for (int i = item.partition.start; i < item.partition.end; i++) {
            double value = featureData[sampleIndices[i]];
            if (value <= splitValue) {
                left++;
            }
        }
        int right = totalSamples - left;
        if (left < minSamplesPerLeaf || right < minSamplesPerLeaf || ((double) left / (double) totalSamples) < minSampleRatioPerChild || ((double) right / (double) totalSamples) < minSampleRatioPerChild) {
            continue;
        }
        item.node.setFeatureIndex(featureIndex);
        item.node.setContinuousSplit(splitValue);
        return;
    }
}
Also used : LabelCounter(com.alibaba.alink.operator.common.tree.LabelCounter)

Example 2 with LabelCounter

use of com.alibaba.alink.operator.common.tree.LabelCounter in project Alink by alibaba.

the class RandomForestModelMapper method predictResultDetail.

@Override
protected Tuple2<Object, String> predictResultDetail(SlicedSelectedSample selection) throws Exception {
    Node[] root = treeModel.roots;
    Row inputBuffer = inputBufferThreadLocal.get();
    selection.fillRow(inputBuffer);
    transform(inputBuffer);
    int len = root.length;
    Object result = null;
    Map<String, Double> detail = null;
    if (len > 0) {
        LabelCounter labelCounter = new LabelCounter(0, 0, new double[root[0].getCounter().getDistributions().length]);
        predict(inputBuffer, root[0], labelCounter, 1.0);
        for (int i = 1; i < len; ++i) {
            predict(inputBuffer, root[i], labelCounter, 1.0);
        }
        labelCounter.normWithWeight();
        if (!Criteria.isRegression(treeModel.meta.get(TreeUtil.TREE_TYPE))) {
            detail = new HashMap<>();
            double[] probability = labelCounter.getDistributions();
            double max = 0.0;
            int maxIndex = -1;
            for (int i = 0; i < probability.length; ++i) {
                detail.put(String.valueOf(treeModel.labels[i]), probability[i]);
                if (max < probability[i]) {
                    max = probability[i];
                    maxIndex = i;
                }
            }
            if (maxIndex == -1) {
                LOG.warn("Can not find the probability: {}", JsonConverter.toJson(probability));
            }
            result = treeModel.labels[maxIndex];
        } else {
            result = labelCounter.getDistributions()[0];
        }
    }
    return new Tuple2<>(result, detail == null ? null : JsonConverter.toJson(detail));
}
Also used : Node(com.alibaba.alink.operator.common.tree.Node) Tuple2(org.apache.flink.api.java.tuple.Tuple2) LabelCounter(com.alibaba.alink.operator.common.tree.LabelCounter) Row(org.apache.flink.types.Row)

Example 3 with LabelCounter

use of com.alibaba.alink.operator.common.tree.LabelCounter in project Alink by alibaba.

the class RegObj method bestSplit.

@Override
public final void bestSplit(Node node, int minusId, NodeInfoPair pair) {
    int start = minusId * lenPerStat();
    double gBestGain = 0.;
    int gBestSplit = 0;
    int[] gBestSplitCategorical = null;
    int gBestFeature = -1;
    double sum = 0.;
    double squareSum = 0.;
    double weight = 0.;
    // stat the first feature
    for (int z = 0; z < nBin; ++z) {
        int zStart = start + z * 3;
        sum += minusHist[zStart];
        squareSum += minusHist[zStart + 1];
        weight += minusHist[zStart + 2];
    }
    node.setCounter(new LabelCounter(weight, 0, new double[] { sum, squareSum }));
    if (maxDepth < pair.depth || minSamplesPerLeaf > weight) {
        node.makeLeaf();
        return;
    }
    double mean = sum / weight;
    double mse = squareSum / weight - mean * mean;
    if (baggingFeatureCount() != nFeatureCol) {
        int baggingFeatureCount = baggingFeatureCount();
        for (int j = 0; j < baggingFeatureCount; ++j) {
            int fStart = start + j * nBin * 3;
            if (featureMetas[pair.baggingFeatures[j]].getType().equals(FeatureMeta.FeatureType.CONTINUOUS)) {
                Tuple2<Integer, Double> gain = bestSplitNumerical(fStart, mse, squareSum, sum, weight);
                if (gain.f1 > gBestGain) {
                    gBestGain = gain.f1;
                    gBestSplit = gain.f0;
                    gBestFeature = pair.baggingFeatures[j];
                }
            } else {
                Tuple2<int[], Double> gain = bestSplitCategorical(fStart, mse, squareSum, sum, weight, featureMetas[pair.baggingFeatures[j]].getNumCategorical());
                if (gain.f1 > gBestGain) {
                    gBestGain = gain.f1;
                    gBestSplitCategorical = gain.f0;
                    gBestFeature = pair.baggingFeatures[j];
                }
            }
        }
    } else {
        for (int j = 0; j < nFeatureCol; ++j) {
            int fStart = start + j * nBin * 3;
            if (featureMetas[j].getType().equals(FeatureMeta.FeatureType.CONTINUOUS)) {
                Tuple2<Integer, Double> gain = bestSplitNumerical(fStart, mse, squareSum, sum, weight);
                if (gain.f1 > gBestGain) {
                    gBestGain = gain.f1;
                    gBestSplit = gain.f0;
                    gBestFeature = j;
                }
            } else {
                Tuple2<int[], Double> gain = bestSplitCategorical(fStart, mse, squareSum, sum, weight, featureMetas[j].getNumCategorical());
                if (gain.f1 > gBestGain) {
                    gBestGain = gain.f1;
                    gBestSplitCategorical = gain.f0;
                    gBestFeature = j;
                }
            }
        }
    }
    if (gBestGain > 0.) {
        node.setFeatureIndex(gBestFeature);
        if (featureMetas[gBestFeature].getType().equals(FeatureMeta.FeatureType.CONTINUOUS)) {
            node.setContinuousSplit(gBestSplit);
        } else {
            node.setCategoricalSplit(gBestSplitCategorical);
        }
        node.setGain(gBestGain);
    } else {
        node.makeLeaf();
    }
}
Also used : LabelCounter(com.alibaba.alink.operator.common.tree.LabelCounter)

Example 4 with LabelCounter

use of com.alibaba.alink.operator.common.tree.LabelCounter in project Alink by alibaba.

the class GbdtModelMapper method predictResultDetailTable.

private Tuple2<Object, Map<String, Double>> predictResultDetailTable(Node[] root, Row row) throws Exception {
    transform(row);
    int len = root.length;
    Object result = null;
    Map<String, Double> detail = null;
    if (len > 0) {
        LabelCounter labelCounter = new LabelCounter(0, 0, new double[root[0].getCounter().getDistributions().length]);
        predict(row, root[0], labelCounter, 1.0);
        for (int i = 1; i < len; ++i) {
            if (root[i] != null) {
                predict(row, root[i], labelCounter, 1.0);
            }
        }
        return predictResultDetailWithLabelCounter(labelCounter);
    }
    return Tuple2.of(result, detail);
}
Also used : LabelCounter(com.alibaba.alink.operator.common.tree.LabelCounter)

Example 5 with LabelCounter

use of com.alibaba.alink.operator.common.tree.LabelCounter in project Alink by alibaba.

the class GbdtModelMapper method predictResultDetailVector.

private Tuple2<Object, Map<String, Double>> predictResultDetailVector(Node[] root, Vector vector) {
    int len = root.length;
    Object result = null;
    Map<String, Double> detail = null;
    if (len > 0) {
        LabelCounter labelCounter = new LabelCounter(0, 0, new double[root[0].getCounter().getDistributions().length]);
        predictVector(vector, root[0], labelCounter, 1.0);
        for (int i = 1; i < len; ++i) {
            if (root[i] != null) {
                predictVector(vector, root[i], labelCounter, 1.0);
            }
        }
        return predictResultDetailWithLabelCounter(labelCounter);
    }
    return Tuple2.of(result, detail);
}
Also used : LabelCounter(com.alibaba.alink.operator.common.tree.LabelCounter)

Aggregations

LabelCounter (com.alibaba.alink.operator.common.tree.LabelCounter)5 Node (com.alibaba.alink.operator.common.tree.Node)1 Tuple2 (org.apache.flink.api.java.tuple.Tuple2)1 Row (org.apache.flink.types.Row)1