Search in sources :

Example 6 with ConfusionMatrixObject

use of ml.shifu.shifu.container.ConfusionMatrixObject in project shifu by ShifuML.

the class PerformanceEvaluator method bucketing.

public PerformanceResult bucketing(Iterable<ConfusionMatrixObject> results, long records, int numBucket, boolean isWeight) {
    List<PerformanceObject> FPRList = new ArrayList<PerformanceObject>(numBucket + 1);
    List<PerformanceObject> catchRateList = new ArrayList<PerformanceObject>(numBucket + 1);
    List<PerformanceObject> gainList = new ArrayList<PerformanceObject>(numBucket + 1);
    List<PerformanceObject> FPRWeightList = new ArrayList<PerformanceObject>(numBucket + 1);
    List<PerformanceObject> catchRateWeightList = new ArrayList<PerformanceObject>(numBucket + 1);
    List<PerformanceObject> gainWeightList = new ArrayList<PerformanceObject>(numBucket + 1);
    int fpBin = 1, tpBin = 1, gainBin = 1, fpWeightBin = 1, tpWeightBin = 1, gainWeightBin = 1;
    double binCapacity = 1.0 / numBucket;
    PerformanceObject po = null;
    boolean isFirst = true;
    int i = 0;
    for (ConfusionMatrixObject object : results) {
        po = setPerformanceObject(object);
        if (isFirst) {
            // hit rate == NaN
            po.precision = 1.0;
            po.weightedPrecision = 1.0;
            // lift = NaN
            po.liftUnit = 0.0;
            po.weightLiftUnit = 0.0;
            FPRList.add(po);
            catchRateList.add(po);
            gainList.add(po);
            FPRWeightList.add(po);
            catchRateWeightList.add(po);
            gainWeightList.add(po);
            isFirst = false;
        } else {
            if (po.fpr >= fpBin * binCapacity) {
                po.binNum = fpBin++;
                FPRList.add(po);
            }
            if (po.recall >= tpBin * binCapacity) {
                po.binNum = tpBin++;
                catchRateList.add(po);
            }
            // prevent 99%
            if ((double) (i + 1) / records >= gainBin * binCapacity) {
                po.binNum = gainBin++;
                gainList.add(po);
            }
            if (po.weightedFpr >= fpWeightBin * binCapacity) {
                po.binNum = fpWeightBin++;
                FPRWeightList.add(po);
            }
            if (po.weightedRecall >= tpWeightBin * binCapacity) {
                po.binNum = tpWeightBin++;
                catchRateWeightList.add(po);
            }
            if ((object.getWeightedTp() + object.getWeightedFp() + 1) / object.getWeightedTotal() >= gainWeightBin * binCapacity) {
                po.binNum = gainWeightBin++;
                gainWeightList.add(po);
            }
        }
        i++;
    }
    logResult(FPRList, "Bucketing False Positive Rate");
    if (isWeight) {
        logResult(FPRWeightList, "Bucketing Weighted False Positive Rate");
    }
    logResult(catchRateList, "Bucketing Catch Rate");
    if (isWeight) {
        logResult(catchRateWeightList, "Bucketing Weighted Catch Rate");
    }
    logResult(gainList, "Bucketing Action rate");
    if (isWeight) {
        logResult(gainWeightList, "Bucketing Weighted action rate");
    }
    PerformanceResult result = new PerformanceResult();
    result.version = Constants.version;
    result.pr = catchRateList;
    result.weightedPr = catchRateWeightList;
    result.roc = FPRList;
    result.weightedRoc = FPRWeightList;
    result.gains = gainList;
    result.weightedGains = gainWeightList;
    // Calculate area under curve
    result.areaUnderRoc = AreaUnderCurve.ofRoc(result.roc);
    result.weightedAreaUnderRoc = AreaUnderCurve.ofWeightedRoc(result.weightedRoc);
    result.areaUnderPr = AreaUnderCurve.ofPr(result.pr);
    result.weightedAreaUnderPr = AreaUnderCurve.ofWeightedPr(result.weightedPr);
    logAucResult(result, isWeight);
    return result;
}
Also used : PerformanceObject(ml.shifu.shifu.container.PerformanceObject) PerformanceResult(ml.shifu.shifu.container.obj.PerformanceResult) ArrayList(java.util.ArrayList) ConfusionMatrixObject(ml.shifu.shifu.container.ConfusionMatrixObject)

Aggregations

ConfusionMatrixObject (ml.shifu.shifu.container.ConfusionMatrixObject)6 ArrayList (java.util.ArrayList)4 ModelResultObject (ml.shifu.shifu.container.ModelResultObject)2 PerformanceObject (ml.shifu.shifu.container.PerformanceObject)2 PerformanceResult (ml.shifu.shifu.container.obj.PerformanceResult)2 ShifuException (ml.shifu.shifu.exception.ShifuException)2 Splitter (com.google.common.base.Splitter)1 BufferedReader (java.io.BufferedReader)1 Scanner (java.util.Scanner)1 SourceType (ml.shifu.shifu.container.obj.RawSourceData.SourceType)1 PathFinder (ml.shifu.shifu.fs.PathFinder)1