Search in sources :

Example 1 with PerformanceResult

use of ml.shifu.shifu.container.obj.PerformanceResult in project shifu by ShifuML.

the class EvalModelProcessor method runDistEval.

/**
 * Run distributed version of evaluation and performance review.
 *
 * @param evalConfig
 *            the evaluation instance
 * @throws IOException
 *             when any exception in delete the old tmp files
 */
private void runDistEval(EvalConfig evalConfig) throws IOException {
    ScoreStatus ss = runDistScore(evalConfig);
    List<String> scoreMetaColumns = evalConfig.getScoreMetaColumns(modelConfig);
    if (scoreMetaColumns == null || scoreMetaColumns.isEmpty() || !modelConfig.isRegression()) {
        // if no any champion score column set, go to previous evaluation with only challendge model
        runConfusionMatrix(evalConfig, ss, isGBTNotConvertToProb(evalConfig));
        return;
    }
    // 1. Get challenge model performance
    PerformanceResult challengeModelPerformance = runConfusionMatrix(evalConfig, ss, pathFinder.getEvalScorePath(evalConfig), pathFinder.getEvalPerformancePath(evalConfig), false, false, isGBTNotConvertToProb(evalConfig));
    List<PerformanceResult> prList = new ArrayList<PerformanceResult>();
    prList.add(challengeModelPerformance);
    // 2. Get all champion model performance
    List<String> names = new ArrayList<String>();
    names.add(modelConfig.getBasic().getName() + "-" + evalConfig.getName());
    for (String metaScoreColumn : scoreMetaColumns) {
        if (StringUtils.isBlank(metaScoreColumn)) {
            continue;
        }
        names.add(metaScoreColumn);
        LOG.info("Model score sort for {} in eval {} is started.", metaScoreColumn, evalConfig.getName());
        ScoreStatus newScoreStatus = runDistMetaScore(evalConfig, metaScoreColumn);
        PerformanceResult championModelPerformance = runConfusionMatrix(evalConfig, newScoreStatus, pathFinder.getEvalMetaScorePath(evalConfig, metaScoreColumn), pathFinder.getEvalMetaPerformancePath(evalConfig, metaScoreColumn), false, false, 0, 1, 2);
        prList.add(championModelPerformance);
    }
    synchronized (this) {
        GainChart gc = new GainChart();
        boolean hasWeight = StringUtils.isNotBlank(evalConfig.getDataSet().getWeightColumnName());
        // 3. Compute gain chart and other eval performance files only in local.
        String htmlGainChart = pathFinder.getEvalFilePath(evalConfig.getName(), evalConfig.getName() + "_gainchart.html", SourceType.LOCAL);
        LOG.info("Gain chart is generated in {}.", htmlGainChart);
        gc.generateHtml(evalConfig, modelConfig, htmlGainChart, prList, names);
        String hrmlPrRoc = pathFinder.getEvalFilePath(evalConfig.getName(), evalConfig.getName() + "_prroc.html", SourceType.LOCAL);
        LOG.info("PR & ROC chart is generated in {}.", hrmlPrRoc);
        gc.generateHtml4PrAndRoc(evalConfig, modelConfig, hrmlPrRoc, prList, names);
        for (int i = 0; i < names.size(); i++) {
            String name = names.get(i);
            PerformanceResult pr = prList.get(i);
            String unitGainChartCsv = pathFinder.getEvalFilePath(evalConfig.getName(), name + "_unit_wise_gainchart.csv", SourceType.LOCAL);
            LOG.info("Unit-wise gain chart data is generated in {} for eval {} and name {}.", unitGainChartCsv, evalConfig.getName(), name);
            gc.generateCsv(evalConfig, modelConfig, unitGainChartCsv, pr.gains);
            if (hasWeight) {
                String weightedGainChartCsv = pathFinder.getEvalFilePath(evalConfig.getName(), name + "_weighted_gainchart.csv", SourceType.LOCAL);
                LOG.info("Weighted gain chart data is generated in {} for eval {} and name {}.", weightedGainChartCsv, evalConfig.getName(), name);
                gc.generateCsv(evalConfig, modelConfig, weightedGainChartCsv, pr.weightedGains);
            }
            String prCsvFile = pathFinder.getEvalFilePath(evalConfig.getName(), name + "_unit_wise_pr.csv", SourceType.LOCAL);
            LOG.info("Unit-wise pr data is generated in {} for eval {} and name {}.", prCsvFile, evalConfig.getName(), name);
            gc.generateCsv(evalConfig, modelConfig, prCsvFile, pr.pr);
            if (hasWeight) {
                String weightedPrCsvFile = pathFinder.getEvalFilePath(evalConfig.getName(), name + "_weighted_pr.csv", SourceType.LOCAL);
                LOG.info("Weighted pr data is generated in {} for eval {} and name {}.", weightedPrCsvFile, evalConfig.getName(), name);
                gc.generateCsv(evalConfig, modelConfig, weightedPrCsvFile, pr.weightedPr);
            }
            String rocCsvFile = pathFinder.getEvalFilePath(evalConfig.getName(), name + "_unit_wise_roc.csv", SourceType.LOCAL);
            LOG.info("Unit-wise roc data is generated in {} for eval {} and name {}.", rocCsvFile, evalConfig.getName(), name);
            gc.generateCsv(evalConfig, modelConfig, rocCsvFile, pr.roc);
            if (hasWeight) {
                String weightedRocCsvFile = pathFinder.getEvalFilePath(evalConfig.getName(), name + "_weighted_roc.csv", SourceType.LOCAL);
                LOG.info("Weighted roc data is generated in {} for eval {} and name {}.", weightedRocCsvFile, evalConfig.getName(), name);
                gc.generateCsv(evalConfig, modelConfig, weightedRocCsvFile, pr.weightedRoc);
            }
            String modelScoreGainChartCsv = pathFinder.getEvalFilePath(evalConfig.getName(), name + "_modelscore_gainchart.csv", SourceType.LOCAL);
            LOG.info("Model score gain chart data is generated in {} for eval {} and name {}.", modelScoreGainChartCsv, evalConfig.getName(), name);
            gc.generateCsv(evalConfig, modelConfig, modelScoreGainChartCsv, pr.modelScoreList);
        }
        LOG.info("Performance Evaluation is done for {}.", evalConfig.getName());
    }
}
Also used : PerformanceResult(ml.shifu.shifu.container.obj.PerformanceResult) ArrayList(java.util.ArrayList) GainChart(ml.shifu.shifu.core.eval.GainChart)

Example 2 with PerformanceResult

use of ml.shifu.shifu.container.obj.PerformanceResult in project shifu by ShifuML.

