Search in sources :

Example 1 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class DTWorker method load.

@Override
public void load(GuaguaWritableAdapter<LongWritable> currentKey, GuaguaWritableAdapter<Text> currentValue, WorkerContext<DTMasterParams, DTWorkerParams> context) {
    this.count += 1;
    if ((this.count) % 5000 == 0) {
        LOG.info("Read {} records.", this.count);
    }
    // hashcode for fixed input split in train and validation
    long hashcode = 0;
    short[] inputs = new short[this.inputCount];
    float ideal = 0f;
    float significance = 1f;
    // use guava Splitter to iterate only once
    // use NNConstants.NN_DEFAULT_COLUMN_SEPARATOR to replace getModelConfig().getDataSetDelimiter(), super follows
    // the function in akka mode.
    int index = 0, inputIndex = 0;
    for (String input : this.splitter.split(currentValue.getWritable().toString())) {
        if (index == this.columnConfigList.size()) {
            // weight, how to process???
            if (StringUtils.isBlank(modelConfig.getWeightColumnName())) {
                significance = 1f;
                break;
            }
            // check here to avoid bad performance in failed NumberFormatUtils.getFloat(input, 1f)
            significance = input.length() == 0 ? 1f : NumberFormatUtils.getFloat(input, 1f);
            // if invalid weight, set it to 1f and warning in log
            if (Float.compare(significance, 0f) < 0) {
                LOG.warn("The {} record in current worker weight {} is less than 0f, it is invalid, set it to 1.", count, significance);
                significance = 1f;
            }
            // the last field is significance, break here
            break;
        } else {
            ColumnConfig columnConfig = this.columnConfigList.get(index);
            if (columnConfig != null && columnConfig.isTarget()) {
                ideal = getFloatValue(input);
            } else {
                if (!isAfterVarSelect) {
                    // no variable selected, good candidate but not meta and not target chose
                    if (!columnConfig.isMeta() && !columnConfig.isTarget() && CommonUtils.isGoodCandidate(columnConfig, this.hasCandidates)) {
                        if (columnConfig.isNumerical()) {
                            float floatValue = getFloatValue(input);
                            // cast is safe as we limit max bin to Short.MAX_VALUE
                            short binIndex = (short) getBinIndex(floatValue, columnConfig.getBinBoundary());
                            inputs[inputIndex] = binIndex;
                            if (!this.inputIndexMap.containsKey(columnConfig.getColumnNum())) {
                                this.inputIndexMap.put(columnConfig.getColumnNum(), inputIndex);
                            }
                        } else if (columnConfig.isCategorical()) {
                            short shortValue = (short) (columnConfig.getBinCategory().size());
                            if (input.length() == 0) {
                                // empty
                                shortValue = (short) (columnConfig.getBinCategory().size());
                            } else {
                                Integer categoricalIndex = this.columnCategoryIndexMapping.get(columnConfig.getColumnNum()).get(input);
                                if (categoricalIndex == null) {
                                    // invalid category, set to -1 for last index
                                    shortValue = -1;
                                } else {
                                    // cast is safe as we limit max bin to Short.MAX_VALUE
                                    shortValue = (short) (categoricalIndex.intValue());
                                }
                                if (shortValue == -1) {
                                    // not found
                                    shortValue = (short) (columnConfig.getBinCategory().size());
                                }
                            }
                            inputs[inputIndex] = shortValue;
                            if (!this.inputIndexMap.containsKey(columnConfig.getColumnNum())) {
                                this.inputIndexMap.put(columnConfig.getColumnNum(), inputIndex);
                            }
                        }
                        hashcode = hashcode * 31 + input.hashCode();
                        inputIndex += 1;
                    }
                } else {
                    // final select some variables but meta and target are not included
                    if (columnConfig != null && !columnConfig.isMeta() && !columnConfig.isTarget() && columnConfig.isFinalSelect()) {
                        if (columnConfig.isNumerical()) {
                            float floatValue = getFloatValue(input);
                            // cast is safe as we limit max bin to Short.MAX_VALUE
                            short binIndex = (short) getBinIndex(floatValue, columnConfig.getBinBoundary());
                            inputs[inputIndex] = binIndex;
                            if (!this.inputIndexMap.containsKey(columnConfig.getColumnNum())) {
                                this.inputIndexMap.put(columnConfig.getColumnNum(), inputIndex);
                            }
                        } else if (columnConfig.isCategorical()) {
                            // cast is safe as we limit max bin to Short.MAX_VALUE
                            short shortValue = (short) (columnConfig.getBinCategory().size());
                            if (input.length() == 0) {
                                // empty
                                shortValue = (short) (columnConfig.getBinCategory().size());
                            } else {
                                Integer categoricalIndex = this.columnCategoryIndexMapping.get(columnConfig.getColumnNum()).get(input);
                                if (categoricalIndex == null) {
                                    // invalid category, set to -1 for last index
                                    shortValue = -1;
                                } else {
                                    // cast is safe as we limit max bin to Short.MAX_VALUE
                                    shortValue = (short) (categoricalIndex.intValue());
                                }
                                if (shortValue == -1) {
                                    // not found
                                    shortValue = (short) (columnConfig.getBinCategory().size());
                                }
                            }
                            inputs[inputIndex] = shortValue;
                            if (!this.inputIndexMap.containsKey(columnConfig.getColumnNum())) {
                                this.inputIndexMap.put(columnConfig.getColumnNum(), inputIndex);
                            }
                        }
                        hashcode = hashcode * 31 + input.hashCode();
                        inputIndex += 1;
                    }
                }
            }
        }
        index += 1;
    }
    // is helped to quick find such issue.
    if (inputIndex != inputs.length) {
        String delimiter = context.getProps().getProperty(Constants.SHIFU_OUTPUT_DATA_DELIMITER, Constants.DEFAULT_DELIMITER);
        throw new RuntimeException("Input length is inconsistent with parsing size. Input original size: " + inputs.length + ", parsing size:" + inputIndex + ", delimiter:" + delimiter + ".");
    }
    if (this.isOneVsAll) {
        // if one vs all, update target value according to index of target
        ideal = updateOneVsAllTargetValue(ideal);
    }
    // sample negative only logic here
    if (modelConfig.getTrain().getSampleNegOnly()) {
        if (this.modelConfig.isFixInitialInput()) {
            // if fixInitialInput, sample hashcode in 1-sampleRate range out if negative records
            int startHashCode = (100 / this.modelConfig.getBaggingNum()) * this.trainerId;
            // here BaggingSampleRate means how many data will be used in training and validation, if it is 0.8, we
            // should take 1-0.8 to check endHashCode
            int endHashCode = startHashCode + Double.valueOf((1d - this.modelConfig.getBaggingSampleRate()) * 100).intValue();
            if (// regression or onevsall
            (modelConfig.isRegression() || this.isOneVsAll) && // negative record
            (int) (ideal + 0.01d) == 0 && isInRange(hashcode, startHashCode, endHashCode)) {
                return;
            }
        } else {
            // and if negative record do sampling out
            if (// regression or onevsall
            (modelConfig.isRegression() || this.isOneVsAll) && // negative record
            (int) (ideal + 0.01d) == 0 && Double.compare(this.sampelNegOnlyRandom.nextDouble(), this.modelConfig.getBaggingSampleRate()) >= 0) {
                return;
            }
        }
    }
    float output = ideal;
    float predict = ideal;
    // up sampling logic, just add more weights while bagging sampling rate is still not changed
    if (modelConfig.isRegression() && isUpSampleEnabled() && Double.compare(ideal, 1d) == 0) {
        // Double.compare(ideal, 1d) == 0 means positive tags; sample + 1 to avoid sample count to 0
        significance = significance * (this.upSampleRng.sample() + 1);
    }
    Data data = new Data(inputs, predict, output, output, significance);
    boolean isValidation = false;
    if (context.getAttachment() != null && context.getAttachment() instanceof Boolean) {
        isValidation = (Boolean) context.getAttachment();
    }
    // split into validation and training data set according to validation rate
    boolean isInTraining = this.addDataPairToDataSet(hashcode, data, isValidation);
    // do bagging sampling only for training data
    if (isInTraining) {
        data.subsampleWeights = sampleWeights(data.label);
        // if gbdt, only the 1st sampling value is used, if rf, use the 1st to denote some information, no need all
        if (isPositive(data.label)) {
            this.positiveSelectedTrainCount += data.subsampleWeights[0] * 1L;
        } else {
            this.negativeSelectedTrainCount += data.subsampleWeights[0] * 1L;
        }
    } else {
    // for validation data, according bagging sampling logic, we may need to sampling validation data set, while
    // validation data set are only used to compute validation error, not to do real sampling is ok.
    }
}
Also used : GuaguaRuntimeException(ml.shifu.guagua.GuaguaRuntimeException) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig)

