Search in sources :

Example 11 with PerformanceObject

use of ml.shifu.shifu.container.PerformanceObject in project shifu by ShifuML.

the class PerformanceEvaluator method setPerformanceObject.

static PerformanceObject setPerformanceObject(ConfusionMatrixObject confMatObject) {
    PerformanceObject po = new PerformanceObject();
    po.binLowestScore = confMatObject.getScore();
    po.tp = confMatObject.getTp();
    po.tn = confMatObject.getTn();
    po.fp = confMatObject.getFp();
    po.fn = confMatObject.getFn();
    po.weightedTp = confMatObject.getWeightedTp();
    po.weightedTn = confMatObject.getWeightedTn();
    po.weightedFp = confMatObject.getWeightedFp();
    po.weightedFn = confMatObject.getWeightedFn();
    // Action Rate, TP + FP / Total;
    po.actionRate = (confMatObject.getTp() + confMatObject.getFp()) / confMatObject.getTotal();
    po.weightedActionRate = (confMatObject.getWeightedTp() + confMatObject.getWeightedFp()) / confMatObject.getWeightedTotal();
    // recall = TP / (TP+FN)
    po.recall = confMatObject.getTp() / (confMatObject.getTp() + confMatObject.getFn());
    po.weightedRecall = confMatObject.getWeightedTp() / (confMatObject.getWeightedTp() + confMatObject.getWeightedFn());
    // precision = TP / (TP+FP)
    po.precision = confMatObject.getTp() / (confMatObject.getTp() + confMatObject.getFp());
    po.weightedPrecision = confMatObject.getWeightedTp() / (confMatObject.getWeightedTp() + confMatObject.getWeightedFp());
    // FPR, False Positive Rate (fp/(fp+tn))
    po.fpr = confMatObject.getFp() / (confMatObject.getFp() + confMatObject.getTn());
    po.weightedFpr = confMatObject.getWeightedFp() / (confMatObject.getWeightedFp() + confMatObject.getWeightedTn());
    // Lift tp / (number_action * (number_postive / all_unit))
    po.liftUnit = confMatObject.getTp() / ((confMatObject.getTp() + confMatObject.getFp()) * (confMatObject.getTp() + confMatObject.getFn()) / confMatObject.getTotal());
    po.weightLiftUnit = confMatObject.getWeightedTp() / ((confMatObject.getWeightedTp() + confMatObject.getWeightedFp()) * (confMatObject.getWeightedTp() + confMatObject.getWeightedFn()) / confMatObject.getWeightedTotal());
    return po;
}
Also used : PerformanceObject(ml.shifu.shifu.container.PerformanceObject)

Example 12 with PerformanceObject

use of ml.shifu.shifu.container.PerformanceObject in project shifu by ShifuML.

the class PerformanceEvaluator method logResult.

static void logResult(List<PerformanceObject> list, String info) {
    DecimalFormat df = new DecimalFormat("#.####");
    String formatString = "%10s %18s %10s %18s %15s %18s %10s %11s %10s";
    log.info("Start print: " + info);
    log.info(String.format(formatString, "ActionRate", "WeightedActionRate", "Recall", "WeightedRecall", "Precision", "WeightedPrecision", "FPR", "WeightedFPR", "BinLowestScore"));
    for (PerformanceObject po : list) {
        log.info(String.format(formatString, df.format(po.actionRate), df.format(po.weightedActionRate), df.format(po.recall), df.format(po.weightedRecall), df.format(po.precision), df.format(po.weightedPrecision), df.format(po.fpr), df.format(po.weightedFpr), po.binLowestScore));
    }
}
Also used : PerformanceObject(ml.shifu.shifu.container.PerformanceObject) DecimalFormat(java.text.DecimalFormat)

Example 13 with PerformanceObject

use of ml.shifu.shifu.container.PerformanceObject in project shifu by ShifuML.

the class PerformanceEvaluator method bucketing.

