use of com.alibaba.alink.operator.common.tree.LabelCounter in project Alink by alibaba.
the class IForest method fitNode.
public void fitNode(QueueItem item, int leafNodeCount) {
int totalSamples = item.partition.end - item.partition.start;
item.node.setCounter(new LabelCounter(totalSamples, totalSamples, null));
if (leafNodeCount + 2 >= maxLeaves) {
item.node.makeLeaf();
return;
}
if (item.depth >= maxDepth) {
item.node.makeLeaf();
return;
}
if (totalSamples <= minSamplesPerLeaf) {
item.node.makeLeaf();
return;
}
Collections.shuffle(shuffleBuffer);
for (Integer featureIndex : shuffleBuffer) {
double[] featureData = data[featureIndex];
double min = Double.MAX_VALUE;
double max = -Double.MAX_VALUE;
for (int i = item.partition.start; i < item.partition.end; i++) {
double value = featureData[sampleIndices[i]];
if (min > value) {
min = value;
}
if (max < value) {
max = value;
}
}
if (min == max) {
continue;
}
double random;
do {
random = Math.random();
} while (0.0 == random);
double splitValue = min + (max - min) * random;
if (splitValue == min || splitValue == max) {
continue;
}
int left = 0;
for (int i = item.partition.start; i < item.partition.end; i++) {
double value = featureData[sampleIndices[i]];
if (value <= splitValue) {
left++;
}
}
int right = totalSamples - left;
if (left < minSamplesPerLeaf || right < minSamplesPerLeaf || ((double) left / (double) totalSamples) < minSampleRatioPerChild || ((double) right / (double) totalSamples) < minSampleRatioPerChild) {
continue;
}
item.node.setFeatureIndex(featureIndex);
item.node.setContinuousSplit(splitValue);
return;
}
}
use of com.alibaba.alink.operator.common.tree.LabelCounter in project Alink by alibaba.
the class RandomForestModelMapper method predictResultDetail.
@Override
protected Tuple2<Object, String> predictResultDetail(SlicedSelectedSample selection) throws Exception {
Node[] root = treeModel.roots;
Row inputBuffer = inputBufferThreadLocal.get();
selection.fillRow(inputBuffer);
transform(inputBuffer);
int len = root.length;
Object result = null;
Map<String, Double> detail = null;
if (len > 0) {
LabelCounter labelCounter = new LabelCounter(0, 0, new double[root[0].getCounter().getDistributions().length]);
predict(inputBuffer, root[0], labelCounter, 1.0);
for (int i = 1; i < len; ++i) {
predict(inputBuffer, root[i], labelCounter, 1.0);
}
labelCounter.normWithWeight();
if (!Criteria.isRegression(treeModel.meta.get(TreeUtil.TREE_TYPE))) {
detail = new HashMap<>();
double[] probability = labelCounter.getDistributions();
double max = 0.0;
int maxIndex = -1;
for (int i = 0; i < probability.length; ++i) {
detail.put(String.valueOf(treeModel.labels[i]), probability[i]);
if (max < probability[i]) {
max = probability[i];
maxIndex = i;
}
}
if (maxIndex == -1) {
LOG.warn("Can not find the probability: {}", JsonConverter.toJson(probability));
}
result = treeModel.labels[maxIndex];
} else {
result = labelCounter.getDistributions()[0];
}
}
return new Tuple2<>(result, detail == null ? null : JsonConverter.toJson(detail));
}
use of com.alibaba.alink.operator.common.tree.LabelCounter in project Alink by alibaba.
the class RegObj method bestSplit.
@Override
public final void bestSplit(Node node, int minusId, NodeInfoPair pair) {
int start = minusId * lenPerStat();
double gBestGain = 0.;
int gBestSplit = 0;
int[] gBestSplitCategorical = null;
int gBestFeature = -1;
double sum = 0.;
double squareSum = 0.;
double weight = 0.;
// stat the first feature
for (int z = 0; z < nBin; ++z) {
int zStart = start + z * 3;
sum += minusHist[zStart];
squareSum += minusHist[zStart + 1];
weight += minusHist[zStart + 2];
}
node.setCounter(new LabelCounter(weight, 0, new double[] { sum, squareSum }));
if (maxDepth < pair.depth || minSamplesPerLeaf > weight) {
node.makeLeaf();
return;
}
double mean = sum / weight;
double mse = squareSum / weight - mean * mean;
if (baggingFeatureCount() != nFeatureCol) {
int baggingFeatureCount = baggingFeatureCount();
for (int j = 0; j < baggingFeatureCount; ++j) {
int fStart = start + j * nBin * 3;
if (featureMetas[pair.baggingFeatures[j]].getType().equals(FeatureMeta.FeatureType.CONTINUOUS)) {
Tuple2<Integer, Double> gain = bestSplitNumerical(fStart, mse, squareSum, sum, weight);
if (gain.f1 > gBestGain) {
gBestGain = gain.f1;
gBestSplit = gain.f0;
gBestFeature = pair.baggingFeatures[j];
}
} else {
Tuple2<int[], Double> gain = bestSplitCategorical(fStart, mse, squareSum, sum, weight, featureMetas[pair.baggingFeatures[j]].getNumCategorical());
if (gain.f1 > gBestGain) {
gBestGain = gain.f1;
gBestSplitCategorical = gain.f0;
gBestFeature = pair.baggingFeatures[j];
}
}
}
} else {
for (int j = 0; j < nFeatureCol; ++j) {
int fStart = start + j * nBin * 3;
if (featureMetas[j].getType().equals(FeatureMeta.FeatureType.CONTINUOUS)) {
Tuple2<Integer, Double> gain = bestSplitNumerical(fStart, mse, squareSum, sum, weight);
if (gain.f1 > gBestGain) {
gBestGain = gain.f1;
gBestSplit = gain.f0;
gBestFeature = j;
}
} else {
Tuple2<int[], Double> gain = bestSplitCategorical(fStart, mse, squareSum, sum, weight, featureMetas[j].getNumCategorical());
if (gain.f1 > gBestGain) {
gBestGain = gain.f1;
gBestSplitCategorical = gain.f0;
gBestFeature = j;
}
}
}
}
if (gBestGain > 0.) {
node.setFeatureIndex(gBestFeature);
if (featureMetas[gBestFeature].getType().equals(FeatureMeta.FeatureType.CONTINUOUS)) {
node.setContinuousSplit(gBestSplit);
} else {
node.setCategoricalSplit(gBestSplitCategorical);
}
node.setGain(gBestGain);
} else {
node.makeLeaf();
}
}
use of com.alibaba.alink.operator.common.tree.LabelCounter in project Alink by alibaba.
the class GbdtModelMapper method predictResultDetailTable.
private Tuple2<Object, Map<String, Double>> predictResultDetailTable(Node[] root, Row row) throws Exception {
transform(row);
int len = root.length;
Object result = null;
Map<String, Double> detail = null;
if (len > 0) {
LabelCounter labelCounter = new LabelCounter(0, 0, new double[root[0].getCounter().getDistributions().length]);
predict(row, root[0], labelCounter, 1.0);
for (int i = 1; i < len; ++i) {
if (root[i] != null) {
predict(row, root[i], labelCounter, 1.0);
}
}
return predictResultDetailWithLabelCounter(labelCounter);
}
return Tuple2.of(result, detail);
}
use of com.alibaba.alink.operator.common.tree.LabelCounter in project Alink by alibaba.
the class GbdtModelMapper method predictResultDetailVector.
private Tuple2<Object, Map<String, Double>> predictResultDetailVector(Node[] root, Vector vector) {
int len = root.length;
Object result = null;
Map<String, Double> detail = null;
if (len > 0) {
LabelCounter labelCounter = new LabelCounter(0, 0, new double[root[0].getCounter().getDistributions().length]);
predictVector(vector, root[0], labelCounter, 1.0);
for (int i = 1; i < len; ++i) {
if (root[i] != null) {
predictVector(vector, root[i], labelCounter, 1.0);
}
}
return predictResultDetailWithLabelCounter(labelCounter);
}
return Tuple2.of(result, detail);
}
Aggregations