Example 2 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class DTWorker method getAllValidFeatures.

private List<Integer> getAllValidFeatures() {
    List<Integer> features = new ArrayList<Integer>();
    boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
    for (ColumnConfig config : columnConfigList) {
        if (isAfterVarSelect) {
            if (config.isFinalSelect() && !config.isTarget() && !config.isMeta()) {
                // or categorical feature with getBinCategory().size() larger than 0
                if ((config.isNumerical() && config.getBinBoundary().size() > 1) || (config.isCategorical() && config.getBinCategory().size() > 0)) {
                    features.add(config.getColumnNum());
                }
            }
        } else {
            if (!config.isMeta() && !config.isTarget() && CommonUtils.isGoodCandidate(config, hasCandidates)) {
                // or categorical feature with getBinCategory().size() larger than 0
                if ((config.isNumerical() && config.getBinBoundary().size() > 1) || (config.isCategorical() && config.getBinCategory().size() > 0)) {
                    features.add(config.getColumnNum());
                }
            }
        }
    }
    return features;
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) ArrayList(java.util.ArrayList)

Example 3 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class UpdateBinningInfoReducer method reduce.

@Override
protected void reduce(IntWritable key, Iterable<BinningInfoWritable> values, Context context) throws IOException, InterruptedException {
    long start = System.currentTimeMillis();
    double sum = 0d;
    double squaredSum = 0d;
    double tripleSum = 0d;
    double quarticSum = 0d;
    double p25th = 0d;
    double median = 0d;
    double p75th = 0d;
    long count = 0L, missingCount = 0L;
    double min = Double.MAX_VALUE, max = Double.MIN_VALUE;
    List<Double> binBoundaryList = null;
    List<String> binCategories = null;
    long[] binCountPos = null;
    long[] binCountNeg = null;
    double[] binWeightPos = null;
    double[] binWeightNeg = null;
    long[] binCountTotal = null;
    int columnConfigIndex = key.get() >= this.columnConfigList.size() ? key.get() % this.columnConfigList.size() : key.get();
    ColumnConfig columnConfig = this.columnConfigList.get(columnConfigIndex);
    HyperLogLogPlus hyperLogLogPlus = null;
    Set<String> fis = new HashSet<String>();
    long totalCount = 0, invalidCount = 0, validNumCount = 0;
    int binSize = 0;
    for (BinningInfoWritable info : values) {
        if (info.isEmpty()) {
            // mapper has no stats, skip it
            continue;
        }
        CountAndFrequentItemsWritable cfiw = info.getCfiw();
        totalCount += cfiw.getCount();
        invalidCount += cfiw.getInvalidCount();
        validNumCount += cfiw.getValidNumCount();
        fis.addAll(cfiw.getFrequetItems());
        if (hyperLogLogPlus == null) {
            hyperLogLogPlus = HyperLogLogPlus.Builder.build(cfiw.getHyperBytes());
        } else {
            try {
                hyperLogLogPlus = (HyperLogLogPlus) hyperLogLogPlus.merge(HyperLogLogPlus.Builder.build(cfiw.getHyperBytes()));
            } catch (CardinalityMergeException e) {
                throw new RuntimeException(e);
            }
        }
        if (columnConfig.isHybrid() && binBoundaryList == null && binCategories == null) {
            binBoundaryList = info.getBinBoundaries();
            binCategories = info.getBinCategories();
            binSize = binBoundaryList.size() + binCategories.size();
            binCountPos = new long[binSize + 1];
            binCountNeg = new long[binSize + 1];
            binWeightPos = new double[binSize + 1];
            binWeightNeg = new double[binSize + 1];
            binCountTotal = new long[binSize + 1];
        } else if (columnConfig.isNumerical() && binBoundaryList == null) {
            binBoundaryList = info.getBinBoundaries();
            binSize = binBoundaryList.size();
            binCountPos = new long[binSize + 1];
            binCountNeg = new long[binSize + 1];
            binWeightPos = new double[binSize + 1];
            binWeightNeg = new double[binSize + 1];
            binCountTotal = new long[binSize + 1];
        } else if (columnConfig.isCategorical() && binCategories == null) {
            binCategories = info.getBinCategories();
            binSize = binCategories.size();
            binCountPos = new long[binSize + 1];
            binCountNeg = new long[binSize + 1];
            binWeightPos = new double[binSize + 1];
            binWeightNeg = new double[binSize + 1];
            binCountTotal = new long[binSize + 1];
        }
        count += info.getTotalCount();
        missingCount += info.getMissingCount();
        // for numeric, such sums are OK, for categorical, such values are all 0, should be updated by using
        // binCountPos and binCountNeg
        sum += info.getSum();
        squaredSum += info.getSquaredSum();
        tripleSum += info.getTripleSum();
        quarticSum += info.getQuarticSum();
        if (Double.compare(max, info.getMax()) < 0) {
            max = info.getMax();
        }
        if (Double.compare(min, info.getMin()) > 0) {
            min = info.getMin();
        }
        for (int i = 0; i < (binSize + 1); i++) {
            binCountPos[i] += info.getBinCountPos()[i];
            binCountNeg[i] += info.getBinCountNeg()[i];
            binWeightPos[i] += info.getBinWeightPos()[i];
            binWeightNeg[i] += info.getBinWeightNeg()[i];
            binCountTotal[i] += info.getBinCountPos()[i];
            binCountTotal[i] += info.getBinCountNeg()[i];
        }
    }
    if (columnConfig.isNumerical()) {
        long p25Count = count / 4;
        long medianCount = p25Count * 2;
        long p75Count = p25Count * 3;
        p25th = min;
        median = min;
        p75th = min;
        int currentCount = 0;
        for (int i = 0; i < binBoundaryList.size(); i++) {
            double left = getCutoffBoundary(binBoundaryList.get(i), max, min);
            double right = ((i == binBoundaryList.size() - 1) ? max : getCutoffBoundary(binBoundaryList.get(i + 1), max, min));
            if (p25Count >= currentCount && p25Count < currentCount + binCountTotal[i]) {
                p25th = ((p25Count - currentCount) / (double) binCountTotal[i]) * (right - left) + left;
            }
            if (medianCount >= currentCount && medianCount < currentCount + binCountTotal[i]) {
                median = ((medianCount - currentCount) / (double) binCountTotal[i]) * (right - left) + left;
            }
            if (p75Count >= currentCount && p75Count < currentCount + binCountTotal[i]) {
                p75th = ((p75Count - currentCount) / (double) binCountTotal[i]) * (right - left) + left;
                // when get 75 percentile stop it
                break;
            }
            currentCount += binCountTotal[i];
        }
        LOG.info("Coloumn num is {}, p25 value is {}, median value is {}, p75 value is {}", columnConfig.getColumnNum(), p25th, median, p75th);
    }
    LOG.info("Coloumn num is {}, columnType value is {}, cateMaxNumBin is {}, binCategory size is {}", columnConfig.getColumnNum(), columnConfig.getColumnType(), modelConfig.getStats().getCateMaxNumBin(), (CollectionUtils.isNotEmpty(columnConfig.getBinCategory()) ? columnConfig.getBinCategory().size() : 0));
    // To merge categorical binning
    if (columnConfig.isCategorical() && modelConfig.getStats().getCateMaxNumBin() > 0 && CollectionUtils.isNotEmpty(binCategories) && binCategories.size() > modelConfig.getStats().getCateMaxNumBin()) {
        // only category size large then expected max bin number
        CateBinningStats cateBinningStats = rebinCategoricalValues(new CateBinningStats(binCategories, binCountPos, binCountNeg, binWeightPos, binWeightNeg));
        LOG.info("For variable - {}, {} bins is rebined to {} bins", columnConfig.getColumnName(), binCategories.size(), cateBinningStats.binCategories.size());
        binCategories = cateBinningStats.binCategories;
        binCountPos = cateBinningStats.binCountPos;
        binCountNeg = cateBinningStats.binCountNeg;
        binWeightPos = cateBinningStats.binWeightPos;
        binWeightNeg = cateBinningStats.binWeightNeg;
    }
    double[] binPosRate;
    if (modelConfig.isRegression()) {
        binPosRate = computePosRate(binCountPos, binCountNeg);
    } else {
        // for multiple classfication, use rate of categories to compute a value
        binPosRate = computeRateForMultiClassfication(binCountPos);
    }
    String binBounString = null;
    if (columnConfig.isHybrid()) {
        if (binCategories.size() > this.maxCateSize) {
            LOG.warn("Column {} {} with invalid bin category size.", key.get(), columnConfig.getColumnName(), binCategories.size());
            return;
        }
        binBounString = binBoundaryList.toString();
        binBounString += Constants.HYBRID_BIN_STR_DILIMETER + Base64Utils.base64Encode("[" + StringUtils.join(binCategories, CalculateStatsUDF.CATEGORY_VAL_SEPARATOR) + "]");
    } else if (columnConfig.isCategorical()) {
        if (binCategories.size() > this.maxCateSize) {
            LOG.warn("Column {} {} with invalid bin category size.", key.get(), columnConfig.getColumnName(), binCategories.size());
            return;
        }
        binBounString = Base64Utils.base64Encode("[" + StringUtils.join(binCategories, CalculateStatsUDF.CATEGORY_VAL_SEPARATOR) + "]");
        // recompute such value for categorical variables
        min = Double.MAX_VALUE;
        max = Double.MIN_VALUE;
        sum = 0d;
        squaredSum = 0d;
        for (int i = 0; i < binPosRate.length; i++) {
            if (!Double.isNaN(binPosRate[i])) {
                if (Double.compare(max, binPosRate[i]) < 0) {
                    max = binPosRate[i];
                }
                if (Double.compare(min, binPosRate[i]) > 0) {
                    min = binPosRate[i];
                }
                long binCount = binCountPos[i] + binCountNeg[i];
                sum += binPosRate[i] * binCount;
                double squaredVal = binPosRate[i] * binPosRate[i];
                squaredSum += squaredVal * binCount;
                tripleSum += squaredVal * binPosRate[i] * binCount;
                quarticSum += squaredVal * squaredVal * binCount;
            }
        }
    } else {
        if (binBoundaryList.size() == 0) {
            LOG.warn("Column {} {} with invalid bin boundary size.", key.get(), columnConfig.getColumnName(), binBoundaryList.size());
            return;
        }
        binBounString = binBoundaryList.toString();
    }
    ColumnMetrics columnCountMetrics = null;
    ColumnMetrics columnWeightMetrics = null;
    if (modelConfig.isRegression()) {
        columnCountMetrics = ColumnStatsCalculator.calculateColumnMetrics(binCountNeg, binCountPos);
        columnWeightMetrics = ColumnStatsCalculator.calculateColumnMetrics(binWeightNeg, binWeightPos);
    }
    // To make it be consistent with SPDT, missingCount is excluded to compute mean, stddev ...
    long realCount = this.statsExcludeMissingValue ? (count - missingCount) : count;
    double mean = sum / realCount;
    double stdDev = Math.sqrt(Math.abs((squaredSum - (sum * sum) / realCount + EPS) / (realCount - 1)));
    double aStdDev = Math.sqrt(Math.abs((squaredSum - (sum * sum) / realCount + EPS) / realCount));
    double skewness = ColumnStatsCalculator.computeSkewness(realCount, mean, aStdDev, sum, squaredSum, tripleSum);
    double kurtosis = ColumnStatsCalculator.computeKurtosis(realCount, mean, aStdDev, sum, squaredSum, tripleSum, quarticSum);
    sb.append(key.get()).append(Constants.DEFAULT_DELIMITER).append(binBounString).append(Constants.DEFAULT_DELIMITER).append(Arrays.toString(binCountNeg)).append(Constants.DEFAULT_DELIMITER).append(Arrays.toString(binCountPos)).append(Constants.DEFAULT_DELIMITER).append(Arrays.toString(new double[0])).append(Constants.DEFAULT_DELIMITER).append(Arrays.toString(binPosRate)).append(Constants.DEFAULT_DELIMITER).append(columnCountMetrics == null ? "" : df.format(columnCountMetrics.getKs())).append(Constants.DEFAULT_DELIMITER).append(columnCountMetrics == null ? "" : df.format(columnCountMetrics.getIv())).append(Constants.DEFAULT_DELIMITER).append(df.format(max)).append(Constants.DEFAULT_DELIMITER).append(df.format(min)).append(Constants.DEFAULT_DELIMITER).append(df.format(mean)).append(Constants.DEFAULT_DELIMITER).append(df.format(stdDev)).append(Constants.DEFAULT_DELIMITER).append(columnConfig.getColumnType().toString()).append(Constants.DEFAULT_DELIMITER).append(median).append(Constants.DEFAULT_DELIMITER).append(missingCount).append(Constants.DEFAULT_DELIMITER).append(count).append(Constants.DEFAULT_DELIMITER).append(missingCount * 1.0d / count).append(Constants.DEFAULT_DELIMITER).append(Arrays.toString(binWeightNeg)).append(Constants.DEFAULT_DELIMITER).append(Arrays.toString(binWeightPos)).append(Constants.DEFAULT_DELIMITER).append(columnCountMetrics == null ? "" : columnCountMetrics.getWoe()).append(Constants.DEFAULT_DELIMITER).append(columnWeightMetrics == null ? "" : columnWeightMetrics.getWoe()).append(Constants.DEFAULT_DELIMITER).append(columnWeightMetrics == null ? "" : columnWeightMetrics.getKs()).append(Constants.DEFAULT_DELIMITER).append(columnWeightMetrics == null ? "" : columnWeightMetrics.getIv()).append(Constants.DEFAULT_DELIMITER).append(columnCountMetrics == null ? Arrays.toString(new double[binSize + 1]) : columnCountMetrics.getBinningWoe().toString()).append(Constants.DEFAULT_DELIMITER).append(columnWeightMetrics == null ? Arrays.toString(new double[binSize + 1]) : // bin weighted WOE
    columnWeightMetrics.getBinningWoe().toString()).append(Constants.DEFAULT_DELIMITER).append(// skewness
    skewness).append(Constants.DEFAULT_DELIMITER).append(// kurtosis
    kurtosis).append(Constants.DEFAULT_DELIMITER).append(// total count
    totalCount).append(Constants.DEFAULT_DELIMITER).append(// invalid count
    invalidCount).append(Constants.DEFAULT_DELIMITER).append(// valid num count
    validNumCount).append(Constants.DEFAULT_DELIMITER).append(// cardinality
    hyperLogLogPlus.cardinality()).append(Constants.DEFAULT_DELIMITER).append(// frequent items
    Base64Utils.base64Encode(limitedFrequentItems(fis))).append(Constants.DEFAULT_DELIMITER).append(// the 25 percentile value
    p25th).append(Constants.DEFAULT_DELIMITER).append(p75th);
    outputValue.set(sb.toString());
    context.write(NullWritable.get(), outputValue);
    sb.delete(0, sb.length());
    LOG.debug("Time:{}", (System.currentTimeMillis() - start));
}
Also used : CountAndFrequentItemsWritable(ml.shifu.shifu.core.autotype.CountAndFrequentItemsWritable) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) CardinalityMergeException(com.clearspring.analytics.stream.cardinality.CardinalityMergeException) HyperLogLogPlus(com.clearspring.analytics.stream.cardinality.HyperLogLogPlus) ColumnMetrics(ml.shifu.shifu.core.ColumnStatsCalculator.ColumnMetrics)

