use of ml.shifu.shifu.core.VariableSelector in project shifu by ShifuML.
the class VarSelectModelProcessor method run.
/**
* Run for the variable selection
*/
@Override
public int run() throws Exception {
log.info("Step Start: varselect");
long start = System.currentTimeMillis();
try {
setUp(ModelStep.VARSELECT);
validateParameters();
// reset all selections if user specify or select by absolute number
if (getIsToReset()) {
log.info("Reset all selections data including type final select etc!");
resetAllFinalSelect();
} else if (getIsToList()) {
log.info("Below variables are selected - ");
for (ColumnConfig columnConfig : this.columnConfigList) {
if (columnConfig.isFinalSelect()) {
log.info(columnConfig.getColumnName());
}
}
log.info("----- Done -----");
} else if (getIsToAutoFilter()) {
log.info("Start to run variable auto filter.");
runAutoVarFilter();
log.info("----- Done -----");
} else if (getIsRecoverAuto()) {
String varselHistory = pathFinder.getVarSelHistory();
if (ShifuFileUtils.isFileExists(varselHistory, SourceType.LOCAL)) {
log.info("!!! Auto filtered variables will be recovered from history.");
recoverVarselStatusFromHist(varselHistory);
log.info("----- Done -----");
} else {
log.warn("No variables auto filter history is found.");
}
} else {
// sync to make sure load from hdfs config is consistent with local configuration
syncDataToHdfs(super.modelConfig.getDataSet().getSource());
String filterExpressions = super.modelConfig.getSegmentFilterExpressionsAsString();
Environment.getProperties().put("shifu.segment.expressions", filterExpressions);
if (StringUtils.isNotBlank(filterExpressions)) {
String[] splits = CommonUtils.split(filterExpressions, Constants.SHIFU_STATS_FILTER_EXPRESSIONS_DELIMETER);
for (int i = 0; i < super.columnConfigList.size(); i++) {
ColumnConfig config = super.columnConfigList.get(i);
int rawSize = super.columnConfigList.size() / (1 + splits.length);
if (config.isTarget()) {
for (int j = 0; j < splits.length; j++) {
ColumnConfig otherConfig = super.columnConfigList.get((j + 1) * rawSize + i);
otherConfig.setColumnFlag(ColumnFlag.ForceRemove);
otherConfig.setFinalSelect(false);
}
break;
}
}
this.saveColumnConfigList();
// sync to make sure load from hdfs config is consistent with local configuration
syncDataToHdfs(super.modelConfig.getDataSet().getSource());
}
if (modelConfig.isRegression()) {
String filterBy = this.modelConfig.getVarSelectFilterBy();
if (filterBy.equalsIgnoreCase(Constants.FILTER_BY_KS) || filterBy.equalsIgnoreCase(Constants.FILTER_BY_IV) || filterBy.equalsIgnoreCase(Constants.FILTER_BY_PARETO) || filterBy.equalsIgnoreCase(Constants.FILTER_BY_MIX)) {
VariableSelector selector = new VariableSelector(this.modelConfig, this.columnConfigList);
this.columnConfigList = selector.selectByFilter();
} else if (filterBy.equalsIgnoreCase(Constants.FILTER_BY_FI)) {
if (!CommonUtils.isTreeModel(modelConfig.getAlgorithm())) {
throw new IllegalArgumentException("Filter by FI only works well in GBT/RF. Please check your modelconfig::train.");
}
selectByFeatureImportance();
} else if (filterBy.equalsIgnoreCase(Constants.FILTER_BY_SE) || filterBy.equalsIgnoreCase(Constants.FILTER_BY_ST)) {
if (!Constants.NN.equalsIgnoreCase(modelConfig.getAlgorithm()) && !Constants.LR.equalsIgnoreCase(modelConfig.getAlgorithm())) {
throw new IllegalArgumentException("Filter by SE/ST only works well in NN/LR. Please check your modelconfig::train.");
}
int recursiveCnt = getRecursiveCnt();
int i = 0;
// create varsel directory and write original copy of ColumnConfig.json
ShifuFileUtils.createDirIfNotExists(pathFinder.getVarSelDir(), SourceType.LOCAL);
super.saveColumnConfigList(pathFinder.getVarSelColumnConfig(i), this.columnConfigList);
while ((i++) < recursiveCnt) {
String trainLogFile = TRAIN_LOG_PREFIX + "-" + (i - 1) + ".log";
distributedSEWrapper(trainLogFile);
// copy training log to SE train.log
ShifuFileUtils.move(trainLogFile, new File(pathFinder.getVarSelDir(), trainLogFile).getPath(), SourceType.LOCAL);
String varSelectMSEOutputPath = pathFinder.getVarSelectMSEOutputPath(modelConfig.getDataSet().getSource());
// even fail to run SE, still to create an empty se.x file
String varSelMSEHistPath = pathFinder.getVarSelMSEHistPath(i - 1);
ShifuFileUtils.createFileIfNotExists(varSelMSEHistPath, SourceType.LOCAL);
ShifuFileUtils.copyToLocal(new SourceFile(varSelectMSEOutputPath, modelConfig.getDataSet().getSource()), Constants.SHIFU_VARSELECT_SE_OUTPUT_NAME, varSelMSEHistPath);
// save as backup
super.saveColumnConfigList(pathFinder.getVarSelColumnConfig(i), this.columnConfigList);
// save as current copy
super.saveColumnConfigList();
}
} else if (filterBy.equalsIgnoreCase(Constants.FILTER_BY_VOTED)) {
votedVariablesSelection();
}
} else {
boolean hasCandidates = CommonUtils.hasCandidateColumns(this.columnConfigList);
if (this.modelConfig.getVarSelect().getForceEnable() && CollectionUtils.isNotEmpty(this.modelConfig.getListForceSelect())) {
log.info("Force Selection is enabled ... " + "for multi-classification, currently only use it to selected variables.");
for (ColumnConfig config : this.columnConfigList) {
if (config.isForceSelect()) {
if (!CommonUtils.isGoodCandidate(config, hasCandidates, modelConfig.isRegression())) {
log.warn("!! Variable - {} is not a good candidate. But it is in forceselect list", config.getColumnName());
}
config.setFinalSelect(true);
}
}
log.info("{} variables are selected by force.", this.modelConfig.getListForceSelect().size());
} else {
// multiple classification, select all candidate at first, TODO add SE for multi-classification
for (ColumnConfig config : this.columnConfigList) {
if (CommonUtils.isGoodCandidate(config, hasCandidates, modelConfig.isRegression())) {
config.setFinalSelect(true);
}
}
}
}
// clean shadow targets for multi-segments
cleanShadowTargetsForSegments();
if (modelConfig.getVarSelect().getAutoFilterEnable()) {
runAutoVarFilter();
}
}
// save column config to file and sync to
clearUp(ModelStep.VARSELECT);
} catch (ShifuException e) {
log.error("Error:" + e.getError().toString() + "; msg:" + e.getMessage(), e);
return -1;
} catch (Exception e) {
log.error("Error:" + e.getMessage(), e);
return -1;
}
log.info("Step Finished: varselect with {} ms", (System.currentTimeMillis() - start));
return 0;
}
Aggregations