the class GainChart method generateHtml.

public void generateHtml(EvalConfig evalConfig, ModelConfig modelConfig, String fileName, List<PerformanceResult> results, List<String> names) throws IOException {
    BufferedWriter writer = null;
    try {
        writer = ShifuFileUtils.getWriter(fileName, SourceType.LOCAL);
        writer.write(GainChartTemplate.HIGHCHART_BASE_BEGIN);
        writer.write(String.format(GainChartTemplate.HIGHCHART_BUTTON_PANEL_TEMPLATE_1, "Weighted Operation Point", "lst0", "Weighted Recall", "lst1", "Unit-wise Recall"));
        writer.write(String.format(GainChartTemplate.HIGHCHART_BUTTON_PANEL_TEMPLATE_2, "Unit-wise Operation Point", "lst2", "Weighted Recall", "lst3", "Unit-wise Recall"));
        writer.write(String.format(GainChartTemplate.HIGHCHART_BUTTON_PANEL_TEMPLATE_3, "Model Score", "lst4", "Weighted Recall", "lst5", "Unit-wise Recall"));
        writer.write(String.format(GainChartTemplate.HIGHCHART_BUTTON_PANEL_TEMPLATE_4, "Score Distibution", "lst6", "Score Count"));
        writer.write("      </div>\n");
        writer.write("      <div class=\"col-sm-9 col-sm-offset-3 col-md-10 col-md-offset-2 main\">\n");
        writer.write(String.format(GainChartTemplate.HIGHCHART_DIV, "container0"));
        writer.write(String.format(GainChartTemplate.HIGHCHART_DIV, "container1"));
        writer.write(String.format(GainChartTemplate.HIGHCHART_DIV, "container2"));
        writer.write(String.format(GainChartTemplate.HIGHCHART_DIV, "container3"));
        writer.write(String.format(GainChartTemplate.HIGHCHART_DIV, "container4"));
        writer.write(String.format(GainChartTemplate.HIGHCHART_DIV, "container5"));
        writer.write(String.format(GainChartTemplate.HIGHCHART_DIV, "container6"));
        writer.write("<script>\n");
        writer.write("\n");
        for (int j = 0; j < results.size(); j++) {
            PerformanceResult result = results.get(j);
            writer.write("  var data_" + j + " = [\n");
            for (int i = 0; i < result.weightedGains.size(); i++) {
                PerformanceObject po = result.weightedGains.get(i);
                writer.write(String.format(GainChartTemplate.DATA_FORMAT, GainChartTemplate.DF.format(po.weightedRecall * 100), GainChartTemplate.DF.format(po.weightedActionRate * 100), GainChartTemplate.DF.format(po.weightedActionRate * 100), GainChartTemplate.DF.format(po.weightedPrecision * 100), GainChartTemplate.DF.format(po.actionRate * 100), GainChartTemplate.DF.format(po.binLowestScore)));
                if (i != result.weightedGains.size() - 1) {
                    writer.write(",");
                }
            }
            writer.write("  ];\n");
            writer.write("\n");
        }
        for (int j = 0; j < results.size(); j++) {
            PerformanceResult result = results.get(j);
            writer.write("  var data_" + (results.size() + j) + " = [\n");
            for (int i = 0; i < result.weightedGains.size(); i++) {
                PerformanceObject po = result.weightedGains.get(i);
                writer.write(String.format(GainChartTemplate.DATA_FORMAT, GainChartTemplate.DF.format(po.recall * 100), GainChartTemplate.DF.format(po.weightedActionRate * 100), GainChartTemplate.DF.format(po.weightedActionRate * 100), GainChartTemplate.DF.format(po.precision * 100), GainChartTemplate.DF.format(po.actionRate * 100), GainChartTemplate.DF.format(po.binLowestScore)));
                if (i != result.weightedGains.size() - 1) {
                    writer.write(",");
                }
            }
            writer.write("  ];\n");
            writer.write("\n");
        }
        for (int j = 0; j < results.size(); j++) {
            PerformanceResult result = results.get(j);
            writer.write("  var data_" + (2 * results.size() + j) + " = [\n");
            for (int i = 0; i < result.gains.size(); i++) {
                PerformanceObject po = result.gains.get(i);
                writer.write(String.format(GainChartTemplate.DATA_FORMAT, GainChartTemplate.DF.format(po.weightedRecall * 100), GainChartTemplate.DF.format(po.actionRate * 100), GainChartTemplate.DF.format(po.weightedActionRate * 100), GainChartTemplate.DF.format(po.weightedPrecision * 100), GainChartTemplate.DF.format(po.actionRate * 100), GainChartTemplate.DF.format(po.binLowestScore)));
                if (i != result.gains.size() - 1) {
                    writer.write(",");
                }
            }
            writer.write("  ];\n");
            writer.write("\n");
        }
        for (int j = 0; j < results.size(); j++) {
            PerformanceResult result = results.get(j);
            writer.write("  var data_" + (3 * results.size() + j) + " = [\n");
            for (int i = 0; i < result.gains.size(); i++) {
                PerformanceObject po = result.gains.get(i);
                writer.write(String.format(GainChartTemplate.DATA_FORMAT, GainChartTemplate.DF.format(po.recall * 100), GainChartTemplate.DF.format(po.actionRate * 100), GainChartTemplate.DF.format(po.weightedActionRate * 100), GainChartTemplate.DF.format(po.precision * 100), GainChartTemplate.DF.format(po.actionRate * 100), GainChartTemplate.DF.format(po.binLowestScore)));
                if (i != result.gains.size() - 1) {
                    writer.write(",");
                }
            }
            writer.write("  ];\n");
            writer.write("\n");
        }
        for (int j = 0; j < results.size(); j++) {
            PerformanceResult result = results.get(j);
            writer.write("  var data_" + (4 * results.size() + j) + " = [\n");
            for (int i = 0; i < result.modelScoreList.size(); i++) {
                PerformanceObject po = result.modelScoreList.get(i);
                writer.write(String.format(GainChartTemplate.DATA_FORMAT, GainChartTemplate.DF.format(po.weightedRecall * 100), GainChartTemplate.DF.format(po.binLowestScore), GainChartTemplate.DF.format(po.weightedActionRate * 100), GainChartTemplate.DF.format(po.weightedPrecision * 100), GainChartTemplate.DF.format(po.actionRate * 100), GainChartTemplate.DF.format(po.binLowestScore)));
                if (i != result.modelScoreList.size() - 1) {
                    writer.write(",");
                }
            }
            writer.write("  ];\n");
            writer.write("\n");
        }
        for (int j = 0; j < results.size(); j++) {
            PerformanceResult result = results.get(j);
            writer.write("  var data_" + (5 * results.size() + j) + " = [\n");
            for (int i = 0; i < result.modelScoreList.size(); i++) {
                PerformanceObject po = result.modelScoreList.get(i);
                writer.write(String.format(GainChartTemplate.DATA_FORMAT, GainChartTemplate.DF.format(po.recall * 100), GainChartTemplate.DF.format(po.binLowestScore), GainChartTemplate.DF.format(po.weightedActionRate * 100), GainChartTemplate.DF.format(po.precision * 100), GainChartTemplate.DF.format(po.actionRate * 100), GainChartTemplate.DF.format(po.binLowestScore)));
                if (i != result.modelScoreList.size() - 1) {
                    writer.write(",");
                }
            }
            writer.write("  ];\n");
            writer.write("\n");
        }
        for (int j = 0; j < results.size(); j++) {
            PerformanceResult result = results.get(j);
            writer.write("  var data_" + (6 * results.size() + j) + " = [\n");
            for (int i = 0; i < result.modelScoreList.size(); i++) {
                PerformanceObject po = result.modelScoreList.get(i);
                writer.write(String.format(GainChartTemplate.SCORE_DATA_FORMAT, GainChartTemplate.DF.format(po.scoreCount), GainChartTemplate.DF.format(po.binLowestScore), GainChartTemplate.DF.format(po.scoreCount), GainChartTemplate.DF.format(po.binLowestScore)));
                if (i != result.modelScoreList.size() - 1) {
                    writer.write(",");
                }
            }
            writer.write("  ];\n");
            writer.write("\n");
        }
        writer.write("$(function () {\n");
        writer.write(String.format(GainChartTemplate.HIGHCHART_CHART_TEMPLATE_PREFIX, "container0", "Weighted Recall", modelConfig.getBasic().getName(), "Weighted  Operation Point", "%", "false"));
        int currIndex = 0;
        writer.write("series: [");
        for (int i = 0; i < results.size(); i++) {
            writer.write("{");
            writer.write("  data: data_" + (currIndex++) + ",");
            writer.write("  name: '" + names.get(i) + "',");
            writer.write("  turboThreshold:0");
            writer.write("}");
            if (i != results.size() - 1) {
                writer.write(",");
            }
        }
        writer.write("]");
        writer.write("});");
        writer.write("\n");
        writer.write(String.format(GainChartTemplate.HIGHCHART_CHART_TEMPLATE_PREFIX, "container1", "Unit-wise Recall", modelConfig.getBasic().getName(), "Weighted  Operation Point", "%", "false"));
        writer.write("series: [");
        for (int i = 0; i < results.size(); i++) {
            writer.write("{");
            writer.write("  data: data_" + (currIndex++) + ",");
            writer.write("  name: '" + names.get(i) + "',");
            writer.write("  turboThreshold:0");
            writer.write("}");
            if (i != results.size() - 1) {
                writer.write(",");
            }
        }
        writer.write("]");
        writer.write("});");
        writer.write(String.format(GainChartTemplate.HIGHCHART_CHART_TEMPLATE_PREFIX, "container2", "Weighted Recall", modelConfig.getBasic().getName(), "Unit-wise  Operation Point", "%", "false"));
        writer.write("series: [");
        for (int i = 0; i < results.size(); i++) {
            writer.write("{");
            writer.write("  data: data_" + (currIndex++) + ",");
            writer.write("  name: '" + names.get(i) + "',");
            writer.write("  turboThreshold:0");
            writer.write("}");
            if (i != results.size() - 1) {
                writer.write(",");
            }
        }
        writer.write("]");
        writer.write("});");
        writer.write("\n");
        writer.write(String.format(GainChartTemplate.HIGHCHART_CHART_TEMPLATE_PREFIX, "container3", "Unit-wise Recall", modelConfig.getBasic().getName(), "Unit-wise  Operation Point", "%", "false"));
        writer.write("series: [");
        for (int i = 0; i < results.size(); i++) {
            writer.write("{");
            writer.write("  data: data_" + (currIndex++) + ",");
            writer.write("  name: '" + names.get(i) + "',");
            writer.write("  turboThreshold:0");
            writer.write("}");
            if (i != results.size() - 1) {
                writer.write(",");
            }
        }
        writer.write("]");
        writer.write("});");
        writer.write("\n");
        writer.write(String.format(GainChartTemplate.HIGHCHART_CHART_TEMPLATE_PREFIX, "container4", "Weighted Recall", modelConfig.getBasic().getName(), "Model Score", "", "true"));
        writer.write("series: [");
        for (int i = 0; i < results.size(); i++) {
            writer.write("{");
            writer.write("  data: data_" + (currIndex++) + ",");
            writer.write("  name: '" + names.get(i) + "',");
            writer.write("  turboThreshold:0");
            writer.write("}");
            if (i != results.size() - 1) {
                writer.write(",");
            }
        }
        writer.write("]");
        writer.write("});");
        writer.write("\n");
        writer.write(String.format(GainChartTemplate.HIGHCHART_CHART_TEMPLATE_PREFIX, "container5", "Unit-wise Recall", modelConfig.getBasic().getName(), "Model Score", "", "true"));
        writer.write("series: [");
        for (int i = 0; i < results.size(); i++) {
            writer.write("{");
            writer.write("  data: data_" + (currIndex++) + ",");
            writer.write("  name: '" + names.get(i) + "',");
            writer.write("  turboThreshold:0");
            writer.write("}");
            if (i != results.size() - 1) {
                writer.write(",");
            }
        }
        writer.write("]");
        writer.write("});");
        writer.write("\n");
        writer.write(String.format(GainChartTemplate.SCORE_HIGHCHART_CHART_PREFIX, "container6", "Score Distribution", modelConfig.getBasic().getName(), "Model Score", "", "false"));
        writer.write("series: [");
        for (int i = 0; i < results.size(); i++) {
            writer.write("{");
            writer.write("  data: data_" + (currIndex++) + ",");
            writer.write("  name: '" + names.get(i) + "',");
            writer.write("  turboThreshold:0");
            writer.write("}");
            if (i != results.size() - 1) {
                writer.write(",");
            }
        }
        writer.write("]");
        writer.write("});");
        writer.write("\n");
        writer.write("});\n");
        writer.write("\n");
        writer.write("$(document).ready(function() {\n");
        writer.write(String.format(GainChartTemplate.HIGHCHART_LIST_TOGGLE_TEMPLATE, "lst0", "container0", "lst0"));
        writer.write("\n");
        writer.write(String.format(GainChartTemplate.HIGHCHART_LIST_TOGGLE_TEMPLATE, "lst1", "container1", "lst1"));
        writer.write("\n");
        writer.write(String.format(GainChartTemplate.HIGHCHART_LIST_TOGGLE_TEMPLATE, "lst2", "container2", "lst2"));
        writer.write("\n");
        writer.write(String.format(GainChartTemplate.HIGHCHART_LIST_TOGGLE_TEMPLATE, "lst3", "container3", "lst3"));
        writer.write("\n");
        writer.write(String.format(GainChartTemplate.HIGHCHART_LIST_TOGGLE_TEMPLATE, "lst4", "container4", "lst4"));
        writer.write("\n");
        writer.write(String.format(GainChartTemplate.HIGHCHART_LIST_TOGGLE_TEMPLATE, "lst5", "container5", "lst5"));
        writer.write("\n");
        writer.write(String.format(GainChartTemplate.HIGHCHART_LIST_TOGGLE_TEMPLATE, "lst6", "container6", "lst6"));
        writer.write("\n");
        writer.write("  var ics = ['#container1', '#container2', '#container4', '#container5', '#container6'];\n");
        writer.write("  var icl = ics.length;\n");
        writer.write("  for (var i = 0; i < icl; i++) {\n");
        writer.write("      $(ics[i]).toggleClass('show');\n");
        writer.write("      $(ics[i]).toggleClass('hidden');\n");
        writer.write("      $(ics[i]).toggleClass('ls_chosen');\n");
        writer.write("  };\n");
        writer.write("\n");
        writer.write("});\n");
        writer.write("\n");
        writer.write("</script>\n");
        writer.write(GainChartTemplate.HIGHCHART_BASE_END);
    } finally {
        if (writer != null) {
            writer.close();
        }
    }
}
Also used : PerformanceResult(ml.shifu.shifu.container.obj.PerformanceResult) PerformanceObject(ml.shifu.shifu.container.PerformanceObject) BufferedWriter(java.io.BufferedWriter)

