Search in sources :

Example 1 with CountAndFrequentItems

use of ml.shifu.shifu.core.autotype.AutoTypeDistinctCountMapper.CountAndFrequentItems in project shifu by ShifuML.

the class UpdateBinningInfoMapper method populateStats.

private void populateStats(String[] units, String tag, Double weight, int columnIndex, int newCCIndex) {
    ColumnConfig columnConfig = this.columnConfigList.get(columnIndex);
    CountAndFrequentItems countAndFrequentItems = this.variableCountMap.get(newCCIndex);
    if (countAndFrequentItems == null) {
        countAndFrequentItems = new CountAndFrequentItems();
        this.variableCountMap.put(newCCIndex, countAndFrequentItems);
    }
    countAndFrequentItems.offer(this.missingOrInvalidValues, units[columnIndex]);
    boolean isMissingValue = false;
    boolean isInvalidValue = false;
    BinningInfoWritable binningInfoWritable = this.columnBinningInfo.get(newCCIndex);
    if (binningInfoWritable == null) {
        return;
    }
    binningInfoWritable.setTotalCount(binningInfoWritable.getTotalCount() + 1L);
    if (columnConfig.isHybrid()) {
        int binNum = 0;
        if (units[columnIndex] == null || missingOrInvalidValues.contains(units[columnIndex].toLowerCase())) {
            isMissingValue = true;
        }
        String str = units[columnIndex];
        double douVal = BinUtils.parseNumber(str);
        Double hybridThreshold = columnConfig.getHybridThreshold();
        if (hybridThreshold == null) {
            hybridThreshold = Double.NEGATIVE_INFINITY;
        }
        // douVal < hybridThreshould which will also be set to category
        boolean isCategory = Double.isNaN(douVal) || douVal < hybridThreshold;
        boolean isNumber = !Double.isNaN(douVal);
        if (isMissingValue) {
            binningInfoWritable.setMissingCount(binningInfoWritable.getMissingCount() + 1L);
            binNum = binningInfoWritable.getBinCategories().size() + binningInfoWritable.getBinBoundaries().size();
        } else if (isCategory) {
            // get categorical bin number in category list
            binNum = quickLocateCategoricalBin(this.categoricalBinMap.get(newCCIndex), str);
            if (binNum < 0) {
                isInvalidValue = true;
            }
            if (isInvalidValue) {
                // the same as missing count
                binningInfoWritable.setMissingCount(binningInfoWritable.getMissingCount() + 1L);
                binNum = binningInfoWritable.getBinCategories().size() + binningInfoWritable.getBinBoundaries().size();
            } else {
                // if real category value, binNum should + binBoundaries.size
                binNum += binningInfoWritable.getBinBoundaries().size();
                ;
            }
        } else if (isNumber) {
            binNum = getBinNum(binningInfoWritable.getBinBoundaries(), douVal);
            if (binNum == -1) {
                throw new RuntimeException("binNum should not be -1 to this step.");
            }
            // other stats are treated as numerical features
            binningInfoWritable.setSum(binningInfoWritable.getSum() + douVal);
            double squaredVal = douVal * douVal;
            binningInfoWritable.setSquaredSum(binningInfoWritable.getSquaredSum() + squaredVal);
            binningInfoWritable.setTripleSum(binningInfoWritable.getTripleSum() + squaredVal * douVal);
            binningInfoWritable.setQuarticSum(binningInfoWritable.getQuarticSum() + squaredVal * squaredVal);
            if (Double.compare(binningInfoWritable.getMax(), douVal) < 0) {
                binningInfoWritable.setMax(douVal);
            }
            if (Double.compare(binningInfoWritable.getMin(), douVal) > 0) {
                binningInfoWritable.setMin(douVal);
            }
        }
        if (posTags.contains(tag)) {
            binningInfoWritable.getBinCountPos()[binNum] += 1L;
            binningInfoWritable.getBinWeightPos()[binNum] += weight;
        } else if (negTags.contains(tag)) {
            binningInfoWritable.getBinCountNeg()[binNum] += 1L;
            binningInfoWritable.getBinWeightNeg()[binNum] += weight;
        }
    } else if (columnConfig.isCategorical()) {
        int lastBinIndex = binningInfoWritable.getBinCategories().size();
        int binNum = 0;
        if (units[columnIndex] == null || missingOrInvalidValues.contains(units[columnIndex].toLowerCase())) {
            isMissingValue = true;
        } else {
            String str = units[columnIndex];
            binNum = quickLocateCategoricalBin(this.categoricalBinMap.get(newCCIndex), str);
            if (binNum < 0) {
                isInvalidValue = true;
            }
        }
        if (isInvalidValue || isMissingValue) {
            binningInfoWritable.setMissingCount(binningInfoWritable.getMissingCount() + 1L);
            binNum = lastBinIndex;
        }
        if (modelConfig.isRegression()) {
            if (posTags.contains(tag)) {
                binningInfoWritable.getBinCountPos()[binNum] += 1L;
                binningInfoWritable.getBinWeightPos()[binNum] += weight;
            } else if (negTags.contains(tag)) {
                binningInfoWritable.getBinCountNeg()[binNum] += 1L;
                binningInfoWritable.getBinWeightNeg()[binNum] += weight;
            }
        } else {
            // for multiple classification, set bin count to BinCountPos and leave BinCountNeg empty
            binningInfoWritable.getBinCountPos()[binNum] += 1L;
            binningInfoWritable.getBinWeightPos()[binNum] += weight;
        }
    } else if (columnConfig.isNumerical()) {
        int lastBinIndex = binningInfoWritable.getBinBoundaries().size();
        double douVal = 0.0;
        if (units[columnIndex] == null || units[columnIndex].length() == 0 || missingOrInvalidValues.contains(units[columnIndex].toLowerCase())) {
            isMissingValue = true;
        } else {
            try {
                douVal = Double.parseDouble(units[columnIndex].trim());
            } catch (Exception e) {
                isInvalidValue = true;
            }
        }
        // add logic the same as CalculateNewStatsUDF
        if (Double.compare(douVal, modelConfig.getNumericalValueThreshold()) > 0) {
            isInvalidValue = true;
        }
        if (isInvalidValue || isMissingValue) {
            binningInfoWritable.setMissingCount(binningInfoWritable.getMissingCount() + 1L);
            if (modelConfig.isRegression()) {
                if (posTags.contains(tag)) {
                    binningInfoWritable.getBinCountPos()[lastBinIndex] += 1L;
                    binningInfoWritable.getBinWeightPos()[lastBinIndex] += weight;
                } else if (negTags.contains(tag)) {
                    binningInfoWritable.getBinCountNeg()[lastBinIndex] += 1L;
                    binningInfoWritable.getBinWeightNeg()[lastBinIndex] += weight;
                }
            }
        } else {
            // For invalid or missing values, no need update sum, squaredSum, max, min ...
            int binNum = getBinNum(binningInfoWritable.getBinBoundaries(), units[columnIndex]);
            if (binNum == -1) {
                throw new RuntimeException("binNum should not be -1 to this step.");
            }
            if (modelConfig.isRegression()) {
                if (posTags.contains(tag)) {
                    binningInfoWritable.getBinCountPos()[binNum] += 1L;
                    binningInfoWritable.getBinWeightPos()[binNum] += weight;
                } else if (negTags.contains(tag)) {
                    binningInfoWritable.getBinCountNeg()[binNum] += 1L;
                    binningInfoWritable.getBinWeightNeg()[binNum] += weight;
                }
            }
            binningInfoWritable.setSum(binningInfoWritable.getSum() + douVal);
            double squaredVal = douVal * douVal;
            binningInfoWritable.setSquaredSum(binningInfoWritable.getSquaredSum() + squaredVal);
            binningInfoWritable.setTripleSum(binningInfoWritable.getTripleSum() + squaredVal * douVal);
            binningInfoWritable.setQuarticSum(binningInfoWritable.getQuarticSum() + squaredVal * squaredVal);
            if (Double.compare(binningInfoWritable.getMax(), douVal) < 0) {
                binningInfoWritable.setMax(douVal);
            }
            if (Double.compare(binningInfoWritable.getMin(), douVal) > 0) {
                binningInfoWritable.setMin(douVal);
            }
        }
    }
}
Also used : CountAndFrequentItems(ml.shifu.shifu.core.autotype.AutoTypeDistinctCountMapper.CountAndFrequentItems) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) IOException(java.io.IOException) FileNotFoundException(java.io.FileNotFoundException)

