use of ml.shifu.shifu.container.ConfusionMatrixObject in project shifu by ShifuML.
the class PerformanceEvaluator method bucketing.
public PerformanceResult bucketing(Iterable<ConfusionMatrixObject> results, long records, int numBucket, boolean isWeight) {
List<PerformanceObject> FPRList = new ArrayList<PerformanceObject>(numBucket + 1);
List<PerformanceObject> catchRateList = new ArrayList<PerformanceObject>(numBucket + 1);
List<PerformanceObject> gainList = new ArrayList<PerformanceObject>(numBucket + 1);
List<PerformanceObject> FPRWeightList = new ArrayList<PerformanceObject>(numBucket + 1);
List<PerformanceObject> catchRateWeightList = new ArrayList<PerformanceObject>(numBucket + 1);
List<PerformanceObject> gainWeightList = new ArrayList<PerformanceObject>(numBucket + 1);
int fpBin = 1, tpBin = 1, gainBin = 1, fpWeightBin = 1, tpWeightBin = 1, gainWeightBin = 1;
double binCapacity = 1.0 / numBucket;
PerformanceObject po = null;
boolean isFirst = true;
int i = 0;
for (ConfusionMatrixObject object : results) {
po = setPerformanceObject(object);
if (isFirst) {
// hit rate == NaN
po.precision = 1.0;
po.weightedPrecision = 1.0;
// lift = NaN
po.liftUnit = 0.0;
po.weightLiftUnit = 0.0;
FPRList.add(po);
catchRateList.add(po);
gainList.add(po);
FPRWeightList.add(po);
catchRateWeightList.add(po);
gainWeightList.add(po);
isFirst = false;
} else {
if (po.fpr >= fpBin * binCapacity) {
po.binNum = fpBin++;
FPRList.add(po);
}
if (po.recall >= tpBin * binCapacity) {
po.binNum = tpBin++;
catchRateList.add(po);
}
// prevent 99%
if ((double) (i + 1) / records >= gainBin * binCapacity) {
po.binNum = gainBin++;
gainList.add(po);
}
if (po.weightedFpr >= fpWeightBin * binCapacity) {
po.binNum = fpWeightBin++;
FPRWeightList.add(po);
}
if (po.weightedRecall >= tpWeightBin * binCapacity) {
po.binNum = tpWeightBin++;
catchRateWeightList.add(po);
}
if ((object.getWeightedTp() + object.getWeightedFp() + 1) / object.getWeightedTotal() >= gainWeightBin * binCapacity) {
po.binNum = gainWeightBin++;
gainWeightList.add(po);
}
}
i++;
}
logResult(FPRList, "Bucketing False Positive Rate");
if (isWeight) {
logResult(FPRWeightList, "Bucketing Weighted False Positive Rate");
}
logResult(catchRateList, "Bucketing Catch Rate");
if (isWeight) {
logResult(catchRateWeightList, "Bucketing Weighted Catch Rate");
}
logResult(gainList, "Bucketing Action rate");
if (isWeight) {
logResult(gainWeightList, "Bucketing Weighted action rate");
}
PerformanceResult result = new PerformanceResult();
result.version = Constants.version;
result.pr = catchRateList;
result.weightedPr = catchRateWeightList;
result.roc = FPRList;
result.weightedRoc = FPRWeightList;
result.gains = gainList;
result.weightedGains = gainWeightList;
// Calculate area under curve
result.areaUnderRoc = AreaUnderCurve.ofRoc(result.roc);
result.weightedAreaUnderRoc = AreaUnderCurve.ofWeightedRoc(result.weightedRoc);
result.areaUnderPr = AreaUnderCurve.ofPr(result.pr);
result.weightedAreaUnderPr = AreaUnderCurve.ofWeightedPr(result.weightedPr);
logAucResult(result, isWeight);
return result;
}
Aggregations