Example 4 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class CorrelationMapper method map.

@Override
protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
    String valueStr = value.toString();
    if (valueStr == null || valueStr.length() == 0 || valueStr.trim().length() == 0) {
        LOG.warn("Empty input.");
        return;
    }
    double[] dValues = null;
    if (!this.dataPurifier.isFilter(valueStr)) {
        return;
    }
    long startO = System.currentTimeMillis();
    context.getCounter(Constants.SHIFU_GROUP_COUNTER, "CNT_AFTER_FILTER").increment(1L);
    // make sampling work in correlation
    if (Math.random() >= modelConfig.getStats().getSampleRate()) {
        return;
    }
    context.getCounter(Constants.SHIFU_GROUP_COUNTER, "CORRELATION_CNT").increment(1L);
    dValues = getDoubleArrayByRawArray(CommonUtils.split(valueStr, this.dataSetDelimiter));
    count += 1L;
    if (count % 2000L == 0) {
        LOG.info("Current records: {} in thread {}.", count, Thread.currentThread().getName());
    }
    for (int i = 0; i < columnConfigList.size(); i++) {
        ColumnConfig columnConfig = columnConfigList.get(i);
        if (columnConfig.getColumnFlag() == ColumnFlag.Meta || (hasCandidates && !ColumnFlag.Candidate.equals(columnConfig.getColumnFlag()))) {
            continue;
        }
        CorrelationWritable cw = CorrelationMultithreadedMapper.finalCorrelationMap.get(columnConfig.getColumnNum());
        synchronized (cw) {
            cw.setColumnIndex(i);
            cw.setCount(cw.getCount() + 1d);
            cw.setSum(cw.getSum() + dValues[i]);
            double squaredSum = dValues[i] * dValues[i];
            cw.setSumSquare(cw.getSumSquare() + squaredSum);
            double[] xySum = cw.getXySum();
            if (xySum == null) {
                xySum = new double[columnConfigList.size()];
                cw.setXySum(xySum);
            }
            double[] xxSum = cw.getXxSum();
            if (xxSum == null) {
                xxSum = new double[columnConfigList.size()];
                cw.setXxSum(xxSum);
            }
            double[] yySum = cw.getYySum();
            if (yySum == null) {
                yySum = new double[columnConfigList.size()];
                cw.setYySum(yySum);
            }
            double[] adjustCount = cw.getAdjustCount();
            if (adjustCount == null) {
                adjustCount = new double[columnConfigList.size()];
                cw.setAdjustCount(adjustCount);
            }
            double[] adjustSumX = cw.getAdjustSumX();
            if (adjustSumX == null) {
                adjustSumX = new double[columnConfigList.size()];
                cw.setAdjustSumX(adjustSumX);
            }
            double[] adjustSumY = cw.getAdjustSumY();
            if (adjustSumY == null) {
                adjustSumY = new double[columnConfigList.size()];
                cw.setAdjustSumY(adjustSumY);
            }
            for (int j = (this.isComputeAll ? 0 : i); j < columnConfigList.size(); j++) {
                ColumnConfig otherColumnConfig = columnConfigList.get(j);
                if ((otherColumnConfig.getColumnFlag() != ColumnFlag.Target) && ((otherColumnConfig.getColumnFlag() == ColumnFlag.Meta) || (hasCandidates && !ColumnFlag.Candidate.equals(otherColumnConfig.getColumnFlag())))) {
                    continue;
                }
                // only do stats on both valid values
                if (dValues[i] != Double.MIN_VALUE && dValues[j] != Double.MIN_VALUE) {
                    xySum[j] += dValues[i] * dValues[j];
                    xxSum[j] += squaredSum;
                    yySum[j] += dValues[j] * dValues[j];
                    adjustCount[j] += 1d;
                    adjustSumX[j] += dValues[i];
                    adjustSumY[j] += dValues[j];
                }
            }
        }
        LOG.debug("running time is {}ms in thread {}", (System.currentTimeMillis() - startO), Thread.currentThread().getName());
    }
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig)