Example 3 with PerformanceResult

use of ml.shifu.shifu.container.obj.PerformanceResult in project shifu by ShifuML.

the class GainChart method generateHtml4PrAndRoc.

public void generateHtml4PrAndRoc(EvalConfig evalConfig, ModelConfig modelConfig, String fileName, List<PerformanceResult> results, List<String> names) throws IOException {
    BufferedWriter writer = null;
    try {
        writer = ShifuFileUtils.getWriter(fileName, SourceType.LOCAL);
        writer.write(GainChartTemplate.HIGHCHART_BASE_BEGIN);
        writer.write(String.format(GainChartTemplate.HIGHCHART_BUTTON_PANEL_TEMPLATE_1, "Weighted PR Curve", "lst0", "Weighted Precision", "lst1", "Unit-wise Precision"));
        writer.write(String.format(GainChartTemplate.HIGHCHART_BUTTON_PANEL_TEMPLATE_2, "Unit-wise PR Curve", "lst2", "Weighted Precision", "lst3", "Unit-wise Precision"));
        writer.write(String.format(GainChartTemplate.HIGHCHART_BUTTON_PANEL_TEMPLATE_1, "Weighted ROC Curve", "lst4", "Weighted Recall", "lst5", "Unit-wise Recall"));
        writer.write(String.format(GainChartTemplate.HIGHCHART_BUTTON_PANEL_TEMPLATE_2, "Unit-wise ROC Curve", "lst6", "Weighted Recall", "lst7", "Unit-wise Recall"));
        writer.write("      </div>\n");
        writer.write("      <div class=\"col-sm-9 col-sm-offset-3 col-md-10 col-md-offset-2 main\">\n");
        writer.write(String.format(GainChartTemplate.HIGHCHART_DIV, "container0"));
        writer.write(String.format(GainChartTemplate.HIGHCHART_DIV, "container1"));
        writer.write(String.format(GainChartTemplate.HIGHCHART_DIV, "container2"));
        writer.write(String.format(GainChartTemplate.HIGHCHART_DIV, "container3"));
        writer.write(String.format(GainChartTemplate.HIGHCHART_DIV, "container4"));
        writer.write(String.format(GainChartTemplate.HIGHCHART_DIV, "container5"));
        writer.write(String.format(GainChartTemplate.HIGHCHART_DIV, "container6"));
        writer.write(String.format(GainChartTemplate.HIGHCHART_DIV, "container7"));
        writer.write("<script>\n");
        writer.write("\n");
        for (int j = 0; j < results.size(); j++) {
            PerformanceResult result = results.get(j);
            writer.write("  var data_" + j + " = [\n");
            for (int i = 0; i < result.weightedPr.size(); i++) {
                PerformanceObject po = result.weightedPr.get(i);
                writer.write(String.format(GainChartTemplate.PRROC_DATA_FORMAT, GainChartTemplate.DF.format(po.weightedPrecision * 100), GainChartTemplate.DF.format(po.weightedRecall * 100), GainChartTemplate.DF.format(po.weightedPrecision * 100), GainChartTemplate.DF.format(po.weightedRecall * 100), GainChartTemplate.DF.format(po.weightedFpr * 100), GainChartTemplate.DF.format(po.weightedActionRate * 100), GainChartTemplate.DF.format(po.binLowestScore)));
                if (i != result.weightedPr.size() - 1) {
                    writer.write(",");
                }
            }
            writer.write("  ];\n");
            writer.write("\n");
        }
        for (int j = 0; j < results.size(); j++) {
            PerformanceResult result = results.get(j);
            writer.write("  var data_" + (results.size() + j) + " = [\n");
            for (int i = 0; i < result.weightedPr.size(); i++) {
                PerformanceObject po = result.weightedPr.get(i);
                writer.write(String.format(GainChartTemplate.PRROC_DATA_FORMAT, GainChartTemplate.DF.format(po.precision * 100), GainChartTemplate.DF.format(po.weightedRecall * 100), GainChartTemplate.DF.format(po.precision * 100), GainChartTemplate.DF.format(po.weightedRecall * 100), GainChartTemplate.DF.format(po.weightedFpr * 100), GainChartTemplate.DF.format(po.weightedActionRate * 100), GainChartTemplate.DF.format(po.binLowestScore)));
                if (i != result.weightedPr.size() - 1) {
                    writer.write(",");
                }
            }
            writer.write("  ];\n");
            writer.write("\n");
        }
        for (int j = 0; j < results.size(); j++) {
            PerformanceResult result = results.get(j);
            writer.write("  var data_" + (2 * results.size() + j) + " = [\n");
            for (int i = 0; i < result.pr.size(); i++) {
                PerformanceObject po = result.pr.get(i);
                writer.write(String.format(GainChartTemplate.PRROC_DATA_FORMAT, GainChartTemplate.DF.format(po.weightedPrecision * 100), GainChartTemplate.DF.format(po.recall * 100), GainChartTemplate.DF.format(po.weightedPrecision * 100), GainChartTemplate.DF.format(po.recall * 100), GainChartTemplate.DF.format(po.fpr * 100), GainChartTemplate.DF.format(po.actionRate * 100), GainChartTemplate.DF.format(po.binLowestScore)));
                if (i != result.pr.size() - 1) {
                    writer.write(",");
                }
            }
            writer.write("  ];\n");
            writer.write("\n");
        }
        for (int j = 0; j < results.size(); j++) {
            PerformanceResult result = results.get(j);
            writer.write("  var data_" + (3 * results.size() + j) + " = [\n");
            for (int i = 0; i < result.pr.size(); i++) {
                PerformanceObject po = result.pr.get(i);
                writer.write(String.format(GainChartTemplate.PRROC_DATA_FORMAT, GainChartTemplate.DF.format(po.precision * 100), GainChartTemplate.DF.format(po.recall * 100), GainChartTemplate.DF.format(po.precision * 100), GainChartTemplate.DF.format(po.recall * 100), GainChartTemplate.DF.format(po.fpr * 100), GainChartTemplate.DF.format(po.actionRate * 100), GainChartTemplate.DF.format(po.binLowestScore)));
                if (i != result.pr.size() - 1) {
                    writer.write(",");
                }
            }
            writer.write("  ];\n");
            writer.write("\n");
        }
        for (int j = 0; j < results.size(); j++) {
            PerformanceResult result = results.get(j);
            writer.write("  var data_" + (4 * results.size() + j) + " = [\n");
            for (int i = 0; i < result.weightedRoc.size(); i++) {
                PerformanceObject po = result.weightedRoc.get(i);
                writer.write(String.format(GainChartTemplate.PRROC_DATA_FORMAT, GainChartTemplate.DF.format(po.weightedRecall * 100), GainChartTemplate.DF.format(po.weightedFpr * 100), GainChartTemplate.DF.format(po.weightedPrecision * 100), GainChartTemplate.DF.format(po.weightedRecall * 100), GainChartTemplate.DF.format(po.weightedFpr * 100), GainChartTemplate.DF.format(po.weightedActionRate * 100), GainChartTemplate.DF.format(po.binLowestScore)));
                if (i != result.weightedRoc.size() - 1) {
                    writer.write(",");
                }
            }
            writer.write("  ];\n");
            writer.write("\n");
        }
        for (int j = 0; j < results.size(); j++) {
            PerformanceResult result = results.get(j);
            writer.write("  var data_" + (5 * results.size() + j) + " = [\n");
            for (int i = 0; i < result.weightedRoc.size(); i++) {
                PerformanceObject po = result.weightedRoc.get(i);
                writer.write(String.format(GainChartTemplate.PRROC_DATA_FORMAT, GainChartTemplate.DF.format(po.recall * 100), GainChartTemplate.DF.format(po.weightedFpr * 100), GainChartTemplate.DF.format(po.weightedPrecision * 100), GainChartTemplate.DF.format(po.recall * 100), GainChartTemplate.DF.format(po.weightedFpr * 100), GainChartTemplate.DF.format(po.weightedActionRate * 100), GainChartTemplate.DF.format(po.binLowestScore)));
                if (i != result.weightedRoc.size() - 1) {
                    writer.write(",");
                }
            }
            writer.write("  ];\n");
            writer.write("\n");
        }
        for (int j = 0; j < results.size(); j++) {
            PerformanceResult result = results.get(j);
            writer.write("  var data_" + (6 * results.size() + j) + " = [\n");
            for (int i = 0; i < result.roc.size(); i++) {
                PerformanceObject po = result.roc.get(i);
                writer.write(String.format(GainChartTemplate.PRROC_DATA_FORMAT, GainChartTemplate.DF.format(po.weightedRecall * 100), GainChartTemplate.DF.format(po.fpr * 100), GainChartTemplate.DF.format(po.precision * 100), GainChartTemplate.DF.format(po.weightedRecall * 100), GainChartTemplate.DF.format(po.fpr * 100), GainChartTemplate.DF.format(po.actionRate * 100), GainChartTemplate.DF.format(po.binLowestScore)));
                if (i != result.roc.size() - 1) {
                    writer.write(",");
                }
            }
            writer.write("  ];\n");
            writer.write("\n");
        }
        for (int j = 0; j < results.size(); j++) {
            PerformanceResult result = results.get(j);
            writer.write("  var data_" + (7 * results.size() + j) + " = [\n");
            for (int i = 0; i < result.roc.size(); i++) {
                PerformanceObject po = result.roc.get(i);
                writer.write(String.format(GainChartTemplate.PRROC_DATA_FORMAT, GainChartTemplate.DF.format(po.recall * 100), GainChartTemplate.DF.format(po.fpr * 100), GainChartTemplate.DF.format(po.precision * 100), GainChartTemplate.DF.format(po.recall * 100), GainChartTemplate.DF.format(po.fpr * 100), GainChartTemplate.DF.format(po.actionRate * 100), GainChartTemplate.DF.format(po.binLowestScore)));
                if (i != result.roc.size() - 1) {
                    writer.write(",");
                }
            }
            writer.write("  ];\n");
            writer.write("\n");
        }
        writer.write("$(function () {\n");
        writer.write(String.format(GainChartTemplate.HIGHCHART_CHART_TEMPLATE_PREFIX3, "container0", "Weighted Recall - Weighted Precision (PR Curve)", modelConfig.getBasic().getName(), "Weighte Precision", "Weighted Recall", "%", "false"));
        int currIndex = 0;
        writer.write("series: [");
        for (int i = 0; i < results.size(); i++) {
            writer.write("{");
            writer.write("  data: data_" + (currIndex++) + ",");
            writer.write("  name: '" + names.get(i) + "',");
            writer.write("  turboThreshold:0");
            writer.write("}");
            if (i != results.size() - 1) {
                writer.write(",");
            }
        }
        writer.write("]");
        writer.write("});");
        writer.write("\n");
        writer.write(String.format(GainChartTemplate.HIGHCHART_CHART_TEMPLATE_PREFIX3, "container1", "Weighted Recall - Unit-wise Precision (PR Curve)", modelConfig.getBasic().getName(), "Unit-wise Precision", "Weighted Recall", "%", "false"));
        writer.write("series: [");
        for (int i = 0; i < results.size(); i++) {
            writer.write("{");
            writer.write("  data: data_" + (currIndex++) + ",");
            writer.write("  name: '" + names.get(i) + "',");
            writer.write("  turboThreshold:0");
            writer.write("}");
            if (i != results.size() - 1) {
                writer.write(",");
            }
        }
        writer.write("]");
        writer.write("});");
        writer.write("\n");
        writer.write(String.format(GainChartTemplate.HIGHCHART_CHART_TEMPLATE_PREFIX3, "container2", "Unit-wise Recall - Weighted Precision (PR Curve)", modelConfig.getBasic().getName(), "Weighted Precision", "Unit-wise Recall", "%", "false"));
        writer.write("series: [");
        for (int i = 0; i < results.size(); i++) {
            writer.write("{");
            writer.write("  data: data_" + (currIndex++) + ",");
            writer.write("  name: '" + names.get(i) + "',");
            writer.write("  turboThreshold:0");
            writer.write("}");
            if (i != results.size() - 1) {
                writer.write(",");
            }
        }
        writer.write("]");
        writer.write("});");
        writer.write("\n");
        writer.write(String.format(GainChartTemplate.HIGHCHART_CHART_TEMPLATE_PREFIX3, "container3", "Unit-wise Recall - Unit-wise Precision (PR Curve)", modelConfig.getBasic().getName(), "Unit-wise Precision", "Unit-wise Recall", "%", "false"));
        writer.write("series: [");
        for (int i = 0; i < results.size(); i++) {
            writer.write("{");
            writer.write("  data: data_" + (currIndex++) + ",");
            writer.write("  name: '" + names.get(i) + "',");
            writer.write("  turboThreshold:0");
            writer.write("}");
            if (i != results.size() - 1) {
                writer.write(",");
            }
        }
        writer.write("]");
        writer.write("});");
        writer.write("\n");
        writer.write(String.format(GainChartTemplate.HIGHCHART_CHART_TEMPLATE_PREFIX3, "container4", "Weighted FPR - Weighted Recall (ROC Curve)", modelConfig.getBasic().getName(), "Weighted Recall", "Weighted FPR", "%", "false"));
        writer.write("series: [");
        for (int i = 0; i < results.size(); i++) {
            writer.write("{");
            writer.write("  data: data_" + (currIndex++) + ",");
            writer.write("  name: '" + names.get(i) + "',");
            writer.write("  turboThreshold:0");
            writer.write("}");
            if (i != results.size() - 1) {
                writer.write(",");
            }
        }
        writer.write("]");
        writer.write("});");
        writer.write("\n");
        writer.write(String.format(GainChartTemplate.HIGHCHART_CHART_TEMPLATE_PREFIX3, "container5", "Weighted FPR - Unit-wise Recall (ROC Curve)", modelConfig.getBasic().getName(), "Unit-wise Recall", "Weighted FPR", "%", "false"));
        writer.write("series: [");
        for (int i = 0; i < results.size(); i++) {
            writer.write("{");
            writer.write("  data: data_" + (currIndex++) + ",");
            writer.write("  name: '" + names.get(i) + "',");
            writer.write("  turboThreshold:0");
            writer.write("}");
            if (i != results.size() - 1) {
                writer.write(",");
            }
        }
        writer.write("]");
        writer.write("});");
        writer.write("\n");
        writer.write(String.format(GainChartTemplate.HIGHCHART_CHART_TEMPLATE_PREFIX3, "container6", "Unit-wise FPR - Weighted Recall (ROC Curve)", modelConfig.getBasic().getName(), "Weighted Recall", "Unit-wise FPR", "%", "false"));
        writer.write("series: [");
        for (int i = 0; i < results.size(); i++) {
            writer.write("{");
            writer.write("  data: data_" + (currIndex++) + ",");
            writer.write("  name: '" + names.get(i) + "',");
            writer.write("  turboThreshold:0");
            writer.write("}");
            if (i != results.size() - 1) {
                writer.write(",");
            }
        }
        writer.write("]");
        writer.write("});");
        writer.write("\n");
        writer.write(String.format(GainChartTemplate.HIGHCHART_CHART_TEMPLATE_PREFIX3, "container7", "Unit-wise FPR - Unit-wise Recall (ROC Curve)", modelConfig.getBasic().getName(), "Unit-wise Recall", "Unit-wise FPR", "%", "false"));
        writer.write("series: [");
        for (int i = 0; i < results.size(); i++) {
            writer.write("{");
            writer.write("  data: data_" + (currIndex++) + ",");
            writer.write("  name: '" + names.get(i) + "',");
            writer.write("  turboThreshold:0");
            writer.write("}");
            if (i != results.size() - 1) {
                writer.write(",");
            }
        }
        writer.write("]");
        writer.write("});");
        writer.write("\n");
        writer.write("});\n");
        writer.write("\n");
        writer.write("$(document).ready(function() {\n");
        writer.write(String.format(GainChartTemplate.HIGHCHART_LIST_TOGGLE_TEMPLATE, "lst0", "container0", "lst0"));
        writer.write("\n");
        writer.write(String.format(GainChartTemplate.HIGHCHART_LIST_TOGGLE_TEMPLATE, "lst1", "container1", "lst1"));
        writer.write("\n");
        writer.write(String.format(GainChartTemplate.HIGHCHART_LIST_TOGGLE_TEMPLATE, "lst2", "container2", "lst2"));
        writer.write("\n");
        writer.write(String.format(GainChartTemplate.HIGHCHART_LIST_TOGGLE_TEMPLATE, "lst3", "container3", "lst3"));
        writer.write("\n");
        writer.write(String.format(GainChartTemplate.HIGHCHART_LIST_TOGGLE_TEMPLATE, "lst4", "container4", "lst4"));
        writer.write("\n");
        writer.write(String.format(GainChartTemplate.HIGHCHART_LIST_TOGGLE_TEMPLATE, "lst5", "container5", "lst5"));
        writer.write("\n");
        writer.write(String.format(GainChartTemplate.HIGHCHART_LIST_TOGGLE_TEMPLATE, "lst6", "container6", "lst6"));
        writer.write("\n");
        writer.write(String.format(GainChartTemplate.HIGHCHART_LIST_TOGGLE_TEMPLATE, "lst7", "container7", "lst7"));
        writer.write("\n");
        writer.write("\n");
        writer.write("  var ics = ['#container1','#container2', '#container5','#container6'];\n");
        writer.write("  var icl = ics.length;\n");
        writer.write("  for (var i = 0; i < icl; i++) {\n");
        writer.write("      $(ics[i]).toggleClass('show');\n");
        writer.write("      $(ics[i]).toggleClass('hidden');\n");
        writer.write("      $(ics[i]).toggleClass('ls_chosen');\n");
        writer.write("  };\n");
        writer.write("\n");
        writer.write("});\n");
        writer.write("\n");
        writer.write("</script>\n");
        writer.write(GainChartTemplate.HIGHCHART_BASE_END);
    } finally {
        if (writer != null) {
            writer.close();
        }
    }
}
Also used : PerformanceResult(ml.shifu.shifu.container.obj.PerformanceResult) PerformanceObject(ml.shifu.shifu.container.PerformanceObject) BufferedWriter(java.io.BufferedWriter)