public PerformanceResult bucketing(Iterable<ConfusionMatrixObject> results, long records, int numBucket, boolean isWeight) {
    List<PerformanceObject> FPRList = new ArrayList<PerformanceObject>(numBucket + 1);
    List<PerformanceObject> catchRateList = new ArrayList<PerformanceObject>(numBucket + 1);
    List<PerformanceObject> gainList = new ArrayList<PerformanceObject>(numBucket + 1);
    List<PerformanceObject> FPRWeightList = new ArrayList<PerformanceObject>(numBucket + 1);
    List<PerformanceObject> catchRateWeightList = new ArrayList<PerformanceObject>(numBucket + 1);
    List<PerformanceObject> gainWeightList = new ArrayList<PerformanceObject>(numBucket + 1);
    int fpBin = 1, tpBin = 1, gainBin = 1, fpWeightBin = 1, tpWeightBin = 1, gainWeightBin = 1;
    double binCapacity = 1.0 / numBucket;
    PerformanceObject po = null;
    boolean isFirst = true;
    int i = 0;
    for (ConfusionMatrixObject object : results) {
        po = setPerformanceObject(object);
        if (isFirst) {
            // hit rate == NaN
            po.precision = 1.0;
            po.weightedPrecision = 1.0;
            // lift = NaN
            po.liftUnit = 0.0;
            po.weightLiftUnit = 0.0;
            FPRList.add(po);
            catchRateList.add(po);
            gainList.add(po);
            FPRWeightList.add(po);
            catchRateWeightList.add(po);
            gainWeightList.add(po);
            isFirst = false;
        } else {
            if (po.fpr >= fpBin * binCapacity) {
                po.binNum = fpBin++;
                FPRList.add(po);
            }
            if (po.recall >= tpBin * binCapacity) {
                po.binNum = tpBin++;
                catchRateList.add(po);
            }
            // prevent 99%
            if ((double) (i + 1) / records >= gainBin * binCapacity) {
                po.binNum = gainBin++;
                gainList.add(po);
            }
            if (po.weightedFpr >= fpWeightBin * binCapacity) {
                po.binNum = fpWeightBin++;
                FPRWeightList.add(po);
            }
            if (po.weightedRecall >= tpWeightBin * binCapacity) {
                po.binNum = tpWeightBin++;
                catchRateWeightList.add(po);
            }
            if ((object.getWeightedTp() + object.getWeightedFp() + 1) / object.getWeightedTotal() >= gainWeightBin * binCapacity) {
                po.binNum = gainWeightBin++;
                gainWeightList.add(po);
            }
        }
        i++;
    }
    logResult(FPRList, "Bucketing False Positive Rate");
    if (isWeight) {
        logResult(FPRWeightList, "Bucketing Weighted False Positive Rate");
    }
    logResult(catchRateList, "Bucketing Catch Rate");
    if (isWeight) {
        logResult(catchRateWeightList, "Bucketing Weighted Catch Rate");
    }
    logResult(gainList, "Bucketing Action rate");
    if (isWeight) {
        logResult(gainWeightList, "Bucketing Weighted action rate");
    }
    PerformanceResult result = new PerformanceResult();
    result.version = Constants.version;
    result.pr = catchRateList;
    result.weightedPr = catchRateWeightList;
    result.roc = FPRList;
    result.weightedRoc = FPRWeightList;
    result.gains = gainList;
    result.weightedGains = gainWeightList;
    // Calculate area under curve
    result.areaUnderRoc = AreaUnderCurve.ofRoc(result.roc);
    result.weightedAreaUnderRoc = AreaUnderCurve.ofWeightedRoc(result.weightedRoc);
    result.areaUnderPr = AreaUnderCurve.ofPr(result.pr);
    result.weightedAreaUnderPr = AreaUnderCurve.ofWeightedPr(result.weightedPr);
    logAucResult(result, isWeight);
    return result;
}
Also used : PerformanceObject(ml.shifu.shifu.container.PerformanceObject) PerformanceResult(ml.shifu.shifu.container.obj.PerformanceResult) ArrayList(java.util.ArrayList) ConfusionMatrixObject(ml.shifu.shifu.container.ConfusionMatrixObject)

Aggregations

PerformanceObject (ml.shifu.shifu.container.PerformanceObject)13 BufferedWriter (java.io.BufferedWriter)5 PerformanceResult (ml.shifu.shifu.container.obj.PerformanceResult)4 ArrayList (java.util.ArrayList)2 ConfusionMatrixObject (ml.shifu.shifu.container.ConfusionMatrixObject)2 Splitter (com.google.common.base.Splitter)1 DecimalFormat (java.text.DecimalFormat)1 Scanner (java.util.Scanner)1 SourceType (ml.shifu.shifu.container.obj.RawSourceData.SourceType)1 PerformanceExtractor (ml.shifu.shifu.core.eval.PerformanceExtractor)1 ShifuException (ml.shifu.shifu.exception.ShifuException)1 BeforeClass (org.testng.annotations.BeforeClass)1 Test (org.testng.annotations.Test)1