Example 2 with CountAndFrequentItems

use of ml.shifu.shifu.core.autotype.AutoTypeDistinctCountMapper.CountAndFrequentItems in project shifu by ShifuML.

the class UpdateBinningInfoMapper method cleanup.

/**
 * Write column info to reducer for merging.
 */
@Override
protected void cleanup(Context context) throws IOException, InterruptedException {
    LOG.debug("Column binning info: {}", this.columnBinningInfo);
    LOG.debug("Column count info: {}", this.variableCountMap);
    for (Map.Entry<Integer, BinningInfoWritable> entry : this.columnBinningInfo.entrySet()) {
        CountAndFrequentItems cfi = this.variableCountMap.get(entry.getKey());
        if (cfi != null) {
            entry.getValue().setCfiw(new CountAndFrequentItemsWritable(cfi.getCount(), cfi.getInvalidCount(), cfi.getValidNumCount(), cfi.getHyper().getBytes(), cfi.getFrequentItems()));
        } else {
            entry.getValue().setEmpty(true);
            LOG.warn("cci is null for column {}", entry.getKey());
        }
        this.outputKey.set(entry.getKey());
        context.write(this.outputKey, entry.getValue());
    }
}
Also used : CountAndFrequentItems(ml.shifu.shifu.core.autotype.AutoTypeDistinctCountMapper.CountAndFrequentItems) CountAndFrequentItemsWritable(ml.shifu.shifu.core.autotype.CountAndFrequentItemsWritable) HashMap(java.util.HashMap) Map(java.util.Map)

