Search in sources :

Example 16 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class VarSelectModelProcessor method autoVarSelCondition.

/**
 * To do some auto variable selection like remove ID-like variables, remove variable with high missing rate.
 *
 * @throws IOException
 *             any IO exception
 */
private void autoVarSelCondition(List<VarSelDesc> varSelDescList) throws IOException {
    // 1. check missing rate
    for (ColumnConfig config : columnConfigList) {
        if (// column needs check
        !config.isTarget() && !config.isMeta() && !config.isForceSelect() && config.isFinalSelect() && isHighMissingRateColumn(config)) {
            log.warn("Column {} is with very high missing rate, set final select to false. " + "If not, you can check it manually in ColumnConfig.json", config.getColumnName());
            config.setFinalSelect(false);
            varSelDescList.add(new VarSelDesc(config, VarSelReason.HIGH_MISSING_RATE));
        }
    }
    // 2. check KS and IV min threshold value
    for (ColumnConfig config : columnConfigList) {
        if (!config.isTarget() && !config.isMeta() && !config.isForceSelect() && config.isFinalSelect()) {
            float minIvThreshold = (super.modelConfig.getVarSelect().getMinIvThreshold() == null ? 0f : super.modelConfig.getVarSelect().getMinIvThreshold());
            if (config.getIv() != null && config.getIv() < minIvThreshold) {
                log.warn("IV of column {} is less than minimal IV threshold, set final select to false. " + "If not, you can check it manually in ColumnConfig.json", config.getColumnName());
                config.setFinalSelect(false);
                varSelDescList.add(new VarSelDesc(config, VarSelReason.IV_TOO_LOW));
            }
            float minKsThreshold = (super.modelConfig.getVarSelect().getMinKsThreshold() == null ? 0f : super.modelConfig.getVarSelect().getMinKsThreshold());
            if (config.getKs() != null && config.getKs() < minKsThreshold) {
                log.warn("KS of column {} is less than minimal KS threshold, set final select to false. " + "If not, you can check it manually in ColumnConfig.json", config.getColumnName());
                config.setFinalSelect(false);
                varSelDescList.add(new VarSelDesc(config, VarSelReason.KS_TOO_LOW));
            }
        }
    }
    // 3. check correlation value:
    if (!ShifuFileUtils.isFileExists(pathFinder.getLocalCorrelationCsvPath(), SourceType.LOCAL)) {
        return;
    }
    varSelectByCorrelation(varSelDescList);
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) VarSelDesc(ml.shifu.shifu.core.history.VarSelDesc)

Example 17 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class VarSelectModelProcessor method postProcessFIVarSelect.

private void postProcessFIVarSelect(Map<Integer, MutablePair<String, Double>> importances) throws IOException {
    int selectCnt = 0;
    for (ColumnConfig config : super.columnConfigList) {
        // enable ForceSelect
        if (config.isForceSelect()) {
            config.setFinalSelect(true);
            selectCnt++;
            log.info("Variable {} is selected, since it is in ForceSelect list.", config.getColumnName());
        }
    }
    VariableSelector.setFilterNumberByFilterOutRatio(this.modelConfig, this.columnConfigList);
    int targetCnt = this.modelConfig.getVarSelectFilterNum();
    List<Integer> candidateColumnIdList = new ArrayList<Integer>();
    candidateColumnIdList.addAll(importances.keySet());
    int i = 0;
    int candidateCount = candidateColumnIdList.size();
    // force-selected variables
    for (ColumnConfig columnConfig : this.columnConfigList) {
        if (columnConfig.isFinalSelect()) {
            columnConfig.setFinalSelect(true);
        }
    }
    Set<NSColumn> userCandidateColumns = CommonUtils.loadCandidateColumns(modelConfig);
    while (selectCnt < targetCnt && i < targetCnt) {
        if (i >= candidateCount) {
            log.warn("Var select finish due to feature importance count {} is less than target var count {}", candidateCount, targetCnt);
            break;
        }
        Integer columnId = candidateColumnIdList.get(i++);
        ColumnConfig columnConfig = this.columnConfigList.get(columnId);
        if (CollectionUtils.isNotEmpty(userCandidateColumns) && !userCandidateColumns.contains(new NSColumn(columnConfig.getColumnName()))) {
            log.info("Variable {} is not in user's candidate list. Skip it.", columnConfig.getColumnName());
        } else if (!columnConfig.isForceSelect() && !columnConfig.isForceRemove()) {
            columnConfig.setFinalSelect(true);
            selectCnt++;
            log.info("Variable {} is selected.", columnConfig.getColumnName());
        }
    }
    log.info("{} variables are selected.", selectCnt);
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) NSColumn(ml.shifu.shifu.column.NSColumn)

