Search in sources :

Example 1 with GainChart

use of ml.shifu.shifu.core.eval.GainChart in project shifu by ShifuML.

the class ConfusionMatrix method generateChartAndJsonPerfFiles.

private void generateChartAndJsonPerfFiles(boolean hasWeight, PerformanceResult result) throws IOException {
    GainChart gc = new GainChart();
    String htmlGainChart = pathFinder.getEvalFilePath(evalConfig.getName(), evalConfig.getName() + "_gainchart.html", SourceType.LOCAL);
    LOG.info("Gain chart is generated in {}.", htmlGainChart);
    gc.generateHtml(evalConfig, modelConfig, htmlGainChart, result);
    String htmlPrRocChart = pathFinder.getEvalFilePath(evalConfig.getName(), evalConfig.getName() + "_prroc.html", SourceType.LOCAL);
    LOG.info("PR&ROC chart is generated in {}.", htmlPrRocChart);
    gc.generateHtml4PrAndRoc(evalConfig, modelConfig, htmlPrRocChart, result);
    String unitGainChartCsv = pathFinder.getEvalFilePath(evalConfig.getName(), evalConfig.getName() + "_unit_wise_gainchart.csv", SourceType.LOCAL);
    LOG.info("Unit-wise gain chart data is generated in {}.", unitGainChartCsv);
    gc.generateCsv(evalConfig, modelConfig, unitGainChartCsv, result.gains);
    if (hasWeight) {
        String weightedGainChartCsv = pathFinder.getEvalFilePath(evalConfig.getName(), evalConfig.getName() + "_weighted_gainchart.csv", SourceType.LOCAL);
        LOG.info("Weighted gain chart data is generated in {}.", weightedGainChartCsv);
        gc.generateCsv(evalConfig, modelConfig, weightedGainChartCsv, result.weightedGains);
    }
    String prCsvFile = pathFinder.getEvalFilePath(evalConfig.getName(), evalConfig.getName() + "_unit_wise_pr.csv", SourceType.LOCAL);
    LOG.info("Unit-wise pr data is generated in {}.", prCsvFile);
    gc.generateCsv(evalConfig, modelConfig, prCsvFile, result.pr);
    if (hasWeight) {
        String weightedPrCsvFile = pathFinder.getEvalFilePath(evalConfig.getName(), evalConfig.getName() + "_weighted_pr.csv", SourceType.LOCAL);
        LOG.info("Weighted pr data is generated in {}.", weightedPrCsvFile);
        gc.generateCsv(evalConfig, modelConfig, weightedPrCsvFile, result.weightedPr);
    }
    String rocCsvFile = pathFinder.getEvalFilePath(evalConfig.getName(), evalConfig.getName() + "_unit_wise_roc.csv", SourceType.LOCAL);
    LOG.info("Unit-wise roc data is generated in {}.", rocCsvFile);
    gc.generateCsv(evalConfig, modelConfig, rocCsvFile, result.roc);
    if (hasWeight) {
        String weightedRocCsvFile = pathFinder.getEvalFilePath(evalConfig.getName(), evalConfig.getName() + "_weighted_roc.csv", SourceType.LOCAL);
        LOG.info("Weighted roc data is generated in {}.", weightedRocCsvFile);
        gc.generateCsv(evalConfig, modelConfig, weightedRocCsvFile, result.weightedRoc);
    }
    String modelScoreGainChartCsv = pathFinder.getEvalFilePath(evalConfig.getName(), evalConfig.getName() + "_modelscore_gainchart.csv", SourceType.LOCAL);
    LOG.info("Model score gain chart data is generated in {}.", modelScoreGainChartCsv);
    gc.generateCsv(evalConfig, modelConfig, modelScoreGainChartCsv, result.modelScoreList);
}
Also used : GainChart(ml.shifu.shifu.core.eval.GainChart)

Example 2 with GainChart

use of ml.shifu.shifu.core.eval.GainChart in project shifu by ShifuML.

the class EvalModelProcessor method runDistEval.

/**
 * Run distributed version of evaluation and performance review.
 *
 * @param evalConfig
 *            the evaluation instance
 * @throws IOException
 *             when any exception in delete the old tmp files
 */