Example 4 with PerformanceResult

use of ml.shifu.shifu.container.obj.PerformanceResult in project shifu by ShifuML.

the class ConfusionMatrix method bufferedComputeConfusionMatrixAndPerformance.

public PerformanceResult bufferedComputeConfusionMatrixAndPerformance(long pigPosTags, long pigNegTags, double pigPosWeightTags, double pigNegWeightTags, long records, double maxPScore, double minPScore, String scoreDataPath, String evalPerformancePath, boolean isPrint, boolean isGenerateChart, int targetColumnIndex, int scoreColumnIndex, int weightColumnIndex, boolean isUseMaxMinScore) throws IOException {
    // 1. compute maxScore and minScore in case some cases score are not in [0, 1]
    double maxScore = 1d * scoreScale, minScore = 0d;
    if (isGBTNeedConvertScore()) {
    // if need convert to [0, 1], just keep max score to 1 and min score to 0 without doing anything
    } else {
        if (isUseMaxMinScore) {
            // TODO some cases maxPScore is already scaled, how to fix that issue
            maxScore = maxPScore;
            minScore = minPScore;
        } else {
        // otherwise, keep [0, 1]
        }
    }
    LOG.info("{} Transformed (scale included) max score is {}, transformed min score is {}", evalConfig.getGbtScoreConvertStrategy(), maxScore, minScore);
    SourceType sourceType = evalConfig.getDataSet().getSource();
    List<Scanner> scanners = ShifuFileUtils.getDataScanners(scoreDataPath, sourceType);
    LOG.info("Number of score files is {} in eval {}.", scanners.size(), evalConfig.getName());
    int numBucket = evalConfig.getPerformanceBucketNum();
    boolean hasWeight = StringUtils.isNotBlank(evalConfig.getDataSet().getWeightColumnName());
    boolean isDir = ShifuFileUtils.isDir(pathFinder.getEvalScorePath(evalConfig, sourceType), sourceType);
    List<PerformanceObject> FPRList = new ArrayList<PerformanceObject>(numBucket + 1);
    List<PerformanceObject> catchRateList = new ArrayList<PerformanceObject>(numBucket + 1);
    List<PerformanceObject> gainList = new ArrayList<PerformanceObject>(numBucket + 1);
    List<PerformanceObject> modelScoreList = new ArrayList<PerformanceObject>(numBucket + 1);
    List<PerformanceObject> FPRWeightList = new ArrayList<PerformanceObject>(numBucket + 1);
    List<PerformanceObject> catchRateWeightList = new ArrayList<PerformanceObject>(numBucket + 1);
    List<PerformanceObject> gainWeightList = new ArrayList<PerformanceObject>(numBucket + 1);
    double binScore = (maxScore - minScore) * 1d / numBucket, binCapacity = 1.0 / numBucket, scoreBinCount = 0, scoreBinWeigthedCount = 0;
    int fpBin = 1, tpBin = 1, gainBin = 1, fpWeightBin = 1, tpWeightBin = 1, gainWeightBin = 1, modelScoreBin = 1;
    long index = 0, cnt = 0, invalidTargetCnt = 0, invalidWgtCnt = 0;
    ConfusionMatrixObject prevCmo = buildInitalCmo(pigPosTags, pigNegTags, pigPosWeightTags, pigNegWeightTags, maxScore);
    PerformanceObject po = buildFirstPO(prevCmo);
    FPRList.add(po);
    catchRateList.add(po);
    gainList.add(po);
    FPRWeightList.add(po);
    catchRateWeightList.add(po);
    gainWeightList.add(po);
    modelScoreList.add(po);
    boolean isGBTScoreHalfCutoffStreategy = isGBTScoreHalfCutoffStreategy();
    boolean isGBTScoreMaxMinScaleStreategy = isGBTScoreMaxMinScaleStreategy();
    Splitter splitter = Splitter.on(delimiter).trimResults();
    for (Scanner scanner : scanners) {
        while (scanner.hasNext()) {
            if ((++cnt) % 100000L == 0L) {
                LOG.info("Loaded {} records.", cnt);
            }
            if ((!isDir) && cnt == 1) {
                // if the evaluation score file is the local file, skip the first line since we add
                continue;
            }
            // score is separated by default delimiter in our pig output format
            String[] raw = Lists.newArrayList(splitter.split(scanner.nextLine())).toArray(new String[0]);
            // tag check
            String tag = raw[targetColumnIndex];
            if (StringUtils.isBlank(tag) || (!posTags.contains(tag) && !negTags.contains(tag))) {
                invalidTargetCnt += 1;
                continue;
            }
            double weight = 1d;
            // if has weight
            if (weightColumnIndex > 0) {
                try {
                    weight = Double.parseDouble(raw[weightColumnIndex]);
                } catch (NumberFormatException e) {
                    invalidWgtCnt += 1;
                }
                if (weight < 0d) {
                    invalidWgtCnt += 1;
                    weight = 1d;
                }
            }
            double score = 0.0;
            try {
                score = Double.parseDouble(raw[scoreColumnIndex]);
            } catch (NumberFormatException e) {
                // user set the score column wrong ?
                if (Math.random() < 0.05) {
                    LOG.warn("The score column - {} is not number. Is score column set correctly?", raw[scoreColumnIndex]);
                }
                continue;
            }
            scoreBinCount += 1;
            scoreBinWeigthedCount += weight;
            ConfusionMatrixObject cmo = new ConfusionMatrixObject(prevCmo);
            if (posTags.contains(tag)) {
                // Positive Instance
                cmo.setTp(cmo.getTp() + 1);
                cmo.setFn(cmo.getFn() - 1);
                cmo.setWeightedTp(cmo.getWeightedTp() + weight * 1.0);
                cmo.setWeightedFn(cmo.getWeightedFn() - weight * 1.0);
            } else {
                // Negative Instance
                cmo.setFp(cmo.getFp() + 1);
                cmo.setTn(cmo.getTn() - 1);
                cmo.setWeightedFp(cmo.getWeightedFp() + weight * 1.0);
                cmo.setWeightedTn(cmo.getWeightedTn() - weight * 1.0);
            }
            if (isGBTScoreHalfCutoffStreategy) {
                // use max min scale to rescale to [0, 1]
                if (score < 0d) {
                    score = 0d;
                }
                score = ((score - 0) * scoreScale) / (maxPScore - 0);
            } else if (isGBTScoreMaxMinScaleStreategy) {
                // use max min scaler to make score in [0, 1], don't foget to time scoreScale
                score = ((score - minPScore) * scoreScale) / (maxPScore - minPScore);
            } else {
            // do nothing, use current score
            }
            cmo.setScore(Double.parseDouble(SCORE_FORMAT.format(score)));
            ConfusionMatrixObject object = cmo;
            po = PerformanceEvaluator.setPerformanceObject(object);
            if (po.fpr >= fpBin * binCapacity) {
                po.binNum = fpBin++;
                FPRList.add(po);
            }
            if (po.recall >= tpBin * binCapacity) {
                po.binNum = tpBin++;
                catchRateList.add(po);
            }
            // prevent 99%
            double validRecordCnt = (double) (index + 1);
            if (validRecordCnt / (pigPosTags + pigNegTags) >= gainBin * binCapacity) {
                po.binNum = gainBin++;
                gainList.add(po);
            }
            if (po.weightedFpr >= fpWeightBin * binCapacity) {
                po.binNum = fpWeightBin++;
                FPRWeightList.add(po);
            }
            if (po.weightedRecall >= tpWeightBin * binCapacity) {
                po.binNum = tpWeightBin++;
                catchRateWeightList.add(po);
            }
            if ((object.getWeightedTp() + object.getWeightedFp()) / object.getWeightedTotal() >= gainWeightBin * binCapacity) {
                po.binNum = gainWeightBin++;
                gainWeightList.add(po);
            }
            if ((maxScore - (modelScoreBin * binScore)) >= score) {
                po.binNum = modelScoreBin++;
                po.scoreCount = scoreBinCount;
                po.scoreWgtCount = scoreBinWeigthedCount;
                // System.out.println("score count is " + scoreBinCount);
                // reset to 0 for next bin score cnt stats
                scoreBinCount = scoreBinWeigthedCount = 0;
                modelScoreList.add(po);
            }
            index += 1;
            prevCmo = cmo;
        }
        scanner.close();
    }
    LOG.info("Totally loading {} records with invalid target records {} and invalid weight records {} in eval {}.", cnt, invalidTargetCnt, invalidWgtCnt, evalConfig.getName());
    PerformanceResult result = buildPerfResult(FPRList, catchRateList, gainList, modelScoreList, FPRWeightList, catchRateWeightList, gainWeightList);
    synchronized (this.lock) {
        if (isPrint) {
            PerformanceEvaluator.logResult(FPRList, "Bucketing False Positive Rate");
            if (hasWeight) {
                PerformanceEvaluator.logResult(FPRWeightList, "Bucketing Weighted False Positive Rate");
            }
            PerformanceEvaluator.logResult(catchRateList, "Bucketing Catch Rate");
            if (hasWeight) {
                PerformanceEvaluator.logResult(catchRateWeightList, "Bucketing Weighted Catch Rate");
            }
            PerformanceEvaluator.logResult(gainList, "Bucketing Action Rate");
            if (hasWeight) {
                PerformanceEvaluator.logResult(gainWeightList, "Bucketing Weighted Action Rate");
            }
            PerformanceEvaluator.logAucResult(result, hasWeight);
        }
        writePerResult2File(evalPerformancePath, result);
        if (isGenerateChart) {
            generateChartAndJsonPerfFiles(hasWeight, result);
        }
    }
    if (cnt == 0) {
        LOG.error("No score read, the EvalScore did not genernate or is null file");
        throw new ShifuException(ShifuErrorCode.ERROR_EVALSCORE);
    }
    return result;
}
Also used : Scanner(java.util.Scanner) Splitter(com.google.common.base.Splitter) PerformanceObject(ml.shifu.shifu.container.PerformanceObject) SourceType(ml.shifu.shifu.container.obj.RawSourceData.SourceType) ArrayList(java.util.ArrayList) ConfusionMatrixObject(ml.shifu.shifu.container.ConfusionMatrixObject) PerformanceResult(ml.shifu.shifu.container.obj.PerformanceResult) ShifuException(ml.shifu.shifu.exception.ShifuException)

