use of ml.shifu.shifu.core.posttrain.FeatureStatsWritable.BinStats in project shifu by ShifuML.
the class PostTrainMapper method map.
@Override
protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
String valueStr = value.toString();
// StringUtils.isBlank is not used here to avoid import new jar
if (valueStr == null || valueStr.length() == 0 || valueStr.trim().length() == 0) {
LOG.warn("Empty input.");
return;
}
if (!this.dataPurifier.isFilter(valueStr)) {
return;
}
String[] units = CommonUtils.split(valueStr, this.modelConfig.getDataSetDelimiter());
// tagColumnNum should be in units array, if not IndexOutofBoundException
String tag = CommonUtils.trimTag(units[this.tagColumnNum]);
if (!this.tags.contains(tag)) {
if (System.currentTimeMillis() % 20 == 0) {
LOG.warn("Data with invalid tag is ignored in post train, invalid tag: {}.", tag);
}
context.getCounter(Constants.SHIFU_GROUP_COUNTER, "INVALID_TAG").increment(1L);
return;
}
Map<String, String> rawDataMap = buildRawDataMap(units);
CaseScoreResult csr = this.modelRunner.compute(rawDataMap);
// store score value
StringBuilder sb = new StringBuilder(500);
sb.append(csr.getAvgScore()).append(Constants.DEFAULT_DELIMITER).append(csr.getMaxScore()).append(Constants.DEFAULT_DELIMITER).append(csr.getMinScore()).append(Constants.DEFAULT_DELIMITER);
for (Double score : csr.getScores()) {
sb.append(score).append(Constants.DEFAULT_DELIMITER);
}
List<String> metaList = modelConfig.getMetaColumnNames();
for (String meta : metaList) {
sb.append(rawDataMap.get(meta)).append(Constants.DEFAULT_DELIMITER);
}
sb.deleteCharAt(sb.length() - Constants.DEFAULT_DELIMITER.length());
this.outputValue.set(sb.toString());
this.mos.write(Constants.POST_TRAIN_OUTPUT_SCORE, NullWritable.get(), this.outputValue);
for (int i = 0; i < headers.length; i++) {
ColumnConfig config = this.columnConfigList.get(i);
if (!config.isMeta() && !config.isTarget() && config.isFinalSelect()) {
int binNum = BinUtils.getBinNum(config, units[i]);
List<BinStats> feaureStatistics = this.variableStatsMap.get(config.getColumnNum());
BinStats bs = null;
if (binNum == -1) {
// if -1, means invalid numeric value like null or empty, last one is for empty stats.
bs = feaureStatistics.get(feaureStatistics.size() - 1);
} else {
bs = feaureStatistics.get(binNum);
}
// bs should not be null as already initialized in setup
bs.setBinSum(csr.getAvgScore() + bs.getBinSum());
bs.setBinCnt(1L + bs.getBinCnt());
}
}
}
use of ml.shifu.shifu.core.posttrain.FeatureStatsWritable.BinStats in project shifu by ShifuML.
the class PostTrainMapper method initFeatureStats.
private void initFeatureStats() {
this.variableStatsMap = new HashMap<Integer, List<BinStats>>();
for (ColumnConfig config : this.columnConfigList) {
if (!config.isMeta() && !config.isTarget() && config.isFinalSelect()) {
List<BinStats> feaureStatistics = null;
int binSize = 0;
if (config.isNumerical()) {
binSize = config.getBinBoundary().size() + 1;
}
if (config.isCategorical()) {
binSize = config.getBinCategory().size();
}
feaureStatistics = new ArrayList<BinStats>(binSize);
for (int i = 0; i < binSize; i++) {
feaureStatistics.add(new BinStats(0, 0));
}
this.variableStatsMap.put(config.getColumnNum(), feaureStatistics);
}
}
}
use of ml.shifu.shifu.core.posttrain.FeatureStatsWritable.BinStats in project shifu by ShifuML.
the class PostTrainReducer method reduce.
@Override
protected void reduce(IntWritable key, Iterable<FeatureStatsWritable> values, Context context) throws IOException, InterruptedException {
List<BinStats> binStats = null;
for (FeatureStatsWritable fsw : values) {
if (binStats == null) {
binStats = fsw.getBinStats();
} else {
for (int i = 0; i < binStats.size(); i++) {
BinStats rbs = binStats.get(i);
BinStats bs = fsw.getBinStats().get(i);
rbs.setBinSum(rbs.getBinSum() + bs.getBinSum());
rbs.setBinCnt(rbs.getBinCnt() + bs.getBinCnt());
}
}
}
StringBuilder sb = new StringBuilder(150);
for (int i = 0; i < binStats.size(); i++) {
BinStats bs = binStats.get(i);
int avgScore = 0;
if (bs.getBinCnt() != 0L) {
avgScore = (int) (bs.getBinSum() / bs.getBinCnt());
}
if (i == binStats.size() - 1) {
sb.append(avgScore);
} else {
sb.append(avgScore).append(',');
}
}
LOG.info(key.toString() + " " + sb.toString());
this.outputValue.set(sb.toString());
context.write(key, this.outputValue);
}
Aggregations