Example 18 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class MapReducerStatsWorker method scanStatsResult.

/**
 * Scan the stats result and save them into column configure
 *
 * @param scanner
 *            the scanners to be read
 */
private void scanStatsResult(Scanner scanner, int ccInitSize) {
    while (scanner.hasNextLine()) {
        String[] raw = scanner.nextLine().trim().split("\\|");
        if (raw.length == 1) {
            continue;
        }
        if (raw.length < 25) {
            log.info("The stats data has " + raw.length + " fields.");
            log.info("The stats data is - " + Arrays.toString(raw));
        }
        int columnNum = Integer.parseInt(raw[0]);
        int corrColumnNum = columnNum;
        if (columnNum >= ccInitSize) {
            corrColumnNum = columnNum % ccInitSize;
        }
        try {
            ColumnConfig basicConfig = this.columnConfigList.get(corrColumnNum);
            log.debug("basicConfig is - " + basicConfig.getColumnName() + " corrColumnNum:" + corrColumnNum);
            ColumnConfig config = null;
            if (columnNum >= ccInitSize) {
                config = new ColumnConfig();
                config.setColumnNum(columnNum);
                config.setColumnName(basicConfig.getColumnName() + "_" + (columnNum / ccInitSize));
                config.setVersion(basicConfig.getVersion());
                config.setColumnType(basicConfig.getColumnType());
                config.setColumnFlag(basicConfig.getColumnFlag() == ColumnFlag.Target ? ColumnFlag.Meta : basicConfig.getColumnFlag());
                log.debug("basicConfig is - " + basicConfig.getColumnName() + " corrColumnNum:" + corrColumnNum + ", currColumnName: " + columnNum + ", currColumnType:" + config.getColumnType());
                this.columnConfigList.add(config);
            } else {
                config = basicConfig;
            }
            if (config.isHybrid()) {
                String[] splits = CommonUtils.split(raw[1], Constants.HYBRID_BIN_STR_DILIMETER);
                config.setBinBoundary(CommonUtils.stringToDoubleList(splits[0]));
                String binCategory = Base64Utils.base64Decode(splits[1]);
                config.setBinCategory(CommonUtils.stringToStringList(binCategory, CalculateStatsUDF.CATEGORY_VAL_SEPARATOR));
            } else if (config.isCategorical()) {
                String binCategory = Base64Utils.base64Decode(raw[1]);
                config.setBinCategory(CommonUtils.stringToStringList(binCategory, CalculateStatsUDF.CATEGORY_VAL_SEPARATOR));
                config.setBinBoundary(null);
            } else {
                config.setBinBoundary(CommonUtils.stringToDoubleList(raw[1]));
                config.setBinCategory(null);
            }
            config.setBinCountNeg(CommonUtils.stringToIntegerList(raw[2]));
            config.setBinCountPos(CommonUtils.stringToIntegerList(raw[3]));
            // config.setBinAvgScore(CommonUtils.stringToIntegerList(raw[4]));
            config.setBinPosCaseRate(CommonUtils.stringToDoubleList(raw[5]));
            config.setBinLength(config.getBinCountNeg().size());
            config.setKs(parseDouble(raw[6]));
            config.setIv(parseDouble(raw[7]));
            config.setMax(parseDouble(raw[8]));
            config.setMin(parseDouble(raw[9]));
            config.setMean(parseDouble(raw[10]));
            config.setStdDev(parseDouble(raw[11], Double.NaN));
            // magic?
            config.setColumnType(ColumnType.of(raw[12]));
            config.setMedian(parseDouble(raw[13]));
            config.setMissingCnt(parseLong(raw[14]));
            config.setTotalCount(parseLong(raw[15]));
            config.setMissingPercentage(parseDouble(raw[16]));
            config.setBinWeightedNeg(CommonUtils.stringToDoubleList(raw[17]));
            config.setBinWeightedPos(CommonUtils.stringToDoubleList(raw[18]));
            config.getColumnStats().setWoe(parseDouble(raw[19]));
            config.getColumnStats().setWeightedWoe(parseDouble(raw[20]));
            config.getColumnStats().setWeightedKs(parseDouble(raw[21]));
            config.getColumnStats().setWeightedIv(parseDouble(raw[22]));
            config.getColumnBinning().setBinCountWoe(CommonUtils.stringToDoubleList(raw[23]));
            config.getColumnBinning().setBinWeightedWoe(CommonUtils.stringToDoubleList(raw[24]));
            // TODO magic code?
            if (raw.length >= 26) {
                config.getColumnStats().setSkewness(parseDouble(raw[25]));
            }
            if (raw.length >= 27) {
                config.getColumnStats().setKurtosis(parseDouble(raw[26]));
            }
            if (raw.length >= 30) {
                config.getColumnStats().setValidNumCount(parseLong(raw[29]));
            }
            if (raw.length >= 31) {
                config.getColumnStats().setDistinctCount(parseLong(raw[30]));
            }
            if (raw.length >= 32) {
                if (raw[31] != null) {
                    List<String> sampleValues = Arrays.asList(Base64Utils.base64Decode(raw[31]).split(","));
                    config.setSampleValues(sampleValues);
                }
            }
            if (raw.length >= 33) {
                config.getColumnStats().set25th(parseDouble(raw[32]));
            }
            if (raw.length >= 34) {
                config.getColumnStats().set75th(parseDouble(raw[33]));
            }
        } catch (Exception e) {
            log.error(String.format("Fail to process following column : %s name: %s error: %s", columnNum, this.columnConfigList.get(corrColumnNum).getColumnName(), e.getMessage()), e);
            continue;
        }
    }
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) ShifuException(ml.shifu.shifu.exception.ShifuException) JexlException(org.apache.commons.jexl2.JexlException) IOException(java.io.IOException)

