use of ml.shifu.shifu.core.eval.GainChart in project shifu by ShifuML.
the class ConfusionMatrix method generateChartAndJsonPerfFiles.
private void generateChartAndJsonPerfFiles(boolean hasWeight, PerformanceResult result) throws IOException {
GainChart gc = new GainChart();
String htmlGainChart = pathFinder.getEvalFilePath(evalConfig.getName(), evalConfig.getName() + "_gainchart.html", SourceType.LOCAL);
LOG.info("Gain chart is generated in {}.", htmlGainChart);
gc.generateHtml(evalConfig, modelConfig, htmlGainChart, result);
String htmlPrRocChart = pathFinder.getEvalFilePath(evalConfig.getName(), evalConfig.getName() + "_prroc.html", SourceType.LOCAL);
LOG.info("PR&ROC chart is generated in {}.", htmlPrRocChart);
gc.generateHtml4PrAndRoc(evalConfig, modelConfig, htmlPrRocChart, result);
String unitGainChartCsv = pathFinder.getEvalFilePath(evalConfig.getName(), evalConfig.getName() + "_unit_wise_gainchart.csv", SourceType.LOCAL);
LOG.info("Unit-wise gain chart data is generated in {}.", unitGainChartCsv);
gc.generateCsv(evalConfig, modelConfig, unitGainChartCsv, result.gains);
if (hasWeight) {
String weightedGainChartCsv = pathFinder.getEvalFilePath(evalConfig.getName(), evalConfig.getName() + "_weighted_gainchart.csv", SourceType.LOCAL);
LOG.info("Weighted gain chart data is generated in {}.", weightedGainChartCsv);
gc.generateCsv(evalConfig, modelConfig, weightedGainChartCsv, result.weightedGains);
}
String prCsvFile = pathFinder.getEvalFilePath(evalConfig.getName(), evalConfig.getName() + "_unit_wise_pr.csv", SourceType.LOCAL);
LOG.info("Unit-wise pr data is generated in {}.", prCsvFile);
gc.generateCsv(evalConfig, modelConfig, prCsvFile, result.pr);
if (hasWeight) {
String weightedPrCsvFile = pathFinder.getEvalFilePath(evalConfig.getName(), evalConfig.getName() + "_weighted_pr.csv", SourceType.LOCAL);
LOG.info("Weighted pr data is generated in {}.", weightedPrCsvFile);
gc.generateCsv(evalConfig, modelConfig, weightedPrCsvFile, result.weightedPr);
}
String rocCsvFile = pathFinder.getEvalFilePath(evalConfig.getName(), evalConfig.getName() + "_unit_wise_roc.csv", SourceType.LOCAL);
LOG.info("Unit-wise roc data is generated in {}.", rocCsvFile);
gc.generateCsv(evalConfig, modelConfig, rocCsvFile, result.roc);
if (hasWeight) {
String weightedRocCsvFile = pathFinder.getEvalFilePath(evalConfig.getName(), evalConfig.getName() + "_weighted_roc.csv", SourceType.LOCAL);
LOG.info("Weighted roc data is generated in {}.", weightedRocCsvFile);
gc.generateCsv(evalConfig, modelConfig, weightedRocCsvFile, result.weightedRoc);
}
String modelScoreGainChartCsv = pathFinder.getEvalFilePath(evalConfig.getName(), evalConfig.getName() + "_modelscore_gainchart.csv", SourceType.LOCAL);
LOG.info("Model score gain chart data is generated in {}.", modelScoreGainChartCsv);
gc.generateCsv(evalConfig, modelConfig, modelScoreGainChartCsv, result.modelScoreList);
}
use of ml.shifu.shifu.core.eval.GainChart in project shifu by ShifuML.
the class EvalModelProcessor method runDistEval.
/**
* Run distributed version of evaluation and performance review.
*
* @param evalConfig
* the evaluation instance
* @throws IOException
* when any exception in delete the old tmp files
*/
private void runDistEval(EvalConfig evalConfig) throws IOException {
ScoreStatus ss = runDistScore(evalConfig);
List<String> scoreMetaColumns = evalConfig.getScoreMetaColumns(modelConfig);
if (scoreMetaColumns == null || scoreMetaColumns.isEmpty() || !modelConfig.isRegression()) {
// if no any champion score column set, go to previous evaluation with only challendge model
runConfusionMatrix(evalConfig, ss, isGBTNotConvertToProb(evalConfig));
return;
}
// 1. Get challenge model performance
PerformanceResult challengeModelPerformance = runConfusionMatrix(evalConfig, ss, pathFinder.getEvalScorePath(evalConfig), pathFinder.getEvalPerformancePath(evalConfig), false, false, isGBTNotConvertToProb(evalConfig));
List<PerformanceResult> prList = new ArrayList<PerformanceResult>();
prList.add(challengeModelPerformance);
// 2. Get all champion model performance
List<String> names = new ArrayList<String>();
names.add(modelConfig.getBasic().getName() + "-" + evalConfig.getName());
for (String metaScoreColumn : scoreMetaColumns) {
if (StringUtils.isBlank(metaScoreColumn)) {
continue;
}
names.add(metaScoreColumn);
LOG.info("Model score sort for {} in eval {} is started.", metaScoreColumn, evalConfig.getName());
ScoreStatus newScoreStatus = runDistMetaScore(evalConfig, metaScoreColumn);
PerformanceResult championModelPerformance = runConfusionMatrix(evalConfig, newScoreStatus, pathFinder.getEvalMetaScorePath(evalConfig, metaScoreColumn), pathFinder.getEvalMetaPerformancePath(evalConfig, metaScoreColumn), false, false, 0, 1, 2);
prList.add(championModelPerformance);
}
synchronized (this) {
GainChart gc = new GainChart();
boolean hasWeight = StringUtils.isNotBlank(evalConfig.getDataSet().getWeightColumnName());
// 3. Compute gain chart and other eval performance files only in local.
String htmlGainChart = pathFinder.getEvalFilePath(evalConfig.getName(), evalConfig.getName() + "_gainchart.html", SourceType.LOCAL);
LOG.info("Gain chart is generated in {}.", htmlGainChart);
gc.generateHtml(evalConfig, modelConfig, htmlGainChart, prList, names);
String hrmlPrRoc = pathFinder.getEvalFilePath(evalConfig.getName(), evalConfig.getName() + "_prroc.html", SourceType.LOCAL);
LOG.info("PR & ROC chart is generated in {}.", hrmlPrRoc);
gc.generateHtml4PrAndRoc(evalConfig, modelConfig, hrmlPrRoc, prList, names);
for (int i = 0; i < names.size(); i++) {
String name = names.get(i);
PerformanceResult pr = prList.get(i);
String unitGainChartCsv = pathFinder.getEvalFilePath(evalConfig.getName(), name + "_unit_wise_gainchart.csv", SourceType.LOCAL);
LOG.info("Unit-wise gain chart data is generated in {} for eval {} and name {}.", unitGainChartCsv, evalConfig.getName(), name);
gc.generateCsv(evalConfig, modelConfig, unitGainChartCsv, pr.gains);
if (hasWeight) {
String weightedGainChartCsv = pathFinder.getEvalFilePath(evalConfig.getName(), name + "_weighted_gainchart.csv", SourceType.LOCAL);
LOG.info("Weighted gain chart data is generated in {} for eval {} and name {}.", weightedGainChartCsv, evalConfig.getName(), name);
gc.generateCsv(evalConfig, modelConfig, weightedGainChartCsv, pr.weightedGains);
}
String prCsvFile = pathFinder.getEvalFilePath(evalConfig.getName(), name + "_unit_wise_pr.csv", SourceType.LOCAL);
LOG.info("Unit-wise pr data is generated in {} for eval {} and name {}.", prCsvFile, evalConfig.getName(), name);
gc.generateCsv(evalConfig, modelConfig, prCsvFile, pr.pr);
if (hasWeight) {
String weightedPrCsvFile = pathFinder.getEvalFilePath(evalConfig.getName(), name + "_weighted_pr.csv", SourceType.LOCAL);
LOG.info("Weighted pr data is generated in {} for eval {} and name {}.", weightedPrCsvFile, evalConfig.getName(), name);
gc.generateCsv(evalConfig, modelConfig, weightedPrCsvFile, pr.weightedPr);
}
String rocCsvFile = pathFinder.getEvalFilePath(evalConfig.getName(), name + "_unit_wise_roc.csv", SourceType.LOCAL);
LOG.info("Unit-wise roc data is generated in {} for eval {} and name {}.", rocCsvFile, evalConfig.getName(), name);
gc.generateCsv(evalConfig, modelConfig, rocCsvFile, pr.roc);
if (hasWeight) {
String weightedRocCsvFile = pathFinder.getEvalFilePath(evalConfig.getName(), name + "_weighted_roc.csv", SourceType.LOCAL);
LOG.info("Weighted roc data is generated in {} for eval {} and name {}.", weightedRocCsvFile, evalConfig.getName(), name);
gc.generateCsv(evalConfig, modelConfig, weightedRocCsvFile, pr.weightedRoc);
}
String modelScoreGainChartCsv = pathFinder.getEvalFilePath(evalConfig.getName(), name + "_modelscore_gainchart.csv", SourceType.LOCAL);
LOG.info("Model score gain chart data is generated in {} for eval {} and name {}.", modelScoreGainChartCsv, evalConfig.getName(), name);
gc.generateCsv(evalConfig, modelConfig, modelScoreGainChartCsv, pr.modelScoreList);
}
LOG.info("Performance Evaluation is done for {}.", evalConfig.getName());
}
}
Aggregations