private void runDistEval(EvalConfig evalConfig) throws IOException {
    ScoreStatus ss = runDistScore(evalConfig);
    List<String> scoreMetaColumns = evalConfig.getScoreMetaColumns(modelConfig);
    if (scoreMetaColumns == null || scoreMetaColumns.isEmpty() || !modelConfig.isRegression()) {
        // if no any champion score column set, go to previous evaluation with only challendge model
        runConfusionMatrix(evalConfig, ss, isGBTNotConvertToProb(evalConfig));
        return;
    }
    // 1. Get challenge model performance
    PerformanceResult challengeModelPerformance = runConfusionMatrix(evalConfig, ss, pathFinder.getEvalScorePath(evalConfig), pathFinder.getEvalPerformancePath(evalConfig), false, false, isGBTNotConvertToProb(evalConfig));
    List<PerformanceResult> prList = new ArrayList<PerformanceResult>();
    prList.add(challengeModelPerformance);
    // 2. Get all champion model performance
    List<String> names = new ArrayList<String>();
    names.add(modelConfig.getBasic().getName() + "-" + evalConfig.getName());
    for (String metaScoreColumn : scoreMetaColumns) {
        if (StringUtils.isBlank(metaScoreColumn)) {
            continue;
        }
        names.add(metaScoreColumn);
        LOG.info("Model score sort for {} in eval {} is started.", metaScoreColumn, evalConfig.getName());
        ScoreStatus newScoreStatus = runDistMetaScore(evalConfig, metaScoreColumn);
        PerformanceResult championModelPerformance = runConfusionMatrix(evalConfig, newScoreStatus, pathFinder.getEvalMetaScorePath(evalConfig, metaScoreColumn), pathFinder.getEvalMetaPerformancePath(evalConfig, metaScoreColumn), false, false, 0, 1, 2);
        prList.add(championModelPerformance);
    }
    synchronized (this) {
        GainChart gc = new GainChart();
        boolean hasWeight = StringUtils.isNotBlank(evalConfig.getDataSet().getWeightColumnName());
        // 3. Compute gain chart and other eval performance files only in local.
        String htmlGainChart = pathFinder.getEvalFilePath(evalConfig.getName(), evalConfig.getName() + "_gainchart.html", SourceType.LOCAL);
        LOG.info("Gain chart is generated in {}.", htmlGainChart);
        gc.generateHtml(evalConfig, modelConfig, htmlGainChart, prList, names);
        String hrmlPrRoc = pathFinder.getEvalFilePath(evalConfig.getName(), evalConfig.getName() + "_prroc.html", SourceType.LOCAL);
        LOG.info("PR & ROC chart is generated in {}.", hrmlPrRoc);
        gc.generateHtml4PrAndRoc(evalConfig, modelConfig, hrmlPrRoc, prList, names);
        for (int i = 0; i < names.size(); i++) {
            String name = names.get(i);
            PerformanceResult pr = prList.get(i);
            String unitGainChartCsv = pathFinder.getEvalFilePath(evalConfig.getName(), name + "_unit_wise_gainchart.csv", SourceType.LOCAL);
            LOG.info("Unit-wise gain chart data is generated in {} for eval {} and name {}.", unitGainChartCsv, evalConfig.getName(), name);
            gc.generateCsv(evalConfig, modelConfig, unitGainChartCsv, pr.gains);
            if (hasWeight) {
                String weightedGainChartCsv = pathFinder.getEvalFilePath(evalConfig.getName(), name + "_weighted_gainchart.csv", SourceType.LOCAL);
                LOG.info("Weighted gain chart data is generated in {} for eval {} and name {}.", weightedGainChartCsv, evalConfig.getName(), name);
                gc.generateCsv(evalConfig, modelConfig, weightedGainChartCsv, pr.weightedGains);
            }
            String prCsvFile = pathFinder.getEvalFilePath(evalConfig.getName(), name + "_unit_wise_pr.csv", SourceType.LOCAL);
            LOG.info("Unit-wise pr data is generated in {} for eval {} and name {}.", prCsvFile, evalConfig.getName(), name);
            gc.generateCsv(evalConfig, modelConfig, prCsvFile, pr.pr);
            if (hasWeight) {
                String weightedPrCsvFile = pathFinder.getEvalFilePath(evalConfig.getName(), name + "_weighted_pr.csv", SourceType.LOCAL);
                LOG.info("Weighted pr data is generated in {} for eval {} and name {}.", weightedPrCsvFile, evalConfig.getName(), name);
                gc.generateCsv(evalConfig, modelConfig, weightedPrCsvFile, pr.weightedPr);
            }
            String rocCsvFile = pathFinder.getEvalFilePath(evalConfig.getName(), name + "_unit_wise_roc.csv", SourceType.LOCAL);
            LOG.info("Unit-wise roc data is generated in {} for eval {} and name {}.", rocCsvFile, evalConfig.getName(), name);
            gc.generateCsv(evalConfig, modelConfig, rocCsvFile, pr.roc);
            if (hasWeight) {
                String weightedRocCsvFile = pathFinder.getEvalFilePath(evalConfig.getName(), name + "_weighted_roc.csv", SourceType.LOCAL);
                LOG.info("Weighted roc data is generated in {} for eval {} and name {}.", weightedRocCsvFile, evalConfig.getName(), name);
                gc.generateCsv(evalConfig, modelConfig, weightedRocCsvFile, pr.weightedRoc);
            }
            String modelScoreGainChartCsv = pathFinder.getEvalFilePath(evalConfig.getName(), name + "_modelscore_gainchart.csv", SourceType.LOCAL);
            LOG.info("Model score gain chart data is generated in {} for eval {} and name {}.", modelScoreGainChartCsv, evalConfig.getName(), name);
            gc.generateCsv(evalConfig, modelConfig, modelScoreGainChartCsv, pr.modelScoreList);
        }
        LOG.info("Performance Evaluation is done for {}.", evalConfig.getName());
    }
}
Also used : PerformanceResult(ml.shifu.shifu.container.obj.PerformanceResult) ArrayList(java.util.ArrayList) GainChart(ml.shifu.shifu.core.eval.GainChart)

Aggregations

GainChart (ml.shifu.shifu.core.eval.GainChart)2 ArrayList (java.util.ArrayList)1 PerformanceResult (ml.shifu.shifu.container.obj.PerformanceResult)1