Example 5 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class BinaryDTSerializer method getColumnMapping.

private static Map<Integer, Integer> getColumnMapping(List<ColumnConfig> columnConfigList) {
    Map<Integer, Integer> columnMapping = new HashMap<Integer, Integer>(columnConfigList.size(), 1f);
    int[] inputOutputIndex = DTrainUtils.getNumericAndCategoricalInputAndOutputCounts(columnConfigList);
    boolean isAfterVarSelect = inputOutputIndex[3] == 1 ? true : false;
    boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
    int index = 0;
    for (int i = 0; i < columnConfigList.size(); i++) {
        ColumnConfig columnConfig = columnConfigList.get(i);
        if (!isAfterVarSelect) {
            if (!columnConfig.isMeta() && !columnConfig.isTarget() && CommonUtils.isGoodCandidate(columnConfig, hasCandidates)) {
                columnMapping.put(columnConfig.getColumnNum(), index);
                index += 1;
            }
        } else {
            if (columnConfig != null && !columnConfig.isMeta() && !columnConfig.isTarget() && columnConfig.isFinalSelect()) {
                columnMapping.put(columnConfig.getColumnNum(), index);
                index += 1;
            }
        }
    }
    return columnMapping;
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) HashMap(java.util.HashMap)

Aggregations

ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)131 ArrayList (java.util.ArrayList)36 Test (org.testng.annotations.Test)17 IOException (java.io.IOException)16 HashMap (java.util.HashMap)12 Tuple (org.apache.pig.data.Tuple)10 File (java.io.File)8 NSColumn (ml.shifu.shifu.column.NSColumn)8 ModelConfig (ml.shifu.shifu.container.obj.ModelConfig)8 ShifuException (ml.shifu.shifu.exception.ShifuException)8 Path (org.apache.hadoop.fs.Path)8 List (java.util.List)7 Scanner (java.util.Scanner)7 DataBag (org.apache.pig.data.DataBag)7 SourceType (ml.shifu.shifu.container.obj.RawSourceData.SourceType)5 BasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)5 TrainingDataSet (ml.shifu.shifu.core.dvarsel.dataset.TrainingDataSet)5 BasicMLData (org.encog.ml.data.basic.BasicMLData)5 BufferedWriter (java.io.BufferedWriter)3 FileInputStream (java.io.FileInputStream)3