Search in sources :

Example 6 with PerformanceResult

use of ml.shifu.shifu.container.obj.PerformanceResult in project shifu by ShifuML.

the class PerformanceEvaluator method bucketing.

public PerformanceResult bucketing(Iterable<ConfusionMatrixObject> results, long records, int numBucket, boolean isWeight) {
    List<PerformanceObject> FPRList = new ArrayList<PerformanceObject>(numBucket + 1);
    List<PerformanceObject> catchRateList = new ArrayList<PerformanceObject>(numBucket + 1);
    List<PerformanceObject> gainList = new ArrayList<PerformanceObject>(numBucket + 1);
    List<PerformanceObject> FPRWeightList = new ArrayList<PerformanceObject>(numBucket + 1);
    List<PerformanceObject> catchRateWeightList = new ArrayList<PerformanceObject>(numBucket + 1);
    List<PerformanceObject> gainWeightList = new ArrayList<PerformanceObject>(numBucket + 1);
    int fpBin = 1, tpBin = 1, gainBin = 1, fpWeightBin = 1, tpWeightBin = 1, gainWeightBin = 1;
    double binCapacity = 1.0 / numBucket;
    PerformanceObject po = null;
    boolean isFirst = true;
    int i = 0;
    for (ConfusionMatrixObject object : results) {
        po = setPerformanceObject(object);
        if (isFirst) {
            // hit rate == NaN
            po.precision = 1.0;
            po.weightedPrecision = 1.0;
            // lift = NaN
            po.liftUnit = 0.0;
            po.weightLiftUnit = 0.0;
            FPRList.add(po);
            catchRateList.add(po);
            gainList.add(po);
            FPRWeightList.add(po);
            catchRateWeightList.add(po);
            gainWeightList.add(po);
            isFirst = false;
        } else {
            if (po.fpr >= fpBin * binCapacity) {
                po.binNum = fpBin++;
                FPRList.add(po);
            }
            if (po.recall >= tpBin * binCapacity) {
                po.binNum = tpBin++;
                catchRateList.add(po);
            }
            // prevent 99%
            if ((double) (i + 1) / records >= gainBin * binCapacity) {
                po.binNum = gainBin++;
                gainList.add(po);
            }
            if (po.weightedFpr >= fpWeightBin * binCapacity) {
                po.binNum = fpWeightBin++;
                FPRWeightList.add(po);
            }
            if (po.weightedRecall >= tpWeightBin * binCapacity) {
                po.binNum = tpWeightBin++;
                catchRateWeightList.add(po);
            }
            if ((object.getWeightedTp() + object.getWeightedFp() + 1) / object.getWeightedTotal() >= gainWeightBin * binCapacity) {
                po.binNum = gainWeightBin++;
                gainWeightList.add(po);
            }
        }
        i++;
    }
    logResult(FPRList, "Bucketing False Positive Rate");
    if (isWeight) {
        logResult(FPRWeightList, "Bucketing Weighted False Positive Rate");
    }
    logResult(catchRateList, "Bucketing Catch Rate");
    if (isWeight) {
        logResult(catchRateWeightList, "Bucketing Weighted Catch Rate");
    }
    logResult(gainList, "Bucketing Action rate");
    if (isWeight) {
        logResult(gainWeightList, "Bucketing Weighted action rate");
    }
    PerformanceResult result = new PerformanceResult();
    result.version = Constants.version;
    result.pr = catchRateList;
    result.weightedPr = catchRateWeightList;
    result.roc = FPRList;
    result.weightedRoc = FPRWeightList;
    result.gains = gainList;
    result.weightedGains = gainWeightList;
    // Calculate area under curve
    result.areaUnderRoc = AreaUnderCurve.ofRoc(result.roc);
    result.weightedAreaUnderRoc = AreaUnderCurve.ofWeightedRoc(result.weightedRoc);
    result.areaUnderPr = AreaUnderCurve.ofPr(result.pr);
    result.weightedAreaUnderPr = AreaUnderCurve.ofWeightedPr(result.weightedPr);
    logAucResult(result, isWeight);
    return result;
}
Also used : PerformanceObject(ml.shifu.shifu.container.PerformanceObject) PerformanceResult(ml.shifu.shifu.container.obj.PerformanceResult) ArrayList(java.util.ArrayList) ConfusionMatrixObject(ml.shifu.shifu.container.ConfusionMatrixObject)

Example 7 with PerformanceResult

use of ml.shifu.shifu.container.obj.PerformanceResult in project shifu by ShifuML.

the class PerformanceEvaluator method review.

public void review(Iterable<ConfusionMatrixObject> matrixList, long records) throws IOException {
    PathFinder pathFinder = new PathFinder(modelConfig);
    // bucketing
    PerformanceResult result = bucketing(matrixList, records, evalConfig.getPerformanceBucketNum(), evalConfig.getDataSet().getWeightColumnName() != null);
    Writer writer = null;
    try {
        writer = ShifuFileUtils.getWriter(pathFinder.getEvalPerformancePath(evalConfig, evalConfig.getDataSet().getSource()), evalConfig.getDataSet().getSource());
        JSONUtils.writeValue(writer, result);
    } catch (IOException e) {
        if (writer != null) {
            writer.close();
        }
    }
}
Also used : PerformanceResult(ml.shifu.shifu.container.obj.PerformanceResult) PathFinder(ml.shifu.shifu.fs.PathFinder) IOException(java.io.IOException) Writer(java.io.Writer)

Aggregations

PerformanceResult (ml.shifu.shifu.container.obj.PerformanceResult)7 PerformanceObject (ml.shifu.shifu.container.PerformanceObject)4 ArrayList (java.util.ArrayList)3 BufferedWriter (java.io.BufferedWriter)2 ConfusionMatrixObject (ml.shifu.shifu.container.ConfusionMatrixObject)2 Splitter (com.google.common.base.Splitter)1 IOException (java.io.IOException)1 Writer (java.io.Writer)1 Scanner (java.util.Scanner)1 SourceType (ml.shifu.shifu.container.obj.RawSourceData.SourceType)1 GainChart (ml.shifu.shifu.core.eval.GainChart)1 ShifuException (ml.shifu.shifu.exception.ShifuException)1 PathFinder (ml.shifu.shifu.fs.PathFinder)1