use of ml.shifu.shifu.container.PerformanceObject in project shifu by ShifuML.
the class PerformanceEvaluator method setPerformanceObject.
static PerformanceObject setPerformanceObject(ConfusionMatrixObject confMatObject) {
PerformanceObject po = new PerformanceObject();
po.binLowestScore = confMatObject.getScore();
po.tp = confMatObject.getTp();
po.tn = confMatObject.getTn();
po.fp = confMatObject.getFp();
po.fn = confMatObject.getFn();
po.weightedTp = confMatObject.getWeightedTp();
po.weightedTn = confMatObject.getWeightedTn();
po.weightedFp = confMatObject.getWeightedFp();
po.weightedFn = confMatObject.getWeightedFn();
// Action Rate, TP + FP / Total;
po.actionRate = (confMatObject.getTp() + confMatObject.getFp()) / confMatObject.getTotal();
po.weightedActionRate = (confMatObject.getWeightedTp() + confMatObject.getWeightedFp()) / confMatObject.getWeightedTotal();
// recall = TP / (TP+FN)
po.recall = confMatObject.getTp() / (confMatObject.getTp() + confMatObject.getFn());
po.weightedRecall = confMatObject.getWeightedTp() / (confMatObject.getWeightedTp() + confMatObject.getWeightedFn());
// precision = TP / (TP+FP)
po.precision = confMatObject.getTp() / (confMatObject.getTp() + confMatObject.getFp());
po.weightedPrecision = confMatObject.getWeightedTp() / (confMatObject.getWeightedTp() + confMatObject.getWeightedFp());
// FPR, False Positive Rate (fp/(fp+tn))
po.fpr = confMatObject.getFp() / (confMatObject.getFp() + confMatObject.getTn());
po.weightedFpr = confMatObject.getWeightedFp() / (confMatObject.getWeightedFp() + confMatObject.getWeightedTn());
// Lift tp / (number_action * (number_postive / all_unit))
po.liftUnit = confMatObject.getTp() / ((confMatObject.getTp() + confMatObject.getFp()) * (confMatObject.getTp() + confMatObject.getFn()) / confMatObject.getTotal());
po.weightLiftUnit = confMatObject.getWeightedTp() / ((confMatObject.getWeightedTp() + confMatObject.getWeightedFp()) * (confMatObject.getWeightedTp() + confMatObject.getWeightedFn()) / confMatObject.getWeightedTotal());
return po;
}
use of ml.shifu.shifu.container.PerformanceObject in project shifu by ShifuML.
the class PerformanceEvaluator method logResult.
static void logResult(List<PerformanceObject> list, String info) {
DecimalFormat df = new DecimalFormat("#.####");
String formatString = "%10s %18s %10s %18s %15s %18s %10s %11s %10s";
log.info("Start print: " + info);
log.info(String.format(formatString, "ActionRate", "WeightedActionRate", "Recall", "WeightedRecall", "Precision", "WeightedPrecision", "FPR", "WeightedFPR", "BinLowestScore"));
for (PerformanceObject po : list) {
log.info(String.format(formatString, df.format(po.actionRate), df.format(po.weightedActionRate), df.format(po.recall), df.format(po.weightedRecall), df.format(po.precision), df.format(po.weightedPrecision), df.format(po.fpr), df.format(po.weightedFpr), po.binLowestScore));
}
}
use of ml.shifu.shifu.container.PerformanceObject in project shifu by ShifuML.
the class PerformanceEvaluator method bucketing.
public PerformanceResult bucketing(Iterable<ConfusionMatrixObject> results, long records, int numBucket, boolean isWeight) {
List<PerformanceObject> FPRList = new ArrayList<PerformanceObject>(numBucket + 1);
List<PerformanceObject> catchRateList = new ArrayList<PerformanceObject>(numBucket + 1);
List<PerformanceObject> gainList = new ArrayList<PerformanceObject>(numBucket + 1);
List<PerformanceObject> FPRWeightList = new ArrayList<PerformanceObject>(numBucket + 1);
List<PerformanceObject> catchRateWeightList = new ArrayList<PerformanceObject>(numBucket + 1);
List<PerformanceObject> gainWeightList = new ArrayList<PerformanceObject>(numBucket + 1);
int fpBin = 1, tpBin = 1, gainBin = 1, fpWeightBin = 1, tpWeightBin = 1, gainWeightBin = 1;
double binCapacity = 1.0 / numBucket;
PerformanceObject po = null;
boolean isFirst = true;
int i = 0;
for (ConfusionMatrixObject object : results) {
po = setPerformanceObject(object);
if (isFirst) {
// hit rate == NaN
po.precision = 1.0;
po.weightedPrecision = 1.0;
// lift = NaN
po.liftUnit = 0.0;
po.weightLiftUnit = 0.0;
FPRList.add(po);
catchRateList.add(po);
gainList.add(po);
FPRWeightList.add(po);
catchRateWeightList.add(po);
gainWeightList.add(po);
isFirst = false;
} else {
if (po.fpr >= fpBin * binCapacity) {
po.binNum = fpBin++;
FPRList.add(po);
}
if (po.recall >= tpBin * binCapacity) {
po.binNum = tpBin++;
catchRateList.add(po);
}
// prevent 99%
if ((double) (i + 1) / records >= gainBin * binCapacity) {
po.binNum = gainBin++;
gainList.add(po);
}
if (po.weightedFpr >= fpWeightBin * binCapacity) {
po.binNum = fpWeightBin++;
FPRWeightList.add(po);
}
if (po.weightedRecall >= tpWeightBin * binCapacity) {
po.binNum = tpWeightBin++;
catchRateWeightList.add(po);
}
if ((object.getWeightedTp() + object.getWeightedFp() + 1) / object.getWeightedTotal() >= gainWeightBin * binCapacity) {
po.binNum = gainWeightBin++;
gainWeightList.add(po);
}
}
i++;
}
logResult(FPRList, "Bucketing False Positive Rate");
if (isWeight) {
logResult(FPRWeightList, "Bucketing Weighted False Positive Rate");
}
logResult(catchRateList, "Bucketing Catch Rate");
if (isWeight) {
logResult(catchRateWeightList, "Bucketing Weighted Catch Rate");
}
logResult(gainList, "Bucketing Action rate");
if (isWeight) {
logResult(gainWeightList, "Bucketing Weighted action rate");
}
PerformanceResult result = new PerformanceResult();
result.version = Constants.version;
result.pr = catchRateList;
result.weightedPr = catchRateWeightList;
result.roc = FPRList;
result.weightedRoc = FPRWeightList;
result.gains = gainList;
result.weightedGains = gainWeightList;
// Calculate area under curve
result.areaUnderRoc = AreaUnderCurve.ofRoc(result.roc);
result.weightedAreaUnderRoc = AreaUnderCurve.ofWeightedRoc(result.weightedRoc);
result.areaUnderPr = AreaUnderCurve.ofPr(result.pr);
result.weightedAreaUnderPr = AreaUnderCurve.ofWeightedPr(result.weightedPr);
logAucResult(result, isWeight);
return result;
}
Aggregations