use of ml.shifu.shifu.core.autotype.AutoTypeDistinctCountMapper.CountAndFrequentItems in project shifu by ShifuML.
the class UpdateBinningInfoMapper method populateStats.
private void populateStats(String[] units, String tag, Double weight, int columnIndex, int newCCIndex) {
ColumnConfig columnConfig = this.columnConfigList.get(columnIndex);
CountAndFrequentItems countAndFrequentItems = this.variableCountMap.get(newCCIndex);
if (countAndFrequentItems == null) {
countAndFrequentItems = new CountAndFrequentItems();
this.variableCountMap.put(newCCIndex, countAndFrequentItems);
}
countAndFrequentItems.offer(this.missingOrInvalidValues, units[columnIndex]);
boolean isMissingValue = false;
boolean isInvalidValue = false;
BinningInfoWritable binningInfoWritable = this.columnBinningInfo.get(newCCIndex);
if (binningInfoWritable == null) {
return;
}
binningInfoWritable.setTotalCount(binningInfoWritable.getTotalCount() + 1L);
if (columnConfig.isHybrid()) {
int binNum = 0;
if (units[columnIndex] == null || missingOrInvalidValues.contains(units[columnIndex].toLowerCase())) {
isMissingValue = true;
}
String str = units[columnIndex];
double douVal = BinUtils.parseNumber(str);
Double hybridThreshold = columnConfig.getHybridThreshold();
if (hybridThreshold == null) {
hybridThreshold = Double.NEGATIVE_INFINITY;
}
// douVal < hybridThreshould which will also be set to category
boolean isCategory = Double.isNaN(douVal) || douVal < hybridThreshold;
boolean isNumber = !Double.isNaN(douVal);
if (isMissingValue) {
binningInfoWritable.setMissingCount(binningInfoWritable.getMissingCount() + 1L);
binNum = binningInfoWritable.getBinCategories().size() + binningInfoWritable.getBinBoundaries().size();
} else if (isCategory) {
// get categorical bin number in category list
binNum = quickLocateCategoricalBin(this.categoricalBinMap.get(newCCIndex), str);
if (binNum < 0) {
isInvalidValue = true;
}
if (isInvalidValue) {
// the same as missing count
binningInfoWritable.setMissingCount(binningInfoWritable.getMissingCount() + 1L);
binNum = binningInfoWritable.getBinCategories().size() + binningInfoWritable.getBinBoundaries().size();
} else {
// if real category value, binNum should + binBoundaries.size
binNum += binningInfoWritable.getBinBoundaries().size();
;
}
} else if (isNumber) {
binNum = getBinNum(binningInfoWritable.getBinBoundaries(), douVal);
if (binNum == -1) {
throw new RuntimeException("binNum should not be -1 to this step.");
}
// other stats are treated as numerical features
binningInfoWritable.setSum(binningInfoWritable.getSum() + douVal);
double squaredVal = douVal * douVal;
binningInfoWritable.setSquaredSum(binningInfoWritable.getSquaredSum() + squaredVal);
binningInfoWritable.setTripleSum(binningInfoWritable.getTripleSum() + squaredVal * douVal);
binningInfoWritable.setQuarticSum(binningInfoWritable.getQuarticSum() + squaredVal * squaredVal);
if (Double.compare(binningInfoWritable.getMax(), douVal) < 0) {
binningInfoWritable.setMax(douVal);
}
if (Double.compare(binningInfoWritable.getMin(), douVal) > 0) {
binningInfoWritable.setMin(douVal);
}
}
if (posTags.contains(tag)) {
binningInfoWritable.getBinCountPos()[binNum] += 1L;
binningInfoWritable.getBinWeightPos()[binNum] += weight;
} else if (negTags.contains(tag)) {
binningInfoWritable.getBinCountNeg()[binNum] += 1L;
binningInfoWritable.getBinWeightNeg()[binNum] += weight;
}
} else if (columnConfig.isCategorical()) {
int lastBinIndex = binningInfoWritable.getBinCategories().size();
int binNum = 0;
if (units[columnIndex] == null || missingOrInvalidValues.contains(units[columnIndex].toLowerCase())) {
isMissingValue = true;
} else {
String str = units[columnIndex];
binNum = quickLocateCategoricalBin(this.categoricalBinMap.get(newCCIndex), str);
if (binNum < 0) {
isInvalidValue = true;
}
}
if (isInvalidValue || isMissingValue) {
binningInfoWritable.setMissingCount(binningInfoWritable.getMissingCount() + 1L);
binNum = lastBinIndex;
}
if (modelConfig.isRegression()) {
if (posTags.contains(tag)) {
binningInfoWritable.getBinCountPos()[binNum] += 1L;
binningInfoWritable.getBinWeightPos()[binNum] += weight;
} else if (negTags.contains(tag)) {
binningInfoWritable.getBinCountNeg()[binNum] += 1L;
binningInfoWritable.getBinWeightNeg()[binNum] += weight;
}
} else {
// for multiple classification, set bin count to BinCountPos and leave BinCountNeg empty
binningInfoWritable.getBinCountPos()[binNum] += 1L;
binningInfoWritable.getBinWeightPos()[binNum] += weight;
}
} else if (columnConfig.isNumerical()) {
int lastBinIndex = binningInfoWritable.getBinBoundaries().size();
double douVal = 0.0;
if (units[columnIndex] == null || units[columnIndex].length() == 0 || missingOrInvalidValues.contains(units[columnIndex].toLowerCase())) {
isMissingValue = true;
} else {
try {
douVal = Double.parseDouble(units[columnIndex].trim());
} catch (Exception e) {
isInvalidValue = true;
}
}
// add logic the same as CalculateNewStatsUDF
if (Double.compare(douVal, modelConfig.getNumericalValueThreshold()) > 0) {
isInvalidValue = true;
}
if (isInvalidValue || isMissingValue) {
binningInfoWritable.setMissingCount(binningInfoWritable.getMissingCount() + 1L);
if (modelConfig.isRegression()) {
if (posTags.contains(tag)) {
binningInfoWritable.getBinCountPos()[lastBinIndex] += 1L;
binningInfoWritable.getBinWeightPos()[lastBinIndex] += weight;
} else if (negTags.contains(tag)) {
binningInfoWritable.getBinCountNeg()[lastBinIndex] += 1L;
binningInfoWritable.getBinWeightNeg()[lastBinIndex] += weight;
}
}
} else {
// For invalid or missing values, no need update sum, squaredSum, max, min ...
int binNum = getBinNum(binningInfoWritable.getBinBoundaries(), units[columnIndex]);
if (binNum == -1) {
throw new RuntimeException("binNum should not be -1 to this step.");
}
if (modelConfig.isRegression()) {
if (posTags.contains(tag)) {
binningInfoWritable.getBinCountPos()[binNum] += 1L;
binningInfoWritable.getBinWeightPos()[binNum] += weight;
} else if (negTags.contains(tag)) {
binningInfoWritable.getBinCountNeg()[binNum] += 1L;
binningInfoWritable.getBinWeightNeg()[binNum] += weight;
}
}
binningInfoWritable.setSum(binningInfoWritable.getSum() + douVal);
double squaredVal = douVal * douVal;
binningInfoWritable.setSquaredSum(binningInfoWritable.getSquaredSum() + squaredVal);
binningInfoWritable.setTripleSum(binningInfoWritable.getTripleSum() + squaredVal * douVal);
binningInfoWritable.setQuarticSum(binningInfoWritable.getQuarticSum() + squaredVal * squaredVal);
if (Double.compare(binningInfoWritable.getMax(), douVal) < 0) {
binningInfoWritable.setMax(douVal);
}
if (Double.compare(binningInfoWritable.getMin(), douVal) > 0) {
binningInfoWritable.setMin(douVal);
}
}
}
}
use of ml.shifu.shifu.core.autotype.AutoTypeDistinctCountMapper.CountAndFrequentItems in project shifu by ShifuML.
the class UpdateBinningInfoMapper method cleanup.
/**
* Write column info to reducer for merging.
*/
@Override
protected void cleanup(Context context) throws IOException, InterruptedException {
LOG.debug("Column binning info: {}", this.columnBinningInfo);
LOG.debug("Column count info: {}", this.variableCountMap);
for (Map.Entry<Integer, BinningInfoWritable> entry : this.columnBinningInfo.entrySet()) {
CountAndFrequentItems cfi = this.variableCountMap.get(entry.getKey());
if (cfi != null) {
entry.getValue().setCfiw(new CountAndFrequentItemsWritable(cfi.getCount(), cfi.getInvalidCount(), cfi.getValidNumCount(), cfi.getHyper().getBytes(), cfi.getFrequentItems()));
} else {
entry.getValue().setEmpty(true);
LOG.warn("cci is null for column {}", entry.getKey());
}
this.outputKey.set(entry.getKey());
context.write(this.outputKey, entry.getValue());
}
}
use of ml.shifu.shifu.core.autotype.AutoTypeDistinctCountMapper.CountAndFrequentItems in project shifu by ShifuML.
the class UpdateBinningInfoMapper method setup.
/**
* Initialization for column statistics in mapper.
*/
@Override
protected void setup(Context context) throws IOException, InterruptedException {
loadConfigFiles(context);
this.dataSetDelimiter = this.modelConfig.getDataSetDelimiter();
this.dataPurifier = new DataPurifier(this.modelConfig, false);
String filterExpressions = context.getConfiguration().get(Constants.SHIFU_STATS_FILTER_EXPRESSIONS);
if (StringUtils.isNotBlank(filterExpressions)) {
this.isForExpressions = true;
String[] splits = CommonUtils.split(filterExpressions, Constants.SHIFU_STATS_FILTER_EXPRESSIONS_DELIMETER);
this.expressionDataPurifiers = new ArrayList<DataPurifier>(splits.length);
for (String split : splits) {
this.expressionDataPurifiers.add(new DataPurifier(modelConfig, split, false));
}
}
loadWeightColumnNum();
loadTagWeightNum();
this.columnBinningInfo = new HashMap<Integer, BinningInfoWritable>(this.columnConfigList.size(), 1f);
this.categoricalBinMap = new HashMap<Integer, Map<String, Integer>>(this.columnConfigList.size(), 1f);
// create Splitter
String delimiter = context.getConfiguration().get(Constants.SHIFU_OUTPUT_DATA_DELIMITER);
this.splitter = MapReduceUtils.generateShifuOutputSplitter(delimiter);
loadColumnBinningInfo();
this.outputKey = new IntWritable();
this.variableCountMap = new HashMap<Integer, CountAndFrequentItems>();
this.posTags = new HashSet<String>(modelConfig.getPosTags());
this.negTags = new HashSet<String>(modelConfig.getNegTags());
this.tags = new HashSet<String>(modelConfig.getFlattenTags());
this.missingOrInvalidValues = new HashSet<String>(this.modelConfig.getDataSet().getMissingOrInvalidValues());
this.isThrowforWeightException = "true".equalsIgnoreCase(context.getConfiguration().get("shifu.weight.exception", "false"));
LOG.debug("Column binning info: {}", this.columnBinningInfo);
this.isLinearTarget = (CollectionUtils.isEmpty(modelConfig.getTags()) && CommonUtils.getTargetColumnConfig(columnConfigList).isNumerical());
}
Aggregations