Example 19 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class ExportModelProcessor method exportVariableCorr.

private int exportVariableCorr() throws IOException {
    Set<VarCorrInfo> varCorrInfoSet = new HashSet<VarCorrInfo>();
    BufferedReader reader = ShifuFileUtils.getReader(pathFinder.getLocalCorrelationCsvPath(), SourceType.LOCAL);
    PostCorrelationMetric metric = this.modelConfig.getVarSelect().getPostCorrelationMetric();
    boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
    try {
        int lineNum = 0;
        String line = null;
        while ((line = reader.readLine()) != null) {
            lineNum += 1;
            if (lineNum <= 2) {
                // skip first 2 lines which are indexes and names
                continue;
            }
            String[] columns = CommonUtils.split(line, ",");
            if (columns != null && columns.length == columnConfigList.size() + 2) {
                int columnIndex = Integer.parseInt(columns[0].trim());
                ColumnConfig fromConfig = this.columnConfigList.get(columnIndex);
                if (fromConfig.isTarget() || CommonUtils.isGoodCandidate(fromConfig, hasCandidates)) {
                    double[] corrArray = getCorrArray(columns);
                    for (int i = 0; i < corrArray.length; i++) {
                        ColumnConfig toConfig = this.columnConfigList.get(i);
                        if (i != columnIndex && !toConfig.isTarget() && !toConfig.isMeta()) {
                            varCorrInfoSet.add(new VarCorrInfo(fromConfig.getColumnName(), toConfig.getColumnName(), corrArray[i], getColumnMetric(fromConfig, metric), getColumnMetric(toConfig, metric)));
                        }
                    }
                }
            }
        }
    } finally {
        IOUtils.closeQuietly(reader);
    }
    List<VarCorrInfo> varCorrInfoList = new ArrayList<VarCorrInfo>(varCorrInfoSet);
    Collections.sort(varCorrInfoList);
    String corrExportPath = this.pathFinder.getCorrExportPath();
    ShifuFileUtils.writeLines(varCorrInfoList, corrExportPath, SourceType.LOCAL);
    log.info("Done. The correlations are exported to {}", corrExportPath);
    return 0;
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) BufferedReader(java.io.BufferedReader) PostCorrelationMetric(ml.shifu.shifu.container.obj.ModelVarSelectConf.PostCorrelationMetric)

Example 20 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class ExportModelProcessor method saveColumnStatus.

