use of ml.shifu.shifu.core.dtrain.FeatureSubsetStrategy in project shifu by ShifuML.
the class ModelInspector method checkTrainSetting.
/**
* Check the setting for model training.
* It will make sure (num_of_layers > 0
* && num_of_layers = hidden_nodes_size
* && num_of_layse = active_func_size)
*
* @param train
* - @ModelTrainConf to check
* @return @ValidateResult
*/
@SuppressWarnings("unchecked")
private ValidateResult checkTrainSetting(ModelConfig modelConfig, ModelTrainConf train) {
ValidateResult result = new ValidateResult(true);
if (train.getBaggingNum() == null || train.getBaggingNum() < 0) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("Bagging number should be greater than zero in train configuration");
result = ValidateResult.mergeResult(result, tmpResult);
}
if (train.getNumKFold() != null && train.getNumKFold() > 20) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("numKFold should be in (0, 20] or <=0 (not dp k-crossValidation)");
result = ValidateResult.mergeResult(result, tmpResult);
}
if (train.getBaggingSampleRate() == null || train.getBaggingSampleRate().compareTo(Double.valueOf(0)) <= 0 || train.getBaggingSampleRate().compareTo(Double.valueOf(1)) > 0) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("Bagging sample rate number should be in (0, 1].");
result = ValidateResult.mergeResult(result, tmpResult);
}
if (train.getValidSetRate() == null || train.getValidSetRate().compareTo(Double.valueOf(0)) < 0 || train.getValidSetRate().compareTo(Double.valueOf(1)) >= 0) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("Validation set rate number should be in [0, 1).");
result = ValidateResult.mergeResult(result, tmpResult);
}
if (train.getNumTrainEpochs() == null || train.getNumTrainEpochs() <= 0) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("Epochs should be larger than 0.");
result = ValidateResult.mergeResult(result, tmpResult);
}
if (train.getEpochsPerIteration() != null && train.getEpochsPerIteration() <= 0) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("'epochsPerIteration' should be larger than 0 if set.");
result = ValidateResult.mergeResult(result, tmpResult);
}
if (train.getWorkerThreadCount() != null && (train.getWorkerThreadCount() <= 0 || train.getWorkerThreadCount() > 32)) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("'workerThreadCount' should be in (0, 32] if set.");
result = ValidateResult.mergeResult(result, tmpResult);
}
if (train.getConvergenceThreshold() != null && train.getConvergenceThreshold().compareTo(0.0) < 0) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("'threshold' should be larger than or equal to 0.0 if set.");
result = ValidateResult.mergeResult(result, tmpResult);
}
if (modelConfig.isClassification() && train.isOneVsAll() && !CommonUtils.isTreeModel(train.getAlgorithm()) && !train.getAlgorithm().equalsIgnoreCase("nn")) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("'one vs all' or 'one vs rest' is only enabled with 'RF' or 'GBT' or 'NN' algorithm");
result = ValidateResult.mergeResult(result, tmpResult);
}
if (modelConfig.isClassification() && train.getMultiClassifyMethod() == MultipleClassification.NATIVE && train.getAlgorithm().equalsIgnoreCase(CommonConstants.RF_ALG_NAME)) {
Object impurity = train.getParams().get("Impurity");
if (impurity != null && !"entropy".equalsIgnoreCase(impurity.toString()) && !"gini".equalsIgnoreCase(impurity.toString())) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("Impurity should be in [entropy,gini] if native mutiple classification in RF.");
result = ValidateResult.mergeResult(result, tmpResult);
}
}
GridSearch gs = new GridSearch(train.getParams(), train.getGridConfigFileContent());
// such parameter validation only in regression and not grid search mode
if (modelConfig.isRegression() && !gs.hasHyperParam()) {
if (train.getAlgorithm().equalsIgnoreCase("nn")) {
Map<String, Object> params = train.getParams();
Object loss = params.get("Loss");
if (loss != null && !"log".equalsIgnoreCase(loss.toString()) && !"squared".equalsIgnoreCase(loss.toString()) && !"absolute".equalsIgnoreCase(loss.toString())) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("Loss should be in [log,squared,absolute].");
result = ValidateResult.mergeResult(result, tmpResult);
}
Object TFloss = params.get("TF.loss");
if (TFloss != null && !"squared".equalsIgnoreCase(TFloss.toString()) && !"absolute".equalsIgnoreCase(TFloss.toString()) && !"log".equalsIgnoreCase(TFloss.toString())) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("Loss should be in [log,squared,absolute].");
result = ValidateResult.mergeResult(result, tmpResult);
}
Object TFOptimizer = params.get("TF.optimizer");
if (TFOptimizer != null && !"adam".equalsIgnoreCase(TFOptimizer.toString()) && !"gradientDescent".equalsIgnoreCase(TFOptimizer.toString()) && !"RMSProp".equalsIgnoreCase(TFOptimizer.toString())) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("tensorflow optimizer should be in [RMSProp,gradientDescent,adam].");
result = ValidateResult.mergeResult(result, tmpResult);
}
int layerCnt = (Integer) params.get(CommonConstants.NUM_HIDDEN_LAYERS);
if (layerCnt < 0) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("The number of hidden layers should be >= 0 in train configuration");
result = ValidateResult.mergeResult(result, tmpResult);
}
List<Integer> hiddenNode = (List<Integer>) params.get(CommonConstants.NUM_HIDDEN_NODES);
List<String> activateFucs = (List<String>) params.get(CommonConstants.ACTIVATION_FUNC);
if (hiddenNode.size() != activateFucs.size() || layerCnt != activateFucs.size()) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add(CommonConstants.NUM_HIDDEN_LAYERS + "/SIZE(" + CommonConstants.NUM_HIDDEN_NODES + ")" + "/SIZE(" + CommonConstants.ACTIVATION_FUNC + ")" + " should be equal in train configuration");
result = ValidateResult.mergeResult(result, tmpResult);
}
Double learningRate = Double.valueOf(params.get(CommonConstants.LEARNING_RATE).toString());
if (learningRate != null && (learningRate.compareTo(Double.valueOf(0)) <= 0)) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("Learning rate should be larger than 0.");
result = ValidateResult.mergeResult(result, tmpResult);
}
Object learningDecayO = params.get(CommonConstants.LEARNING_DECAY);
if (learningDecayO != null) {
Double learningDecay = Double.valueOf(learningDecayO.toString());
if (learningDecay != null && ((learningDecay.compareTo(Double.valueOf(0)) < 0) || (learningDecay.compareTo(Double.valueOf(1)) >= 0))) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("Learning decay should be in [0, 1) if set.");
result = ValidateResult.mergeResult(result, tmpResult);
}
}
Object dropoutObj = params.get(CommonConstants.DROPOUT_RATE);
if (dropoutObj != null) {
Double dropoutRate = Double.valueOf(dropoutObj.toString());
if (dropoutRate != null && (dropoutRate < 0d || dropoutRate >= 1d)) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("Dropout rate should be in [0, 1).");
result = ValidateResult.mergeResult(result, tmpResult);
}
}
Object fixedLayersObj = params.get(CommonConstants.FIXED_LAYERS);
if (fixedLayersObj != null) {
List<Integer> fixedLayers = (List<Integer>) fixedLayersObj;
for (int layer : fixedLayers) {
if (layer <= 0 || layer > (layerCnt + 1)) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("Fixed layer id " + layer + " is invaild. It should be between 0 and hidden layer cnt + output layer:" + (layerCnt + 1));
result = ValidateResult.mergeResult(result, tmpResult);
}
}
}
Object miniBatchsO = params.get(CommonConstants.MINI_BATCH);
if (miniBatchsO != null) {
Integer miniBatchs = Integer.valueOf(miniBatchsO.toString());
if (miniBatchs != null && (miniBatchs <= 0 || miniBatchs > 1000)) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("MiniBatchs should be in (0, 1000] if set.");
result = ValidateResult.mergeResult(result, tmpResult);
}
}
Object momentumO = params.get("Momentum");
if (momentumO != null) {
Double momentum = Double.valueOf(momentumO.toString());
if (momentum != null && momentum <= 0d) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("Momentum should be in (0, ) if set.");
result = ValidateResult.mergeResult(result, tmpResult);
}
}
Object adamBeta1O = params.get("AdamBeta1");
if (adamBeta1O != null) {
Double adamBeta1 = Double.valueOf(adamBeta1O.toString());
if (adamBeta1 != null && (adamBeta1 <= 0d || adamBeta1 >= 1d)) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("AdamBeta1 should be in (0, 1) if set.");
result = ValidateResult.mergeResult(result, tmpResult);
}
}
Object adamBeta2O = params.get("AdamBeta2");
if (adamBeta2O != null) {
Double adamBeta2 = Double.valueOf(adamBeta2O.toString());
if (adamBeta2 != null && (adamBeta2 <= 0d || adamBeta2 >= 1d)) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("AdamBeta2 should be in (0, 1) if set.");
result = ValidateResult.mergeResult(result, tmpResult);
}
}
}
if (train.getAlgorithm().equalsIgnoreCase(CommonConstants.GBT_ALG_NAME) || train.getAlgorithm().equalsIgnoreCase(CommonConstants.RF_ALG_NAME) || train.getAlgorithm().equalsIgnoreCase(NNConstants.NN_ALG_NAME)) {
Map<String, Object> params = train.getParams();
Object fssObj = params.get("FeatureSubsetStrategy");
if (fssObj == null) {
if (train.getAlgorithm().equalsIgnoreCase(CommonConstants.GBT_ALG_NAME) || train.getAlgorithm().equalsIgnoreCase(CommonConstants.RF_ALG_NAME)) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("FeatureSubsetStrategy is not set in RF/GBT algorithm.");
result = ValidateResult.mergeResult(result, tmpResult);
}
} else {
boolean isNumber = false;
double doubleFss = 0;
try {
doubleFss = Double.parseDouble(fssObj.toString());
isNumber = true;
} catch (Exception e) {
isNumber = false;
}
if (isNumber) {
// if not in [0, 1] failed
if (doubleFss <= 0d || doubleFss > 1d) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("FeatureSubsetStrategy if double should be in (0, 1]");
result = ValidateResult.mergeResult(result, tmpResult);
}
} else {
boolean fssInEnum = false;
for (FeatureSubsetStrategy fss : FeatureSubsetStrategy.values()) {
if (fss.toString().equalsIgnoreCase(fssObj.toString())) {
fssInEnum = true;
break;
}
}
if (!fssInEnum) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("FeatureSubsetStrategy if string should be in ['ALL', 'HALF', 'ONETHIRD' , 'TWOTHIRDS' , 'AUTO' , 'SQRT' , 'LOG2']");
result = ValidateResult.mergeResult(result, tmpResult);
}
}
}
}
if (train.getAlgorithm().equalsIgnoreCase(CommonConstants.GBT_ALG_NAME) || train.getAlgorithm().equalsIgnoreCase(CommonConstants.RF_ALG_NAME)) {
Map<String, Object> params = train.getParams();
if (train.getAlgorithm().equalsIgnoreCase(CommonConstants.GBT_ALG_NAME)) {
Object loss = params.get("Loss");
if (loss != null && !"log".equalsIgnoreCase(loss.toString()) && !"squared".equalsIgnoreCase(loss.toString()) && !"halfgradsquared".equalsIgnoreCase(loss.toString()) && !"absolute".equalsIgnoreCase(loss.toString())) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("Loss should be in [log,squared,halfgradsquared,absolute].");
result = ValidateResult.mergeResult(result, tmpResult);
}
if (loss == null) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("'Loss' parameter isn't being set in train#parameters in GBT training.");
result = ValidateResult.mergeResult(result, tmpResult);
}
}
Object maxDepthObj = params.get("MaxDepth");
if (maxDepthObj != null) {
int maxDepth = Integer.valueOf(maxDepthObj.toString());
if (maxDepth <= 0 || maxDepth > 20) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("MaxDepth should in [1, 20].");
result = ValidateResult.mergeResult(result, tmpResult);
}
}
Object vtObj = params.get("ValidationTolerance");
if (vtObj != null) {
double validationTolerance = Double.valueOf(vtObj.toString());
if (validationTolerance < 0d || validationTolerance >= 1d) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("ValidationTolerance should in [0, 1).");
result = ValidateResult.mergeResult(result, tmpResult);
}
}
Object maxLeavesObj = params.get("MaxLeaves");
if (maxLeavesObj != null) {
int maxLeaves = Integer.valueOf(maxLeavesObj.toString());
if (maxLeaves <= 0) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("MaxLeaves should in [1, Integer.MAX_VALUE].");
result = ValidateResult.mergeResult(result, tmpResult);
}
}
if (maxDepthObj == null && maxLeavesObj == null) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("'MaxDepth' or 'MaxLeaves' parameters at least one of both should be set in train#parameters in GBT training.");
result = ValidateResult.mergeResult(result, tmpResult);
}
Object maxStatsMemoryMBObj = params.get("MaxStatsMemoryMB");
if (maxStatsMemoryMBObj != null) {
int maxStatsMemoryMB = Integer.valueOf(maxStatsMemoryMBObj.toString());
if (maxStatsMemoryMB <= 0) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("MaxStatsMemoryMB should > 0.");
result = ValidateResult.mergeResult(result, tmpResult);
}
}
Object dropoutObj = params.get(CommonConstants.DROPOUT_RATE);
if (dropoutObj != null) {
Double dropoutRate = Double.valueOf(dropoutObj.toString());
if (dropoutRate != null && (dropoutRate < 0d || dropoutRate >= 1d)) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("Dropout rate should be in [0, 1).");
result = ValidateResult.mergeResult(result, tmpResult);
}
}
if (train.getAlgorithm().equalsIgnoreCase(CommonConstants.GBT_ALG_NAME)) {
Object learningRateObj = params.get(CommonConstants.LEARNING_RATE);
if (learningRateObj != null) {
Double learningRate = Double.valueOf(learningRateObj.toString());
if (learningRate != null && (learningRate.compareTo(Double.valueOf(0)) <= 0)) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("Learning rate should be larger than 0.");
result = ValidateResult.mergeResult(result, tmpResult);
}
} else {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("'LearningRate' parameter isn't being set in train#parameters in GBT training.");
result = ValidateResult.mergeResult(result, tmpResult);
}
}
Object minInstancesPerNodeObj = params.get("MinInstancesPerNode");
if (minInstancesPerNodeObj != null) {
int minInstancesPerNode = Integer.valueOf(minInstancesPerNodeObj.toString());
if (minInstancesPerNode <= 0) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("MinInstancesPerNode should > 0.");
result = ValidateResult.mergeResult(result, tmpResult);
}
} else {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("'MinInstancesPerNode' parameter isn't be set in train#parameters in GBT/RF training.");
result = ValidateResult.mergeResult(result, tmpResult);
}
Object treeNumObj = params.get("TreeNum");
if (treeNumObj != null) {
int treeNum = Integer.valueOf(treeNumObj.toString());
if (treeNum <= 0 || treeNum > 10000) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("TreeNum should be in [1, 10000].");
result = ValidateResult.mergeResult(result, tmpResult);
}
} else {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("'TreeNum' parameter isn't being set in train#parameters in GBT/RF training.");
result = ValidateResult.mergeResult(result, tmpResult);
}
Object minInfoGainObj = params.get("MinInfoGain");
if (minInfoGainObj != null) {
Double minInfoGain = Double.valueOf(minInfoGainObj.toString());
if (minInfoGain != null && (minInfoGain.compareTo(Double.valueOf(0)) < 0)) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("MinInfoGain should be >= 0.");
result = ValidateResult.mergeResult(result, tmpResult);
}
} else {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("'MinInfoGain' parameter isn't be set in train#parameters in GBT/RF training.");
result = ValidateResult.mergeResult(result, tmpResult);
}
Object impurityObj = params.get("Impurity");
if (impurityObj == null) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("Impurity is not set in RF/GBT algorithm.");
result = ValidateResult.mergeResult(result, tmpResult);
} else {
if (train.getAlgorithm().equalsIgnoreCase(CommonConstants.GBT_ALG_NAME)) {
if (impurityObj != null && !"variance".equalsIgnoreCase(impurityObj.toString()) && !"friedmanmse".equalsIgnoreCase(impurityObj.toString())) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("GBDT only supports 'variance|friedmanmse' impurity type.");
result = ValidateResult.mergeResult(result, tmpResult);
}
}
if (train.getAlgorithm().equalsIgnoreCase(CommonConstants.RF_ALG_NAME)) {
if (impurityObj != null && !"friedmanmse".equalsIgnoreCase(impurityObj.toString()) && !"entropy".equalsIgnoreCase(impurityObj.toString()) && !"variance".equalsIgnoreCase(impurityObj.toString()) && !"gini".equalsIgnoreCase(impurityObj.toString())) {
ValidateResult tmpResult = new ValidateResult(true);
tmpResult.setStatus(false);
tmpResult.getCauses().add("RF supports 'variance|entropy|gini|friedmanmse' impurity types.");
result = ValidateResult.mergeResult(result, tmpResult);
}
}
}
}
}
return result;
}
use of ml.shifu.shifu.core.dtrain.FeatureSubsetStrategy in project shifu by ShifuML.
the class TrainModelProcessor method runDistributedTrain.
protected int runDistributedTrain() throws IOException, InterruptedException, ClassNotFoundException {
LOG.info("Started {}distributed training.", isDryTrain ? "dry " : "");
int status = 0;
Configuration conf = new Configuration();
SourceType sourceType = super.getModelConfig().getDataSet().getSource();
final List<String> args = new ArrayList<String>();
GridSearch gs = new GridSearch(modelConfig.getTrain().getParams(), modelConfig.getTrain().getGridConfigFileContent());
prepareCommonParams(gs.hasHyperParam(), args, sourceType);
String alg = super.getModelConfig().getTrain().getAlgorithm();
// add tmp models folder to config
FileSystem fileSystem = ShifuFileUtils.getFileSystemBySourceType(sourceType);
Path tmpModelsPath = fileSystem.makeQualified(new Path(super.getPathFinder().getPathBySourceType(new Path(Constants.TMP, Constants.DEFAULT_MODELS_TMP_FOLDER), sourceType)));
args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_TMP_MODELS_FOLDER, tmpModelsPath.toString()));
int baggingNum = isForVarSelect ? 1 : super.getModelConfig().getBaggingNum();
if (modelConfig.isClassification()) {
int classes = modelConfig.getTags().size();
if (classes == 2) {
// binary classification, only need one job
baggingNum = 1;
} else {
if (modelConfig.getTrain().isOneVsAll()) {
// one vs all multiple classification, we need multiple bagging jobs to do ONEVSALL
baggingNum = modelConfig.getTags().size();
} else {
// native classification, using bagging from setting job, no need set here
}
}
if (baggingNum != super.getModelConfig().getBaggingNum()) {
LOG.warn("'train:baggingNum' is set to {} because of ONEVSALL multiple classification.", baggingNum);
}
}
boolean isKFoldCV = false;
Integer kCrossValidation = this.modelConfig.getTrain().getNumKFold();
if (kCrossValidation != null && kCrossValidation > 0) {
isKFoldCV = true;
baggingNum = modelConfig.getTrain().getNumKFold();
if (baggingNum != super.getModelConfig().getBaggingNum() && gs.hasHyperParam()) {
// if it is grid search mode, then kfold mode is disabled
LOG.warn("'train:baggingNum' is set to {} because of k-fold cross validation is enabled by 'numKFold' not -1.", baggingNum);
}
}
long start = System.currentTimeMillis();
boolean isParallel = Boolean.valueOf(Environment.getProperty(Constants.SHIFU_DTRAIN_PARALLEL, SHIFU_DEFAULT_DTRAIN_PARALLEL)).booleanValue();
GuaguaMapReduceClient guaguaClient;
int[] inputOutputIndex = DTrainUtils.getInputOutputCandidateCounts(modelConfig.getNormalizeType(), this.columnConfigList);
int inputNodeCount = inputOutputIndex[0] == 0 ? inputOutputIndex[2] : inputOutputIndex[0];
int candidateCount = inputOutputIndex[2];
boolean isAfterVarSelect = (inputOutputIndex[0] != 0);
// cache all feature list for sampling features
List<Integer> allFeatures = NormalUtils.getAllFeatureList(this.columnConfigList, isAfterVarSelect);
if (modelConfig.getNormalize().getIsParquet()) {
guaguaClient = new GuaguaParquetMapReduceClient();
// set required field list to make sure we only load selected columns.
RequiredFieldList requiredFieldList = new RequiredFieldList();
boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
for (ColumnConfig columnConfig : super.columnConfigList) {
if (columnConfig.isTarget()) {
requiredFieldList.add(new RequiredField(columnConfig.getColumnName(), columnConfig.getColumnNum(), null, DataType.FLOAT));
} else {
if (inputNodeCount == candidateCount) {
// no any variables are selected
if (!columnConfig.isMeta() && !columnConfig.isTarget() && CommonUtils.isGoodCandidate(columnConfig, hasCandidates)) {
requiredFieldList.add(new RequiredField(columnConfig.getColumnName(), columnConfig.getColumnNum(), null, DataType.FLOAT));
}
} else {
if (!columnConfig.isMeta() && !columnConfig.isTarget() && columnConfig.isFinalSelect()) {
requiredFieldList.add(new RequiredField(columnConfig.getColumnName(), columnConfig.getColumnNum(), null, DataType.FLOAT));
}
}
}
}
// weight is added manually
requiredFieldList.add(new RequiredField("weight", columnConfigList.size(), null, DataType.DOUBLE));
args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, "parquet.private.pig.required.fields", serializeRequiredFieldList(requiredFieldList)));
args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, "parquet.private.pig.column.index.access", "true"));
} else {
guaguaClient = new GuaguaMapReduceClient();
}
int parallelNum = Integer.parseInt(Environment.getProperty(CommonConstants.SHIFU_TRAIN_BAGGING_INPARALLEL, "5"));
int parallelGroups = 1;
if (gs.hasHyperParam()) {
parallelGroups = (gs.getFlattenParams().size() % parallelNum == 0 ? gs.getFlattenParams().size() / parallelNum : gs.getFlattenParams().size() / parallelNum + 1);
baggingNum = gs.getFlattenParams().size();
LOG.warn("'train:baggingNum' is set to {} because of grid search enabled by settings in 'train#params'.", gs.getFlattenParams().size());
} else {
parallelGroups = baggingNum % parallelNum == 0 ? baggingNum / parallelNum : baggingNum / parallelNum + 1;
}
LOG.info("Distributed trainning with baggingNum: {}", baggingNum);
List<String> progressLogList = new ArrayList<String>(baggingNum);
boolean isOneJobNotContinuous = false;
for (int j = 0; j < parallelGroups; j++) {
int currBags = baggingNum;
if (gs.hasHyperParam()) {
if (j == parallelGroups - 1) {
currBags = gs.getFlattenParams().size() % parallelNum == 0 ? parallelNum : gs.getFlattenParams().size() % parallelNum;
} else {
currBags = parallelNum;
}
} else {
if (j == parallelGroups - 1) {
currBags = baggingNum % parallelNum == 0 ? parallelNum : baggingNum % parallelNum;
} else {
currBags = parallelNum;
}
}
for (int k = 0; k < currBags; k++) {
int i = j * parallelNum + k;
if (gs.hasHyperParam()) {
LOG.info("Start the {}th grid search job with params: {}", i, gs.getParams(i));
} else if (isKFoldCV) {
LOG.info("Start the {}th k-fold cross validation job with params.", i);
}
List<String> localArgs = new ArrayList<String>(args);
// set name for each bagging job.
localArgs.add("-n");
localArgs.add(String.format("Shifu Master-Workers %s Training Iteration: %s id:%s", alg, super.getModelConfig().getModelSetName(), i));
LOG.info("Start trainer with id: {}", i);
String modelName = getModelName(i);
Path modelPath = fileSystem.makeQualified(new Path(super.getPathFinder().getModelsPath(sourceType), modelName));
Path bModelPath = fileSystem.makeQualified(new Path(super.getPathFinder().getNNBinaryModelsPath(sourceType), modelName));
// check if job is continuous training, this can be set multiple times and we only get last one
boolean isContinuous = false;
if (gs.hasHyperParam()) {
isContinuous = false;
} else {
int intContinuous = checkContinuousTraining(fileSystem, localArgs, modelPath, modelConfig.getTrain().getParams());
if (intContinuous == -1) {
LOG.warn("Model with index {} with size of trees is over treeNum, such training will not be started.", i);
continue;
} else {
isContinuous = (intContinuous == 1);
}
}
// training
if (gs.hasHyperParam() || isKFoldCV) {
isContinuous = false;
}
if (!isContinuous && !isOneJobNotContinuous) {
isOneJobNotContinuous = true;
// delete all old models if not continuous
String srcModelPath = super.getPathFinder().getModelsPath(sourceType);
String mvModelPath = srcModelPath + "_" + System.currentTimeMillis();
LOG.info("Old model path has been moved to {}", mvModelPath);
fileSystem.rename(new Path(srcModelPath), new Path(mvModelPath));
fileSystem.mkdirs(new Path(srcModelPath));
FileSystem.getLocal(conf).delete(new Path(super.getPathFinder().getModelsPath(SourceType.LOCAL)), true);
}
if (NNConstants.NN_ALG_NAME.equalsIgnoreCase(alg)) {
// tree related parameters initialization
Map<String, Object> params = gs.hasHyperParam() ? gs.getParams(i) : this.modelConfig.getTrain().getParams();
Object fssObj = params.get("FeatureSubsetStrategy");
FeatureSubsetStrategy featureSubsetStrategy = null;
double featureSubsetRate = 0d;
if (fssObj != null) {
try {
featureSubsetRate = Double.parseDouble(fssObj.toString());
// no need validate featureSubsetRate is in (0,1], as already validated in ModelInspector
featureSubsetStrategy = null;
} catch (NumberFormatException ee) {
featureSubsetStrategy = FeatureSubsetStrategy.of(fssObj.toString());
}
} else {
LOG.warn("FeatureSubsetStrategy is not set, set to ALL by default.");
featureSubsetStrategy = FeatureSubsetStrategy.ALL;
featureSubsetRate = 0;
}
Set<Integer> subFeatures = null;
if (isContinuous) {
BasicFloatNetwork existingModel = (BasicFloatNetwork) ModelSpecLoaderUtils.getBasicNetwork(ModelSpecLoaderUtils.loadModel(modelConfig, modelPath, ShifuFileUtils.getFileSystemBySourceType(this.modelConfig.getDataSet().getSource())));
if (existingModel == null) {
subFeatures = new HashSet<Integer>(getSubsamplingFeatures(allFeatures, featureSubsetStrategy, featureSubsetRate, inputNodeCount));
} else {
subFeatures = existingModel.getFeatureSet();
}
} else {
subFeatures = new HashSet<Integer>(getSubsamplingFeatures(allFeatures, featureSubsetStrategy, featureSubsetRate, inputNodeCount));
}
if (subFeatures == null || subFeatures.size() == 0) {
localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_NN_FEATURE_SUBSET, ""));
} else {
localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_NN_FEATURE_SUBSET, StringUtils.join(subFeatures, ',')));
LOG.debug("Size: {}, list: {}.", subFeatures.size(), StringUtils.join(subFeatures, ','));
}
}
localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.GUAGUA_OUTPUT, modelPath.toString()));
localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, Constants.SHIFU_NN_BINARY_MODEL_PATH, bModelPath.toString()));
if (gs.hasHyperParam() || isKFoldCV) {
// k-fold cv need val error
Path valErrPath = fileSystem.makeQualified(new Path(super.getPathFinder().getValErrorPath(sourceType), "val_error_" + i));
localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.GS_VALIDATION_ERROR, valErrPath.toString()));
}
localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_TRAINER_ID, String.valueOf(i)));
final String progressLogFile = getProgressLogFile(i);
progressLogList.add(progressLogFile);
localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_DTRAIN_PROGRESS_FILE, progressLogFile));
String hdpVersion = HDPUtils.getHdpVersionForHDP224();
if (StringUtils.isNotBlank(hdpVersion)) {
localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, "hdp.version", hdpVersion));
HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("hdfs-site.xml"), conf);
HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("core-site.xml"), conf);
HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("mapred-site.xml"), conf);
HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("yarn-site.xml"), conf);
}
if (isParallel) {
guaguaClient.addJob(localArgs.toArray(new String[0]));
} else {
TailThread tailThread = startTailThread(new String[] { progressLogFile });
boolean ret = guaguaClient.createJob(localArgs.toArray(new String[0])).waitForCompletion(true);
status += (ret ? 0 : 1);
stopTailThread(tailThread);
}
}
if (isParallel) {
TailThread tailThread = startTailThread(progressLogList.toArray(new String[0]));
status += guaguaClient.run();
stopTailThread(tailThread);
}
}
if (isKFoldCV) {
// k-fold we also copy model files at last, such models can be used for evaluation
for (int i = 0; i < baggingNum; i++) {
String modelName = getModelName(i);
Path modelPath = fileSystem.makeQualified(new Path(super.getPathFinder().getModelsPath(sourceType), modelName));
if (ShifuFileUtils.getFileSystemBySourceType(sourceType).exists(modelPath)) {
copyModelToLocal(modelName, modelPath, sourceType);
} else {
LOG.warn("Model {} isn't there, maybe job is failed, for bagging it can be ignored.", modelPath.toString());
status += 1;
}
}
List<Double> valErrs = readAllValidationErrors(sourceType, fileSystem, kCrossValidation);
double sum = 0d;
for (Double err : valErrs) {
sum += err;
}
LOG.info("Average validation error for current k-fold cross validation is {}.", sum / valErrs.size());
LOG.info("K-fold cross validation on distributed training finished in {}ms.", System.currentTimeMillis() - start);
} else if (gs.hasHyperParam()) {
// select the best parameter composite in grid search
LOG.info("Original grid search params: {}", modelConfig.getParams());
Map<String, Object> params = findBestParams(sourceType, fileSystem, gs);
// temp copy all models for evaluation
for (int i = 0; i < baggingNum; i++) {
String modelName = getModelName(i);
Path modelPath = fileSystem.makeQualified(new Path(super.getPathFinder().getModelsPath(sourceType), modelName));
if (ShifuFileUtils.getFileSystemBySourceType(sourceType).exists(modelPath) && (status == 0)) {
copyModelToLocal(modelName, modelPath, sourceType);
} else {
LOG.warn("Model {} isn't there, maybe job is failed, for bagging it can be ignored.", modelPath.toString());
}
}
LOG.info("The best parameters in grid search is {}", params);
LOG.info("Grid search on distributed training finished in {}ms.", System.currentTimeMillis() - start);
} else {
// copy model files at last.
for (int i = 0; i < baggingNum; i++) {
String modelName = getModelName(i);
Path modelPath = fileSystem.makeQualified(new Path(super.getPathFinder().getModelsPath(sourceType), modelName));
if (ShifuFileUtils.getFileSystemBySourceType(sourceType).exists(modelPath) && (status == 0)) {
copyModelToLocal(modelName, modelPath, sourceType);
} else {
LOG.warn("Model {} isn't there, maybe job is failed, for bagging it can be ignored.", modelPath.toString());
}
}
// copy temp model files, for RF/GBT, not to copy tmp models because of larger space needed, for others
// by default copy tmp models to local
boolean copyTmpModelsToLocal = Boolean.TRUE.toString().equalsIgnoreCase(Environment.getProperty(Constants.SHIFU_TMPMODEL_COPYTOLOCAL, "true"));
if (copyTmpModelsToLocal) {
copyTmpModelsToLocal(tmpModelsPath, sourceType);
} else {
LOG.info("Tmp models are not copied into local, please find them in hdfs path: {}", tmpModelsPath);
}
LOG.info("Distributed training finished in {}ms.", System.currentTimeMillis() - start);
}
if (CommonUtils.isTreeModel(modelConfig.getAlgorithm())) {
List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(this.modelConfig, null);
// compute feature importance and write to local file after models are trained
Map<Integer, MutablePair<String, Double>> featureImportances = CommonUtils.computeTreeModelFeatureImportance(models);
String localFsFolder = pathFinder.getLocalFeatureImportanceFolder();
String localFIPath = pathFinder.getLocalFeatureImportancePath();
processRollupForFIFiles(localFsFolder, localFIPath);
CommonUtils.writeFeatureImportance(localFIPath, featureImportances);
}
if (status != 0) {
LOG.error("Error may occurred. There is no model generated. Please check!");
}
return status;
}
Aggregations