Example 3 with CountAndFrequentItems

use of ml.shifu.shifu.core.autotype.AutoTypeDistinctCountMapper.CountAndFrequentItems in project shifu by ShifuML.

the class UpdateBinningInfoMapper method setup.

/**
 * Initialization for column statistics in mapper.
 */
@Override
protected void setup(Context context) throws IOException, InterruptedException {
    loadConfigFiles(context);
    this.dataSetDelimiter = this.modelConfig.getDataSetDelimiter();
    this.dataPurifier = new DataPurifier(this.modelConfig, false);
    String filterExpressions = context.getConfiguration().get(Constants.SHIFU_STATS_FILTER_EXPRESSIONS);
    if (StringUtils.isNotBlank(filterExpressions)) {
        this.isForExpressions = true;
        String[] splits = CommonUtils.split(filterExpressions, Constants.SHIFU_STATS_FILTER_EXPRESSIONS_DELIMETER);
        this.expressionDataPurifiers = new ArrayList<DataPurifier>(splits.length);
        for (String split : splits) {
            this.expressionDataPurifiers.add(new DataPurifier(modelConfig, split, false));
        }
    }
    loadWeightColumnNum();
    loadTagWeightNum();
    this.columnBinningInfo = new HashMap<Integer, BinningInfoWritable>(this.columnConfigList.size(), 1f);
    this.categoricalBinMap = new HashMap<Integer, Map<String, Integer>>(this.columnConfigList.size(), 1f);
    // create Splitter
    String delimiter = context.getConfiguration().get(Constants.SHIFU_OUTPUT_DATA_DELIMITER);
    this.splitter = MapReduceUtils.generateShifuOutputSplitter(delimiter);
    loadColumnBinningInfo();
    this.outputKey = new IntWritable();
    this.variableCountMap = new HashMap<Integer, CountAndFrequentItems>();
    this.posTags = new HashSet<String>(modelConfig.getPosTags());
    this.negTags = new HashSet<String>(modelConfig.getNegTags());
    this.tags = new HashSet<String>(modelConfig.getFlattenTags());
    this.missingOrInvalidValues = new HashSet<String>(this.modelConfig.getDataSet().getMissingOrInvalidValues());
    this.isThrowforWeightException = "true".equalsIgnoreCase(context.getConfiguration().get("shifu.weight.exception", "false"));
    LOG.debug("Column binning info: {}", this.columnBinningInfo);
    this.isLinearTarget = (CollectionUtils.isEmpty(modelConfig.getTags()) && CommonUtils.getTargetColumnConfig(columnConfigList).isNumerical());
}
Also used : CountAndFrequentItems(ml.shifu.shifu.core.autotype.AutoTypeDistinctCountMapper.CountAndFrequentItems) DataPurifier(ml.shifu.shifu.core.DataPurifier) HashMap(java.util.HashMap) Map(java.util.Map) IntWritable(org.apache.hadoop.io.IntWritable)

Aggregations

CountAndFrequentItems (ml.shifu.shifu.core.autotype.AutoTypeDistinctCountMapper.CountAndFrequentItems)3 HashMap (java.util.HashMap)2 Map (java.util.Map)2 FileNotFoundException (java.io.FileNotFoundException)1 IOException (java.io.IOException)1 ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)1 DataPurifier (ml.shifu.shifu.core.DataPurifier)1 CountAndFrequentItemsWritable (ml.shifu.shifu.core.autotype.CountAndFrequentItemsWritable)1 IntWritable (org.apache.hadoop.io.IntWritable)1