private void saveColumnStatus() throws IOException {
    Path localColumnStatsPath = new Path(pathFinder.getLocalColumnStatsPath());
    log.info("Saving ColumnStatus to local file system: {}.", localColumnStatsPath);
    if (HDFSUtils.getLocalFS().exists(localColumnStatsPath)) {
        HDFSUtils.getLocalFS().delete(localColumnStatsPath, true);
    }
    BufferedWriter writer = null;
    try {
        writer = ShifuFileUtils.getWriter(localColumnStatsPath.toString(), SourceType.LOCAL);
        Map<Integer, List<String>> ccUnitStatsMap = loadColumnConfigUnitStats();
        List<String> firstUnitStats = null;
        if (MapUtils.isNotEmpty(ccUnitStatsMap)) {
            firstUnitStats = ccUnitStatsMap.entrySet().iterator().next().getValue();
            writer.write("dataSet,columnFlag,columnName,columnNum,iv,ks,max,mean,median,min,missingCount," + "missingPercentage,stdDev,totalCount,distinctCount,weightedIv,weightedKs,weightedWoe,woe," + "skewness,kurtosis,columnType,finalSelect,psi,unitstats,version," + unitsToHeader(firstUnitStats) + "\n");
        } else {
            writer.write("dataSet,columnFlag,columnName,columnNum,iv,ks,max,mean,median,min,missingCount," + "missingPercentage,stdDev,totalCount,distinctCount,weightedIv,weightedKs,weightedWoe,woe," + "skewness,kurtosis,columnType,finalSelect,psi,unitstats,version\n");
        }
        StringBuilder builder = new StringBuilder(500);
        for (ColumnConfig columnConfig : columnConfigList) {
            builder.setLength(0);
            builder.append(modelConfig.getBasic().getName()).append(',');
            builder.append(columnConfig.getColumnFlag()).append(',');
            builder.append(columnConfig.getColumnName()).append(',');
            builder.append(columnConfig.getColumnNum()).append(',');
            builder.append(columnConfig.getIv()).append(',');
            builder.append(columnConfig.getKs()).append(',');
            builder.append(columnConfig.getColumnStats().getMax()).append(',');
            builder.append(columnConfig.getColumnStats().getMean()).append(',');
            builder.append(columnConfig.getColumnStats().getMedian()).append(',');
            builder.append(columnConfig.getColumnStats().getMin()).append(',');
            builder.append(columnConfig.getColumnStats().getMissingCount()).append(',');
            builder.append(columnConfig.getColumnStats().getMissingPercentage()).append(',');
            builder.append(columnConfig.getColumnStats().getStdDev()).append(',');
            builder.append(columnConfig.getColumnStats().getTotalCount()).append(',');
            builder.append(columnConfig.getColumnStats().getDistinctCount()).append(',');
            builder.append(columnConfig.getColumnStats().getWeightedIv()).append(',');
            builder.append(columnConfig.getColumnStats().getWeightedKs()).append(',');
            builder.append(columnConfig.getColumnStats().getWeightedWoe()).append(',');
            builder.append(columnConfig.getColumnStats().getWoe()).append(',');
            builder.append(columnConfig.getColumnStats().getSkewness()).append(',');
            builder.append(columnConfig.getColumnStats().getKurtosis()).append(',');
            builder.append(columnConfig.getColumnType()).append(',');
            builder.append(columnConfig.isFinalSelect()).append(',');
            builder.append(columnConfig.getPSI()).append(',');
            builder.append(StringUtils.join(columnConfig.getUnitStats(), '|')).append(',');
            if (CollectionUtils.isNotEmpty(firstUnitStats)) {
                builder.append(modelConfig.getBasic().getVersion()).append(",");
                builder.append(splitUnitStatsToColumn(ccUnitStatsMap.get(columnConfig.getColumnNum()), firstUnitStats.size())).append("\n");
            } else {
                builder.append(modelConfig.getBasic().getVersion()).append("\n");
            }
            writer.write(builder.toString());
        }
    } finally {
        if (writer != null) {
            writer.close();
        }
    }
}
Also used : Path(org.apache.hadoop.fs.Path) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) BufferedWriter(java.io.BufferedWriter)

Aggregations

ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)131 ArrayList (java.util.ArrayList)36 Test (org.testng.annotations.Test)17 IOException (java.io.IOException)16 HashMap (java.util.HashMap)12 Tuple (org.apache.pig.data.Tuple)10 File (java.io.File)8 NSColumn (ml.shifu.shifu.column.NSColumn)8 ModelConfig (ml.shifu.shifu.container.obj.ModelConfig)8 ShifuException (ml.shifu.shifu.exception.ShifuException)8 Path (org.apache.hadoop.fs.Path)8 List (java.util.List)7 Scanner (java.util.Scanner)7 DataBag (org.apache.pig.data.DataBag)7 SourceType (ml.shifu.shifu.container.obj.RawSourceData.SourceType)5 BasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)5 TrainingDataSet (ml.shifu.shifu.core.dvarsel.dataset.TrainingDataSet)5 BasicMLData (org.encog.ml.data.basic.BasicMLData)5 BufferedWriter (java.io.BufferedWriter)3 FileInputStream (java.io.FileInputStream)3