Example 5 with PerformanceResult

use of ml.shifu.shifu.container.obj.PerformanceResult in project shifu by ShifuML.

the class ConfusionMatrix method buildPerfResult.

private PerformanceResult buildPerfResult(List<PerformanceObject> FPRList, List<PerformanceObject> catchRateList, List<PerformanceObject> gainList, List<PerformanceObject> modelScoreList, List<PerformanceObject> FPRWeightList, List<PerformanceObject> catchRateWeightList, List<PerformanceObject> gainWeightList) {
    PerformanceResult result = new PerformanceResult();
    result.version = Constants.version;
    result.pr = catchRateList;
    result.weightedPr = catchRateWeightList;
    result.roc = FPRList;
    result.weightedRoc = FPRWeightList;
    result.gains = gainList;
    result.weightedGains = gainWeightList;
    result.modelScoreList = modelScoreList;
    // Calculate area under curve
    result.areaUnderRoc = AreaUnderCurve.ofRoc(result.roc);
    result.weightedAreaUnderRoc = AreaUnderCurve.ofWeightedRoc(result.weightedRoc);
    result.areaUnderPr = AreaUnderCurve.ofPr(result.pr);
    result.weightedAreaUnderPr = AreaUnderCurve.ofWeightedPr(result.weightedPr);
    return result;
}
Also used : PerformanceResult(ml.shifu.shifu.container.obj.PerformanceResult)

Aggregations

PerformanceResult (ml.shifu.shifu.container.obj.PerformanceResult)7 PerformanceObject (ml.shifu.shifu.container.PerformanceObject)4 ArrayList (java.util.ArrayList)3 BufferedWriter (java.io.BufferedWriter)2 ConfusionMatrixObject (ml.shifu.shifu.container.ConfusionMatrixObject)2 Splitter (com.google.common.base.Splitter)1 IOException (java.io.IOException)1 Writer (java.io.Writer)1 Scanner (java.util.Scanner)1 SourceType (ml.shifu.shifu.container.obj.RawSourceData.SourceType)1 GainChart (ml.shifu.shifu.core.eval.GainChart)1 ShifuException (ml.shifu.shifu.exception.ShifuException)1 PathFinder (ml.shifu.shifu.fs.PathFinder)1