Search in sources :

Example 1 with VariableSelector

use of ml.shifu.shifu.core.VariableSelector in project shifu by ShifuML.

the class VarSelectModelProcessor method run.

/**
 * Run for the variable selection
 */
@Override
public int run() throws Exception {
    log.info("Step Start: varselect");
    long start = System.currentTimeMillis();
    try {
        setUp(ModelStep.VARSELECT);
        validateParameters();
        // reset all selections if user specify or select by absolute number
        if (getIsToReset()) {
            log.info("Reset all selections data including type final select etc!");
            resetAllFinalSelect();
        } else if (getIsToList()) {
            log.info("Below variables are selected - ");
            for (ColumnConfig columnConfig : this.columnConfigList) {
                if (columnConfig.isFinalSelect()) {
                    log.info(columnConfig.getColumnName());
                }
            }
            log.info("-----  Done -----");
        } else if (getIsToAutoFilter()) {
            log.info("Start to run variable auto filter.");
            runAutoVarFilter();
            log.info("-----  Done -----");
        } else if (getIsRecoverAuto()) {
            String varselHistory = pathFinder.getVarSelHistory();
            if (ShifuFileUtils.isFileExists(varselHistory, SourceType.LOCAL)) {
                log.info("!!! Auto filtered variables will be recovered from history.");
                recoverVarselStatusFromHist(varselHistory);
                log.info("-----  Done -----");
            } else {
                log.warn("No variables auto filter history is found.");
            }
        } else {
            // sync to make sure load from hdfs config is consistent with local configuration
            syncDataToHdfs(super.modelConfig.getDataSet().getSource());
            String filterExpressions = super.modelConfig.getSegmentFilterExpressionsAsString();
            Environment.getProperties().put("shifu.segment.expressions", filterExpressions);
            if (StringUtils.isNotBlank(filterExpressions)) {
                String[] splits = CommonUtils.split(filterExpressions, Constants.SHIFU_STATS_FILTER_EXPRESSIONS_DELIMETER);
                for (int i = 0; i < super.columnConfigList.size(); i++) {
                    ColumnConfig config = super.columnConfigList.get(i);
                    int rawSize = super.columnConfigList.size() / (1 + splits.length);
                    if (config.isTarget()) {
                        for (int j = 0; j < splits.length; j++) {
                            ColumnConfig otherConfig = super.columnConfigList.get((j + 1) * rawSize + i);
                            otherConfig.setColumnFlag(ColumnFlag.ForceRemove);
                            otherConfig.setFinalSelect(false);
                        }
                        break;
                    }
                }
                this.saveColumnConfigList();
                // sync to make sure load from hdfs config is consistent with local configuration
                syncDataToHdfs(super.modelConfig.getDataSet().getSource());
            }
            if (modelConfig.isRegression()) {
                String filterBy = this.modelConfig.getVarSelectFilterBy();
                if (filterBy.equalsIgnoreCase(Constants.FILTER_BY_KS) || filterBy.equalsIgnoreCase(Constants.FILTER_BY_IV) || filterBy.equalsIgnoreCase(Constants.FILTER_BY_PARETO) || filterBy.equalsIgnoreCase(Constants.FILTER_BY_MIX)) {
                    VariableSelector selector = new VariableSelector(this.modelConfig, this.columnConfigList);
                    this.columnConfigList = selector.selectByFilter();
                } else if (filterBy.equalsIgnoreCase(Constants.FILTER_BY_FI)) {
                    if (!CommonUtils.isTreeModel(modelConfig.getAlgorithm())) {
                        throw new IllegalArgumentException("Filter by FI only works well in GBT/RF. Please check your modelconfig::train.");
                    }
                    selectByFeatureImportance();
                } else if (filterBy.equalsIgnoreCase(Constants.FILTER_BY_SE) || filterBy.equalsIgnoreCase(Constants.FILTER_BY_ST)) {
                    if (!Constants.NN.equalsIgnoreCase(modelConfig.getAlgorithm()) && !Constants.LR.equalsIgnoreCase(modelConfig.getAlgorithm())) {
                        throw new IllegalArgumentException("Filter by SE/ST only works well in NN/LR. Please check your modelconfig::train.");
                    }
                    int recursiveCnt = getRecursiveCnt();
                    int i = 0;
                    // create varsel directory and write original copy of ColumnConfig.json
                    ShifuFileUtils.createDirIfNotExists(pathFinder.getVarSelDir(), SourceType.LOCAL);
                    super.saveColumnConfigList(pathFinder.getVarSelColumnConfig(i), this.columnConfigList);
                    while ((i++) < recursiveCnt) {
                        String trainLogFile = TRAIN_LOG_PREFIX + "-" + (i - 1) + ".log";
                        distributedSEWrapper(trainLogFile);
                        // copy training log to SE train.log
                        ShifuFileUtils.move(trainLogFile, new File(pathFinder.getVarSelDir(), trainLogFile).getPath(), SourceType.LOCAL);
                        String varSelectMSEOutputPath = pathFinder.getVarSelectMSEOutputPath(modelConfig.getDataSet().getSource());
                        // even fail to run SE, still to create an empty se.x file
                        String varSelMSEHistPath = pathFinder.getVarSelMSEHistPath(i - 1);
                        ShifuFileUtils.createFileIfNotExists(varSelMSEHistPath, SourceType.LOCAL);
                        ShifuFileUtils.copyToLocal(new SourceFile(varSelectMSEOutputPath, modelConfig.getDataSet().getSource()), Constants.SHIFU_VARSELECT_SE_OUTPUT_NAME, varSelMSEHistPath);
                        // save as backup
                        super.saveColumnConfigList(pathFinder.getVarSelColumnConfig(i), this.columnConfigList);
                        // save as current copy
                        super.saveColumnConfigList();
                    }
                } else if (filterBy.equalsIgnoreCase(Constants.FILTER_BY_VOTED)) {
                    votedVariablesSelection();
                }
            } else {
                boolean hasCandidates = CommonUtils.hasCandidateColumns(this.columnConfigList);
                if (this.modelConfig.getVarSelect().getForceEnable() && CollectionUtils.isNotEmpty(this.modelConfig.getListForceSelect())) {
                    log.info("Force Selection is enabled ... " + "for multi-classification, currently only use it to selected variables.");
                    for (ColumnConfig config : this.columnConfigList) {
                        if (config.isForceSelect()) {
                            if (!CommonUtils.isGoodCandidate(config, hasCandidates, modelConfig.isRegression())) {
                                log.warn("!! Variable - {} is not a good candidate. But it is in forceselect list", config.getColumnName());
                            }
                            config.setFinalSelect(true);
                        }
                    }
                    log.info("{} variables are selected by force.", this.modelConfig.getListForceSelect().size());
                } else {
                    // multiple classification, select all candidate at first, TODO add SE for multi-classification
                    for (ColumnConfig config : this.columnConfigList) {
                        if (CommonUtils.isGoodCandidate(config, hasCandidates, modelConfig.isRegression())) {
                            config.setFinalSelect(true);
                        }
                    }
                }
            }
            // clean shadow targets for multi-segments
            cleanShadowTargetsForSegments();
            if (modelConfig.getVarSelect().getAutoFilterEnable()) {
                runAutoVarFilter();
            }
        }
        // save column config to file and sync to
        clearUp(ModelStep.VARSELECT);
    } catch (ShifuException e) {
        log.error("Error:" + e.getError().toString() + "; msg:" + e.getMessage(), e);
        return -1;
    } catch (Exception e) {
        log.error("Error:" + e.getMessage(), e);
        return -1;
    }
    log.info("Step Finished: varselect with {} ms", (System.currentTimeMillis() - start));
    return 0;
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) VariableSelector(ml.shifu.shifu.core.VariableSelector) SourceFile(ml.shifu.shifu.fs.SourceFile) SourceFile(ml.shifu.shifu.fs.SourceFile) File(java.io.File) ShifuException(ml.shifu.shifu.exception.ShifuException) ShifuException(ml.shifu.shifu.exception.ShifuException) JexlException(org.apache.commons.jexl2.JexlException) IOException(java.io.IOException)

Aggregations

File (java.io.File)1 IOException (java.io.IOException)1 ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)1 VariableSelector (ml.shifu.shifu.core.VariableSelector)1 ShifuException (ml.shifu.shifu.exception.ShifuException)1 SourceFile (ml.shifu.shifu.fs.SourceFile)1 JexlException (org.apache.commons.jexl2.JexlException)1