Search in sources :

Example 1 with GridSearch

use of ml.shifu.shifu.core.dtrain.gs.GridSearch in project shifu by ShifuML.

the class DTMaster method init.

@Override
public void init(MasterContext<DTMasterParams, DTWorkerParams> context) {
    Properties props = context.getProps();
    // init model config and column config list at first
    SourceType sourceType;
    try {
        sourceType = SourceType.valueOf(props.getProperty(CommonConstants.MODELSET_SOURCE_TYPE, SourceType.HDFS.toString()));
        this.modelConfig = CommonUtils.loadModelConfig(props.getProperty(CommonConstants.SHIFU_MODEL_CONFIG), sourceType);
        this.columnConfigList = CommonUtils.loadColumnConfigList(props.getProperty(CommonConstants.SHIFU_COLUMN_CONFIG), sourceType);
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
    // worker number is used to estimate nodes per iteration for stats
    this.workerNumber = NumberFormatUtils.getInt(props.getProperty(GuaguaConstants.GUAGUA_WORKER_NUMBER), true);
    // check if variables are set final selected
    int[] inputOutputIndex = DTrainUtils.getNumericAndCategoricalInputAndOutputCounts(this.columnConfigList);
    this.inputNum = inputOutputIndex[0] + inputOutputIndex[1];
    this.isAfterVarSelect = (inputOutputIndex[3] == 1);
    // cache all feature list for sampling features
    this.allFeatures = this.getAllFeatureList(columnConfigList, isAfterVarSelect);
    int trainerId = Integer.valueOf(context.getProps().getProperty(CommonConstants.SHIFU_TRAINER_ID, "0"));
    // If grid search, select valid paramters, if not parameters is what in ModelConfig.json
    GridSearch gs = new GridSearch(modelConfig.getTrain().getParams(), modelConfig.getTrain().getGridConfigFileContent());
    Map<String, Object> validParams = this.modelConfig.getTrain().getParams();
    if (gs.hasHyperParam()) {
        validParams = gs.getParams(trainerId);
        LOG.info("Start grid search master with params: {}", validParams);
    }
    Object vtObj = validParams.get("ValidationTolerance");
    if (vtObj != null) {
        try {
            validationTolerance = Double.parseDouble(vtObj.toString());
            LOG.warn("Validation by tolerance is enabled with value {}.", validationTolerance);
        } catch (NumberFormatException ee) {
            validationTolerance = 0d;
            LOG.warn("Validation by tolerance isn't enabled because of non numerical value of ValidationTolerance: {}.", vtObj);
        }
    } else {
        LOG.warn("Validation by tolerance isn't enabled.");
    }
    // tree related parameters initialization
    Object fssObj = validParams.get("FeatureSubsetStrategy");
    if (fssObj != null) {
        try {
            this.featureSubsetRate = Double.parseDouble(fssObj.toString());
            // no need validate featureSubsetRate is in (0,1], as already validated in ModelInspector
            this.featureSubsetStrategy = null;
        } catch (NumberFormatException ee) {
            this.featureSubsetStrategy = FeatureSubsetStrategy.of(fssObj.toString());
        }
    } else {
        LOG.warn("FeatureSubsetStrategy is not set, set to TWOTHRIDS by default in DTMaster.");
        this.featureSubsetStrategy = FeatureSubsetStrategy.TWOTHIRDS;
        this.featureSubsetRate = 0;
    }
    // max depth
    Object maxDepthObj = validParams.get("MaxDepth");
    if (maxDepthObj != null) {
        this.maxDepth = Integer.valueOf(maxDepthObj.toString());
    } else {
        this.maxDepth = 10;
    }
    // max leaves which is used for leaf-wised tree building, TODO add more benchmarks
    Object maxLeavesObj = validParams.get("MaxLeaves");
    if (maxLeavesObj != null) {
        this.maxLeaves = Integer.valueOf(maxLeavesObj.toString());
    } else {
        this.maxLeaves = -1;
    }
    // enable leaf wise tree building once maxLeaves is configured
    if (this.maxLeaves > 0) {
        this.isLeafWise = true;
    }
    // maxBatchSplitSize means each time split # of batch nodes
    Object maxBatchSplitSizeObj = validParams.get("MaxBatchSplitSize");
    if (maxBatchSplitSizeObj != null) {
        this.maxBatchSplitSize = Integer.valueOf(maxBatchSplitSizeObj.toString());
    } else {
        // by default split 32 at most in a batch
        this.maxBatchSplitSize = 32;
    }
    assert this.maxDepth > 0 && this.maxDepth <= 20;
    // hide in parameters, this to avoid OOM issue for each iteration
    Object maxStatsMemoryMB = validParams.get("MaxStatsMemoryMB");
    if (maxStatsMemoryMB != null) {
        this.maxStatsMemory = Long.valueOf(validParams.get("MaxStatsMemoryMB").toString()) * 1024 * 1024;
        if (this.maxStatsMemory > ((2L * Runtime.getRuntime().maxMemory()) / 3)) {
            // if >= 2/3 max memory, take 2/3 max memory to avoid OOM
            this.maxStatsMemory = ((2L * Runtime.getRuntime().maxMemory()) / 3);
        }
    } else {
        // by default it is 1/2 of heap, about 1.5G setting in current Shifu
        this.maxStatsMemory = Runtime.getRuntime().maxMemory() / 2L;
    }
    // assert this.maxStatsMemory <= Math.min(Runtime.getRuntime().maxMemory() * 0.6, 800 * 1024 * 1024L);
    this.treeNum = Integer.valueOf(validParams.get("TreeNum").toString());
    this.isRF = ALGORITHM.RF.toString().equalsIgnoreCase(modelConfig.getAlgorithm());
    this.isGBDT = ALGORITHM.GBT.toString().equalsIgnoreCase(modelConfig.getAlgorithm());
    if (this.isGBDT) {
        // learning rate only effective in gbdt
        this.learningRate = Double.valueOf(validParams.get(CommonConstants.LEARNING_RATE).toString());
    }
    // initialize impurity type according to regression or classfication
    String imStr = validParams.get("Impurity").toString();
    int numClasses = 2;
    if (this.modelConfig.isClassification()) {
        numClasses = this.modelConfig.getTags().size();
    }
    // these two parameters is to stop tree growth parameters
    int minInstancesPerNode = Integer.valueOf(validParams.get("MinInstancesPerNode").toString());
    double minInfoGain = Double.valueOf(validParams.get("MinInfoGain").toString());
    if (imStr.equalsIgnoreCase("entropy")) {
        impurity = new Entropy(numClasses, minInstancesPerNode, minInfoGain);
    } else if (imStr.equalsIgnoreCase("gini")) {
        impurity = new Gini(numClasses, minInstancesPerNode, minInfoGain);
    } else {
        impurity = new Variance(minInstancesPerNode, minInfoGain);
    }
    // checkpoint folder and interval (every # iterations to do checkpoint)
    this.checkpointInterval = NumberFormatUtils.getInt(context.getProps().getProperty(CommonConstants.SHIFU_DT_MASTER_CHECKPOINT_INTERVAL, "20"));
    this.checkpointOutput = new Path(context.getProps().getProperty(CommonConstants.SHIFU_DT_MASTER_CHECKPOINT_FOLDER, "tmp/cp_" + context.getAppId()));
    // cache conf to avoid new
    this.conf = new Configuration();
    // if continuous model training is enabled
    this.isContinuousEnabled = Boolean.TRUE.toString().equalsIgnoreCase(context.getProps().getProperty(CommonConstants.CONTINUOUS_TRAINING));
    this.dtEarlyStopDecider = new DTEarlyStopDecider(this.maxDepth);
    if (validParams.containsKey("EnableEarlyStop") && Boolean.valueOf(validParams.get("EnableEarlyStop").toString().toLowerCase())) {
        this.enableEarlyStop = true;
    }
    LOG.info("Master init params: isAfterVarSel={}, featureSubsetStrategy={}, featureSubsetRate={} maxDepth={}, maxStatsMemory={}, " + "treeNum={}, impurity={}, workerNumber={}, minInstancesPerNode={}, minInfoGain={}, isRF={}, " + "isGBDT={}, isContinuousEnabled={}, enableEarlyStop={}.", isAfterVarSelect, featureSubsetStrategy, this.featureSubsetRate, maxDepth, maxStatsMemory, treeNum, imStr, this.workerNumber, minInstancesPerNode, minInfoGain, this.isRF, this.isGBDT, this.isContinuousEnabled, this.enableEarlyStop);
    this.toDoQueue = new LinkedList<TreeNode>();
    if (this.isLeafWise) {
        this.toSplitQueue = new PriorityQueue<TreeNode>(64, new Comparator<TreeNode>() {

            @Override
            public int compare(TreeNode o1, TreeNode o2) {
                return Double.compare(o2.getNode().getWgtCntRatio() * o2.getNode().getGain(), o1.getNode().getWgtCntRatio() * o1.getNode().getGain());
            }
        });
    }
    // initialize trees
    if (context.isFirstIteration()) {
        if (this.isRF) {
            // for random forest, trees are trained in parallel
            this.trees = new CopyOnWriteArrayList<TreeNode>();
            for (int i = 0; i < treeNum; i++) {
                this.trees.add(new TreeNode(i, new Node(Node.ROOT_INDEX), 1d));
            }
        }
        if (this.isGBDT) {
            if (isContinuousEnabled) {
                TreeModel existingModel;
                try {
                    Path modelPath = new Path(context.getProps().getProperty(CommonConstants.GUAGUA_OUTPUT));
                    existingModel = (TreeModel) ModelSpecLoaderUtils.loadModel(modelConfig, modelPath, ShifuFileUtils.getFileSystemBySourceType(this.modelConfig.getDataSet().getSource()));
                    if (existingModel == null) {
                        // null means no existing model file or model file is in wrong format
                        this.trees = new CopyOnWriteArrayList<TreeNode>();
                        // learning rate is 1 for 1st
                        this.trees.add(new TreeNode(0, new Node(Node.ROOT_INDEX), 1d));
                        LOG.info("Starting to train model from scratch and existing model is empty.");
                    } else {
                        this.trees = existingModel.getTrees();
                        this.existingTreeSize = this.trees.size();
                        // starting from existing models, first tree learning rate is current learning rate
                        this.trees.add(new TreeNode(this.existingTreeSize, new Node(Node.ROOT_INDEX), this.existingTreeSize == 0 ? 1d : this.learningRate));
                        LOG.info("Starting to train model from existing model {} with existing trees {}.", modelPath, existingTreeSize);
                    }
                } catch (IOException e) {
                    throw new GuaguaRuntimeException(e);
                }
            } else {
                this.trees = new CopyOnWriteArrayList<TreeNode>();
                // for GBDT, initialize the first tree. trees are trained sequentially,first tree learning rate is 1
                this.trees.add(new TreeNode(0, new Node(Node.ROOT_INDEX), 1.0d));
            }
        }
    } else {
        // recover all states once master is fail over
        LOG.info("Recover master status from checkpoint file {}", this.checkpointOutput);
        recoverMasterStatus(sourceType);
    }
}
Also used : Configuration(org.apache.hadoop.conf.Configuration) SourceType(ml.shifu.shifu.container.obj.RawSourceData.SourceType) Properties(java.util.Properties) Comparator(java.util.Comparator) TreeModel(ml.shifu.shifu.core.TreeModel) GuaguaRuntimeException(ml.shifu.guagua.GuaguaRuntimeException) Path(org.apache.hadoop.fs.Path) IOException(java.io.IOException) GridSearch(ml.shifu.shifu.core.dtrain.gs.GridSearch) GuaguaRuntimeException(ml.shifu.guagua.GuaguaRuntimeException)

Example 2 with GridSearch

use of ml.shifu.shifu.core.dtrain.gs.GridSearch in project shifu by ShifuML.

the class LogisticRegressionOutput method init.

private void init(MasterContext<LogisticRegressionParams, LogisticRegressionParams> context) {
    this.isDry = Boolean.TRUE.toString().equals(context.getProps().getProperty(CommonConstants.SHIFU_DRY_DTRAIN));
    if (this.isDry) {
        return;
    }
    if (isInit.compareAndSet(false, true)) {
        loadConfigFiles(context.getProps());
        this.trainerId = context.getProps().getProperty(CommonConstants.SHIFU_TRAINER_ID);
        this.tmpModelsFolder = context.getProps().getProperty(CommonConstants.SHIFU_TMP_MODELS_FOLDER);
        Integer kCrossValidation = this.modelConfig.getTrain().getNumKFold();
        if (kCrossValidation != null && kCrossValidation > 0) {
            isKFoldCV = true;
        }
        GridSearch gs = new GridSearch(modelConfig.getTrain().getParams(), modelConfig.getTrain().getGridConfigFileContent());
        this.isGsMode = gs.hasHyperParam();
    }
    try {
        Path progressLog = new Path(context.getProps().getProperty(CommonConstants.SHIFU_DTRAIN_PROGRESS_FILE));
        // we need to append the log, so that client console can get refreshed. Or console will appear stuck.
        if (ShifuFileUtils.isFileExists(progressLog, SourceType.HDFS)) {
            this.progressOutput = FileSystem.get(new Configuration()).append(progressLog);
        } else {
            this.progressOutput = FileSystem.get(new Configuration()).create(progressLog);
        }
    } catch (IOException e) {
        LOG.error("Error in create progress log:", e);
    }
}
Also used : Path(org.apache.hadoop.fs.Path) Configuration(org.apache.hadoop.conf.Configuration) IOException(java.io.IOException) GridSearch(ml.shifu.shifu.core.dtrain.gs.GridSearch)

Example 3 with GridSearch

use of ml.shifu.shifu.core.dtrain.gs.GridSearch in project shifu by ShifuML.

the class ModelInspector method checkTrainSetting.

/**
 * Check the setting for model training.
 * It will make sure (num_of_layers > 0
 * && num_of_layers = hidden_nodes_size
 * && num_of_layse = active_func_size)
 *
 * @param train
 *            - @ModelTrainConf to check
 * @return @ValidateResult
 */
@SuppressWarnings("unchecked")
private ValidateResult checkTrainSetting(ModelConfig modelConfig, ModelTrainConf train) {
    ValidateResult result = new ValidateResult(true);
    if (train.getBaggingNum() == null || train.getBaggingNum() < 0) {
        ValidateResult tmpResult = new ValidateResult(true);
        tmpResult.setStatus(false);
        tmpResult.getCauses().add("Bagging number should be greater than zero in train configuration");
        result = ValidateResult.mergeResult(result, tmpResult);
    }
    if (train.getNumKFold() != null && train.getNumKFold() > 20) {
        ValidateResult tmpResult = new ValidateResult(true);
        tmpResult.setStatus(false);
        tmpResult.getCauses().add("numKFold should be in (0, 20] or <=0 (not dp k-crossValidation)");
        result = ValidateResult.mergeResult(result, tmpResult);
    }
    if (train.getBaggingSampleRate() == null || train.getBaggingSampleRate().compareTo(Double.valueOf(0)) <= 0 || train.getBaggingSampleRate().compareTo(Double.valueOf(1)) > 0) {
        ValidateResult tmpResult = new ValidateResult(true);
        tmpResult.setStatus(false);
        tmpResult.getCauses().add("Bagging sample rate number should be in (0, 1].");
        result = ValidateResult.mergeResult(result, tmpResult);
    }
    if (train.getValidSetRate() == null || train.getValidSetRate().compareTo(Double.valueOf(0)) < 0 || train.getValidSetRate().compareTo(Double.valueOf(1)) >= 0) {
        ValidateResult tmpResult = new ValidateResult(true);
        tmpResult.setStatus(false);
        tmpResult.getCauses().add("Validation set rate number should be in [0, 1).");
        result = ValidateResult.mergeResult(result, tmpResult);
    }
    if (train.getNumTrainEpochs() == null || train.getNumTrainEpochs() <= 0) {
        ValidateResult tmpResult = new ValidateResult(true);
        tmpResult.setStatus(false);
        tmpResult.getCauses().add("Epochs should be larger than 0.");
        result = ValidateResult.mergeResult(result, tmpResult);
    }
    if (train.getEpochsPerIteration() != null && train.getEpochsPerIteration() <= 0) {
        ValidateResult tmpResult = new ValidateResult(true);
        tmpResult.setStatus(false);
        tmpResult.getCauses().add("'epochsPerIteration' should be larger than 0 if set.");
        result = ValidateResult.mergeResult(result, tmpResult);
    }
    if (train.getWorkerThreadCount() != null && (train.getWorkerThreadCount() <= 0 || train.getWorkerThreadCount() > 32)) {
        ValidateResult tmpResult = new ValidateResult(true);
        tmpResult.setStatus(false);
        tmpResult.getCauses().add("'workerThreadCount' should be in (0, 32] if set.");
        result = ValidateResult.mergeResult(result, tmpResult);
    }
    if (train.getConvergenceThreshold() != null && train.getConvergenceThreshold().compareTo(0.0) < 0) {
        ValidateResult tmpResult = new ValidateResult(true);
        tmpResult.setStatus(false);
        tmpResult.getCauses().add("'threshold' should be larger than or equal to 0.0 if set.");
        result = ValidateResult.mergeResult(result, tmpResult);
    }
    if (modelConfig.isClassification() && train.isOneVsAll() && !CommonUtils.isTreeModel(train.getAlgorithm()) && !train.getAlgorithm().equalsIgnoreCase("nn")) {
        ValidateResult tmpResult = new ValidateResult(true);
        tmpResult.setStatus(false);
        tmpResult.getCauses().add("'one vs all' or 'one vs rest' is only enabled with 'RF' or 'GBT' or 'NN' algorithm");
        result = ValidateResult.mergeResult(result, tmpResult);
    }
    if (modelConfig.isClassification() && train.getMultiClassifyMethod() == MultipleClassification.NATIVE && train.getAlgorithm().equalsIgnoreCase(CommonConstants.RF_ALG_NAME)) {
        Object impurity = train.getParams().get("Impurity");
        if (impurity != null && !"entropy".equalsIgnoreCase(impurity.toString()) && !"gini".equalsIgnoreCase(impurity.toString())) {
            ValidateResult tmpResult = new ValidateResult(true);
            tmpResult.setStatus(false);
            tmpResult.getCauses().add("Impurity should be in [entropy,gini] if native mutiple classification in RF.");
            result = ValidateResult.mergeResult(result, tmpResult);
        }
    }
    GridSearch gs = new GridSearch(train.getParams(), train.getGridConfigFileContent());
    // such parameter validation only in regression and not grid search mode
    if (modelConfig.isRegression() && !gs.hasHyperParam()) {
        if (train.getAlgorithm().equalsIgnoreCase("nn")) {
            Map<String, Object> params = train.getParams();
            Object loss = params.get("Loss");
            if (loss != null && !"log".equalsIgnoreCase(loss.toString()) && !"squared".equalsIgnoreCase(loss.toString()) && !"absolute".equalsIgnoreCase(loss.toString())) {
                ValidateResult tmpResult = new ValidateResult(true);
                tmpResult.setStatus(false);
                tmpResult.getCauses().add("Loss should be in [log,squared,absolute].");
                result = ValidateResult.mergeResult(result, tmpResult);
            }
            Object TFloss = params.get("TF.loss");
            if (TFloss != null && !"squared".equalsIgnoreCase(TFloss.toString()) && !"absolute".equalsIgnoreCase(TFloss.toString()) && !"log".equalsIgnoreCase(TFloss.toString())) {
                ValidateResult tmpResult = new ValidateResult(true);
                tmpResult.setStatus(false);
                tmpResult.getCauses().add("Loss should be in [log,squared,absolute].");
                result = ValidateResult.mergeResult(result, tmpResult);
            }
            Object TFOptimizer = params.get("TF.optimizer");
            if (TFOptimizer != null && !"adam".equalsIgnoreCase(TFOptimizer.toString()) && !"gradientDescent".equalsIgnoreCase(TFOptimizer.toString()) && !"RMSProp".equalsIgnoreCase(TFOptimizer.toString())) {
                ValidateResult tmpResult = new ValidateResult(true);
                tmpResult.setStatus(false);
                tmpResult.getCauses().add("tensorflow optimizer should be in [RMSProp,gradientDescent,adam].");
                result = ValidateResult.mergeResult(result, tmpResult);
            }
            int layerCnt = (Integer) params.get(CommonConstants.NUM_HIDDEN_LAYERS);
            if (layerCnt < 0) {
                ValidateResult tmpResult = new ValidateResult(true);
                tmpResult.setStatus(false);
                tmpResult.getCauses().add("The number of hidden layers should be >= 0 in train configuration");
                result = ValidateResult.mergeResult(result, tmpResult);
            }
            List<Integer> hiddenNode = (List<Integer>) params.get(CommonConstants.NUM_HIDDEN_NODES);
            List<String> activateFucs = (List<String>) params.get(CommonConstants.ACTIVATION_FUNC);
            if (hiddenNode.size() != activateFucs.size() || layerCnt != activateFucs.size()) {
                ValidateResult tmpResult = new ValidateResult(true);
                tmpResult.setStatus(false);
                tmpResult.getCauses().add(CommonConstants.NUM_HIDDEN_LAYERS + "/SIZE(" + CommonConstants.NUM_HIDDEN_NODES + ")" + "/SIZE(" + CommonConstants.ACTIVATION_FUNC + ")" + " should be equal in train configuration");
                result = ValidateResult.mergeResult(result, tmpResult);
            }
            Double learningRate = Double.valueOf(params.get(CommonConstants.LEARNING_RATE).toString());
            if (learningRate != null && (learningRate.compareTo(Double.valueOf(0)) <= 0)) {
                ValidateResult tmpResult = new ValidateResult(true);
                tmpResult.setStatus(false);
                tmpResult.getCauses().add("Learning rate should be larger than 0.");
                result = ValidateResult.mergeResult(result, tmpResult);
            }
            Object learningDecayO = params.get(CommonConstants.LEARNING_DECAY);
            if (learningDecayO != null) {
                Double learningDecay = Double.valueOf(learningDecayO.toString());
                if (learningDecay != null && ((learningDecay.compareTo(Double.valueOf(0)) < 0) || (learningDecay.compareTo(Double.valueOf(1)) >= 0))) {
                    ValidateResult tmpResult = new ValidateResult(true);
                    tmpResult.setStatus(false);
                    tmpResult.getCauses().add("Learning decay should be in [0, 1) if set.");
                    result = ValidateResult.mergeResult(result, tmpResult);
                }
            }
            Object dropoutObj = params.get(CommonConstants.DROPOUT_RATE);
            if (dropoutObj != null) {
                Double dropoutRate = Double.valueOf(dropoutObj.toString());
                if (dropoutRate != null && (dropoutRate < 0d || dropoutRate >= 1d)) {
                    ValidateResult tmpResult = new ValidateResult(true);
                    tmpResult.setStatus(false);
                    tmpResult.getCauses().add("Dropout rate should be in [0, 1).");
                    result = ValidateResult.mergeResult(result, tmpResult);
                }
            }
            Object fixedLayersObj = params.get(CommonConstants.FIXED_LAYERS);
            if (fixedLayersObj != null) {
                List<Integer> fixedLayers = (List<Integer>) fixedLayersObj;
                for (int layer : fixedLayers) {
                    if (layer <= 0 || layer > (layerCnt + 1)) {
                        ValidateResult tmpResult = new ValidateResult(true);
                        tmpResult.setStatus(false);
                        tmpResult.getCauses().add("Fixed layer id " + layer + " is invaild. It should be between 0 and hidden layer cnt +  output layer:" + (layerCnt + 1));
                        result = ValidateResult.mergeResult(result, tmpResult);
                    }
                }
            }
            Object miniBatchsO = params.get(CommonConstants.MINI_BATCH);
            if (miniBatchsO != null) {
                Integer miniBatchs = Integer.valueOf(miniBatchsO.toString());
                if (miniBatchs != null && (miniBatchs <= 0 || miniBatchs > 1000)) {
                    ValidateResult tmpResult = new ValidateResult(true);
                    tmpResult.setStatus(false);
                    tmpResult.getCauses().add("MiniBatchs should be in (0, 1000] if set.");
                    result = ValidateResult.mergeResult(result, tmpResult);
                }
            }
            Object momentumO = params.get("Momentum");
            if (momentumO != null) {
                Double momentum = Double.valueOf(momentumO.toString());
                if (momentum != null && momentum <= 0d) {
                    ValidateResult tmpResult = new ValidateResult(true);
                    tmpResult.setStatus(false);
                    tmpResult.getCauses().add("Momentum should be in (0, ) if set.");
                    result = ValidateResult.mergeResult(result, tmpResult);
                }
            }
            Object adamBeta1O = params.get("AdamBeta1");
            if (adamBeta1O != null) {
                Double adamBeta1 = Double.valueOf(adamBeta1O.toString());
                if (adamBeta1 != null && (adamBeta1 <= 0d || adamBeta1 >= 1d)) {
                    ValidateResult tmpResult = new ValidateResult(true);
                    tmpResult.setStatus(false);
                    tmpResult.getCauses().add("AdamBeta1 should be in (0, 1) if set.");
                    result = ValidateResult.mergeResult(result, tmpResult);
                }
            }
            Object adamBeta2O = params.get("AdamBeta2");
            if (adamBeta2O != null) {
                Double adamBeta2 = Double.valueOf(adamBeta2O.toString());
                if (adamBeta2 != null && (adamBeta2 <= 0d || adamBeta2 >= 1d)) {
                    ValidateResult tmpResult = new ValidateResult(true);
                    tmpResult.setStatus(false);
                    tmpResult.getCauses().add("AdamBeta2 should be in (0, 1) if set.");
                    result = ValidateResult.mergeResult(result, tmpResult);
                }
            }
        }
        if (train.getAlgorithm().equalsIgnoreCase(CommonConstants.GBT_ALG_NAME) || train.getAlgorithm().equalsIgnoreCase(CommonConstants.RF_ALG_NAME) || train.getAlgorithm().equalsIgnoreCase(NNConstants.NN_ALG_NAME)) {
            Map<String, Object> params = train.getParams();
            Object fssObj = params.get("FeatureSubsetStrategy");
            if (fssObj == null) {
                if (train.getAlgorithm().equalsIgnoreCase(CommonConstants.GBT_ALG_NAME) || train.getAlgorithm().equalsIgnoreCase(CommonConstants.RF_ALG_NAME)) {
                    ValidateResult tmpResult = new ValidateResult(true);
                    tmpResult.setStatus(false);
                    tmpResult.getCauses().add("FeatureSubsetStrategy is not set in RF/GBT algorithm.");
                    result = ValidateResult.mergeResult(result, tmpResult);
                }
            } else {
                boolean isNumber = false;
                double doubleFss = 0;
                try {
                    doubleFss = Double.parseDouble(fssObj.toString());
                    isNumber = true;
                } catch (Exception e) {
                    isNumber = false;
                }
                if (isNumber) {
                    // if not in [0, 1] failed
                    if (doubleFss <= 0d || doubleFss > 1d) {
                        ValidateResult tmpResult = new ValidateResult(true);
                        tmpResult.setStatus(false);
                        tmpResult.getCauses().add("FeatureSubsetStrategy if double should be in (0, 1]");
                        result = ValidateResult.mergeResult(result, tmpResult);
                    }
                } else {
                    boolean fssInEnum = false;
                    for (FeatureSubsetStrategy fss : FeatureSubsetStrategy.values()) {
                        if (fss.toString().equalsIgnoreCase(fssObj.toString())) {
                            fssInEnum = true;
                            break;
                        }
                    }
                    if (!fssInEnum) {
                        ValidateResult tmpResult = new ValidateResult(true);
                        tmpResult.setStatus(false);
                        tmpResult.getCauses().add("FeatureSubsetStrategy if string should be in ['ALL', 'HALF', 'ONETHIRD' , 'TWOTHIRDS' , 'AUTO' , 'SQRT' , 'LOG2']");
                        result = ValidateResult.mergeResult(result, tmpResult);
                    }
                }
            }
        }
        if (train.getAlgorithm().equalsIgnoreCase(CommonConstants.GBT_ALG_NAME) || train.getAlgorithm().equalsIgnoreCase(CommonConstants.RF_ALG_NAME)) {
            Map<String, Object> params = train.getParams();
            if (train.getAlgorithm().equalsIgnoreCase(CommonConstants.GBT_ALG_NAME)) {
                Object loss = params.get("Loss");
                if (loss != null && !"log".equalsIgnoreCase(loss.toString()) && !"squared".equalsIgnoreCase(loss.toString()) && !"halfgradsquared".equalsIgnoreCase(loss.toString()) && !"absolute".equalsIgnoreCase(loss.toString())) {
                    ValidateResult tmpResult = new ValidateResult(true);
                    tmpResult.setStatus(false);
                    tmpResult.getCauses().add("Loss should be in [log,squared,halfgradsquared,absolute].");
                    result = ValidateResult.mergeResult(result, tmpResult);
                }
                if (loss == null) {
                    ValidateResult tmpResult = new ValidateResult(true);
                    tmpResult.setStatus(false);
                    tmpResult.getCauses().add("'Loss' parameter isn't being set in train#parameters in GBT training.");
                    result = ValidateResult.mergeResult(result, tmpResult);
                }
            }
            Object maxDepthObj = params.get("MaxDepth");
            if (maxDepthObj != null) {
                int maxDepth = Integer.valueOf(maxDepthObj.toString());
                if (maxDepth <= 0 || maxDepth > 20) {
                    ValidateResult tmpResult = new ValidateResult(true);
                    tmpResult.setStatus(false);
                    tmpResult.getCauses().add("MaxDepth should in [1, 20].");
                    result = ValidateResult.mergeResult(result, tmpResult);
                }
            }
            Object vtObj = params.get("ValidationTolerance");
            if (vtObj != null) {
                double validationTolerance = Double.valueOf(vtObj.toString());
                if (validationTolerance < 0d || validationTolerance >= 1d) {
                    ValidateResult tmpResult = new ValidateResult(true);
                    tmpResult.setStatus(false);
                    tmpResult.getCauses().add("ValidationTolerance should in [0, 1).");
                    result = ValidateResult.mergeResult(result, tmpResult);
                }
            }
            Object maxLeavesObj = params.get("MaxLeaves");
            if (maxLeavesObj != null) {
                int maxLeaves = Integer.valueOf(maxLeavesObj.toString());
                if (maxLeaves <= 0) {
                    ValidateResult tmpResult = new ValidateResult(true);
                    tmpResult.setStatus(false);
                    tmpResult.getCauses().add("MaxLeaves should in [1, Integer.MAX_VALUE].");
                    result = ValidateResult.mergeResult(result, tmpResult);
                }
            }
            if (maxDepthObj == null && maxLeavesObj == null) {
                ValidateResult tmpResult = new ValidateResult(true);
                tmpResult.setStatus(false);
                tmpResult.getCauses().add("'MaxDepth' or 'MaxLeaves' parameters at least one of both should be set in train#parameters in GBT training.");
                result = ValidateResult.mergeResult(result, tmpResult);
            }
            Object maxStatsMemoryMBObj = params.get("MaxStatsMemoryMB");
            if (maxStatsMemoryMBObj != null) {
                int maxStatsMemoryMB = Integer.valueOf(maxStatsMemoryMBObj.toString());
                if (maxStatsMemoryMB <= 0) {
                    ValidateResult tmpResult = new ValidateResult(true);
                    tmpResult.setStatus(false);
                    tmpResult.getCauses().add("MaxStatsMemoryMB should > 0.");
                    result = ValidateResult.mergeResult(result, tmpResult);
                }
            }
            Object dropoutObj = params.get(CommonConstants.DROPOUT_RATE);
            if (dropoutObj != null) {
                Double dropoutRate = Double.valueOf(dropoutObj.toString());
                if (dropoutRate != null && (dropoutRate < 0d || dropoutRate >= 1d)) {
                    ValidateResult tmpResult = new ValidateResult(true);
                    tmpResult.setStatus(false);
                    tmpResult.getCauses().add("Dropout rate should be in [0, 1).");
                    result = ValidateResult.mergeResult(result, tmpResult);
                }
            }
            if (train.getAlgorithm().equalsIgnoreCase(CommonConstants.GBT_ALG_NAME)) {
                Object learningRateObj = params.get(CommonConstants.LEARNING_RATE);
                if (learningRateObj != null) {
                    Double learningRate = Double.valueOf(learningRateObj.toString());
                    if (learningRate != null && (learningRate.compareTo(Double.valueOf(0)) <= 0)) {
                        ValidateResult tmpResult = new ValidateResult(true);
                        tmpResult.setStatus(false);
                        tmpResult.getCauses().add("Learning rate should be larger than 0.");
                        result = ValidateResult.mergeResult(result, tmpResult);
                    }
                } else {
                    ValidateResult tmpResult = new ValidateResult(true);
                    tmpResult.setStatus(false);
                    tmpResult.getCauses().add("'LearningRate' parameter isn't being set in train#parameters in GBT training.");
                    result = ValidateResult.mergeResult(result, tmpResult);
                }
            }
            Object minInstancesPerNodeObj = params.get("MinInstancesPerNode");
            if (minInstancesPerNodeObj != null) {
                int minInstancesPerNode = Integer.valueOf(minInstancesPerNodeObj.toString());
                if (minInstancesPerNode <= 0) {
                    ValidateResult tmpResult = new ValidateResult(true);
                    tmpResult.setStatus(false);
                    tmpResult.getCauses().add("MinInstancesPerNode should > 0.");
                    result = ValidateResult.mergeResult(result, tmpResult);
                }
            } else {
                ValidateResult tmpResult = new ValidateResult(true);
                tmpResult.setStatus(false);
                tmpResult.getCauses().add("'MinInstancesPerNode' parameter isn't be set in train#parameters in GBT/RF training.");
                result = ValidateResult.mergeResult(result, tmpResult);
            }
            Object treeNumObj = params.get("TreeNum");
            if (treeNumObj != null) {
                int treeNum = Integer.valueOf(treeNumObj.toString());
                if (treeNum <= 0 || treeNum > 10000) {
                    ValidateResult tmpResult = new ValidateResult(true);
                    tmpResult.setStatus(false);
                    tmpResult.getCauses().add("TreeNum should be in [1, 10000].");
                    result = ValidateResult.mergeResult(result, tmpResult);
                }
            } else {
                ValidateResult tmpResult = new ValidateResult(true);
                tmpResult.setStatus(false);
                tmpResult.getCauses().add("'TreeNum' parameter isn't being set in train#parameters in GBT/RF training.");
                result = ValidateResult.mergeResult(result, tmpResult);
            }
            Object minInfoGainObj = params.get("MinInfoGain");
            if (minInfoGainObj != null) {
                Double minInfoGain = Double.valueOf(minInfoGainObj.toString());
                if (minInfoGain != null && (minInfoGain.compareTo(Double.valueOf(0)) < 0)) {
                    ValidateResult tmpResult = new ValidateResult(true);
                    tmpResult.setStatus(false);
                    tmpResult.getCauses().add("MinInfoGain should be >= 0.");
                    result = ValidateResult.mergeResult(result, tmpResult);
                }
            } else {
                ValidateResult tmpResult = new ValidateResult(true);
                tmpResult.setStatus(false);
                tmpResult.getCauses().add("'MinInfoGain' parameter isn't be set in train#parameters in GBT/RF training.");
                result = ValidateResult.mergeResult(result, tmpResult);
            }
            Object impurityObj = params.get("Impurity");
            if (impurityObj == null) {
                ValidateResult tmpResult = new ValidateResult(true);
                tmpResult.setStatus(false);
                tmpResult.getCauses().add("Impurity is not set in RF/GBT algorithm.");
                result = ValidateResult.mergeResult(result, tmpResult);
            } else {
                if (train.getAlgorithm().equalsIgnoreCase(CommonConstants.GBT_ALG_NAME)) {
                    if (impurityObj != null && !"variance".equalsIgnoreCase(impurityObj.toString()) && !"friedmanmse".equalsIgnoreCase(impurityObj.toString())) {
                        ValidateResult tmpResult = new ValidateResult(true);
                        tmpResult.setStatus(false);
                        tmpResult.getCauses().add("GBDT only supports 'variance|friedmanmse' impurity type.");
                        result = ValidateResult.mergeResult(result, tmpResult);
                    }
                }
                if (train.getAlgorithm().equalsIgnoreCase(CommonConstants.RF_ALG_NAME)) {
                    if (impurityObj != null && !"friedmanmse".equalsIgnoreCase(impurityObj.toString()) && !"entropy".equalsIgnoreCase(impurityObj.toString()) && !"variance".equalsIgnoreCase(impurityObj.toString()) && !"gini".equalsIgnoreCase(impurityObj.toString())) {
                        ValidateResult tmpResult = new ValidateResult(true);
                        tmpResult.setStatus(false);
                        tmpResult.getCauses().add("RF supports 'variance|entropy|gini|friedmanmse' impurity types.");
                        result = ValidateResult.mergeResult(result, tmpResult);
                    }
                }
            }
        }
    }
    return result;
}
Also used : ValidateResult(ml.shifu.shifu.container.meta.ValidateResult) FeatureSubsetStrategy(ml.shifu.shifu.core.dtrain.FeatureSubsetStrategy) IOException(java.io.IOException) GridSearch(ml.shifu.shifu.core.dtrain.gs.GridSearch) List(java.util.List)

Example 4 with GridSearch

use of ml.shifu.shifu.core.dtrain.gs.GridSearch in project shifu by ShifuML.

the class TrainModelProcessor method runDistributedTrain.

protected int runDistributedTrain() throws IOException, InterruptedException, ClassNotFoundException {
    LOG.info("Started {}distributed training.", isDryTrain ? "dry " : "");
    int status = 0;
    Configuration conf = new Configuration();
    SourceType sourceType = super.getModelConfig().getDataSet().getSource();
    final List<String> args = new ArrayList<String>();
    GridSearch gs = new GridSearch(modelConfig.getTrain().getParams(), modelConfig.getTrain().getGridConfigFileContent());
    prepareCommonParams(gs.hasHyperParam(), args, sourceType);
    String alg = super.getModelConfig().getTrain().getAlgorithm();
    // add tmp models folder to config
    FileSystem fileSystem = ShifuFileUtils.getFileSystemBySourceType(sourceType);
    Path tmpModelsPath = fileSystem.makeQualified(new Path(super.getPathFinder().getPathBySourceType(new Path(Constants.TMP, Constants.DEFAULT_MODELS_TMP_FOLDER), sourceType)));
    args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_TMP_MODELS_FOLDER, tmpModelsPath.toString()));
    int baggingNum = isForVarSelect ? 1 : super.getModelConfig().getBaggingNum();
    if (modelConfig.isClassification()) {
        int classes = modelConfig.getTags().size();
        if (classes == 2) {
            // binary classification, only need one job
            baggingNum = 1;
        } else {
            if (modelConfig.getTrain().isOneVsAll()) {
                // one vs all multiple classification, we need multiple bagging jobs to do ONEVSALL
                baggingNum = modelConfig.getTags().size();
            } else {
            // native classification, using bagging from setting job, no need set here
            }
        }
        if (baggingNum != super.getModelConfig().getBaggingNum()) {
            LOG.warn("'train:baggingNum' is set to {} because of ONEVSALL multiple classification.", baggingNum);
        }
    }
    boolean isKFoldCV = false;
    Integer kCrossValidation = this.modelConfig.getTrain().getNumKFold();
    if (kCrossValidation != null && kCrossValidation > 0) {
        isKFoldCV = true;
        baggingNum = modelConfig.getTrain().getNumKFold();
        if (baggingNum != super.getModelConfig().getBaggingNum() && gs.hasHyperParam()) {
            // if it is grid search mode, then kfold mode is disabled
            LOG.warn("'train:baggingNum' is set to {} because of k-fold cross validation is enabled by 'numKFold' not -1.", baggingNum);
        }
    }
    long start = System.currentTimeMillis();
    boolean isParallel = Boolean.valueOf(Environment.getProperty(Constants.SHIFU_DTRAIN_PARALLEL, SHIFU_DEFAULT_DTRAIN_PARALLEL)).booleanValue();
    GuaguaMapReduceClient guaguaClient;
    int[] inputOutputIndex = DTrainUtils.getInputOutputCandidateCounts(modelConfig.getNormalizeType(), this.columnConfigList);
    int inputNodeCount = inputOutputIndex[0] == 0 ? inputOutputIndex[2] : inputOutputIndex[0];
    int candidateCount = inputOutputIndex[2];
    boolean isAfterVarSelect = (inputOutputIndex[0] != 0);
    // cache all feature list for sampling features
    List<Integer> allFeatures = NormalUtils.getAllFeatureList(this.columnConfigList, isAfterVarSelect);
    if (modelConfig.getNormalize().getIsParquet()) {
        guaguaClient = new GuaguaParquetMapReduceClient();
        // set required field list to make sure we only load selected columns.
        RequiredFieldList requiredFieldList = new RequiredFieldList();
        boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
        for (ColumnConfig columnConfig : super.columnConfigList) {
            if (columnConfig.isTarget()) {
                requiredFieldList.add(new RequiredField(columnConfig.getColumnName(), columnConfig.getColumnNum(), null, DataType.FLOAT));
            } else {
                if (inputNodeCount == candidateCount) {
                    // no any variables are selected
                    if (!columnConfig.isMeta() && !columnConfig.isTarget() && CommonUtils.isGoodCandidate(columnConfig, hasCandidates)) {
                        requiredFieldList.add(new RequiredField(columnConfig.getColumnName(), columnConfig.getColumnNum(), null, DataType.FLOAT));
                    }
                } else {
                    if (!columnConfig.isMeta() && !columnConfig.isTarget() && columnConfig.isFinalSelect()) {
                        requiredFieldList.add(new RequiredField(columnConfig.getColumnName(), columnConfig.getColumnNum(), null, DataType.FLOAT));
                    }
                }
            }
        }
        // weight is added manually
        requiredFieldList.add(new RequiredField("weight", columnConfigList.size(), null, DataType.DOUBLE));
        args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, "parquet.private.pig.required.fields", serializeRequiredFieldList(requiredFieldList)));
        args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, "parquet.private.pig.column.index.access", "true"));
    } else {
        guaguaClient = new GuaguaMapReduceClient();
    }
    int parallelNum = Integer.parseInt(Environment.getProperty(CommonConstants.SHIFU_TRAIN_BAGGING_INPARALLEL, "5"));
    int parallelGroups = 1;
    if (gs.hasHyperParam()) {
        parallelGroups = (gs.getFlattenParams().size() % parallelNum == 0 ? gs.getFlattenParams().size() / parallelNum : gs.getFlattenParams().size() / parallelNum + 1);
        baggingNum = gs.getFlattenParams().size();
        LOG.warn("'train:baggingNum' is set to {} because of grid search enabled by settings in 'train#params'.", gs.getFlattenParams().size());
    } else {
        parallelGroups = baggingNum % parallelNum == 0 ? baggingNum / parallelNum : baggingNum / parallelNum + 1;
    }
    LOG.info("Distributed trainning with baggingNum: {}", baggingNum);
    List<String> progressLogList = new ArrayList<String>(baggingNum);
    boolean isOneJobNotContinuous = false;
    for (int j = 0; j < parallelGroups; j++) {
        int currBags = baggingNum;
        if (gs.hasHyperParam()) {
            if (j == parallelGroups - 1) {
                currBags = gs.getFlattenParams().size() % parallelNum == 0 ? parallelNum : gs.getFlattenParams().size() % parallelNum;
            } else {
                currBags = parallelNum;
            }
        } else {
            if (j == parallelGroups - 1) {
                currBags = baggingNum % parallelNum == 0 ? parallelNum : baggingNum % parallelNum;
            } else {
                currBags = parallelNum;
            }
        }
        for (int k = 0; k < currBags; k++) {
            int i = j * parallelNum + k;
            if (gs.hasHyperParam()) {
                LOG.info("Start the {}th grid search job with params: {}", i, gs.getParams(i));
            } else if (isKFoldCV) {
                LOG.info("Start the {}th k-fold cross validation job with params.", i);
            }
            List<String> localArgs = new ArrayList<String>(args);
            // set name for each bagging job.
            localArgs.add("-n");
            localArgs.add(String.format("Shifu Master-Workers %s Training Iteration: %s id:%s", alg, super.getModelConfig().getModelSetName(), i));
            LOG.info("Start trainer with id: {}", i);
            String modelName = getModelName(i);
            Path modelPath = fileSystem.makeQualified(new Path(super.getPathFinder().getModelsPath(sourceType), modelName));
            Path bModelPath = fileSystem.makeQualified(new Path(super.getPathFinder().getNNBinaryModelsPath(sourceType), modelName));
            // check if job is continuous training, this can be set multiple times and we only get last one
            boolean isContinuous = false;
            if (gs.hasHyperParam()) {
                isContinuous = false;
            } else {
                int intContinuous = checkContinuousTraining(fileSystem, localArgs, modelPath, modelConfig.getTrain().getParams());
                if (intContinuous == -1) {
                    LOG.warn("Model with index {} with size of trees is over treeNum, such training will not be started.", i);
                    continue;
                } else {
                    isContinuous = (intContinuous == 1);
                }
            }
            // training
            if (gs.hasHyperParam() || isKFoldCV) {
                isContinuous = false;
            }
            if (!isContinuous && !isOneJobNotContinuous) {
                isOneJobNotContinuous = true;
                // delete all old models if not continuous
                String srcModelPath = super.getPathFinder().getModelsPath(sourceType);
                String mvModelPath = srcModelPath + "_" + System.currentTimeMillis();
                LOG.info("Old model path has been moved to {}", mvModelPath);
                fileSystem.rename(new Path(srcModelPath), new Path(mvModelPath));
                fileSystem.mkdirs(new Path(srcModelPath));
                FileSystem.getLocal(conf).delete(new Path(super.getPathFinder().getModelsPath(SourceType.LOCAL)), true);
            }
            if (NNConstants.NN_ALG_NAME.equalsIgnoreCase(alg)) {
                // tree related parameters initialization
                Map<String, Object> params = gs.hasHyperParam() ? gs.getParams(i) : this.modelConfig.getTrain().getParams();
                Object fssObj = params.get("FeatureSubsetStrategy");
                FeatureSubsetStrategy featureSubsetStrategy = null;
                double featureSubsetRate = 0d;
                if (fssObj != null) {
                    try {
                        featureSubsetRate = Double.parseDouble(fssObj.toString());
                        // no need validate featureSubsetRate is in (0,1], as already validated in ModelInspector
                        featureSubsetStrategy = null;
                    } catch (NumberFormatException ee) {
                        featureSubsetStrategy = FeatureSubsetStrategy.of(fssObj.toString());
                    }
                } else {
                    LOG.warn("FeatureSubsetStrategy is not set, set to ALL by default.");
                    featureSubsetStrategy = FeatureSubsetStrategy.ALL;
                    featureSubsetRate = 0;
                }
                Set<Integer> subFeatures = null;
                if (isContinuous) {
                    BasicFloatNetwork existingModel = (BasicFloatNetwork) ModelSpecLoaderUtils.getBasicNetwork(ModelSpecLoaderUtils.loadModel(modelConfig, modelPath, ShifuFileUtils.getFileSystemBySourceType(this.modelConfig.getDataSet().getSource())));
                    if (existingModel == null) {
                        subFeatures = new HashSet<Integer>(getSubsamplingFeatures(allFeatures, featureSubsetStrategy, featureSubsetRate, inputNodeCount));
                    } else {
                        subFeatures = existingModel.getFeatureSet();
                    }
                } else {
                    subFeatures = new HashSet<Integer>(getSubsamplingFeatures(allFeatures, featureSubsetStrategy, featureSubsetRate, inputNodeCount));
                }
                if (subFeatures == null || subFeatures.size() == 0) {
                    localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_NN_FEATURE_SUBSET, ""));
                } else {
                    localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_NN_FEATURE_SUBSET, StringUtils.join(subFeatures, ',')));
                    LOG.debug("Size: {}, list: {}.", subFeatures.size(), StringUtils.join(subFeatures, ','));
                }
            }
            localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.GUAGUA_OUTPUT, modelPath.toString()));
            localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, Constants.SHIFU_NN_BINARY_MODEL_PATH, bModelPath.toString()));
            if (gs.hasHyperParam() || isKFoldCV) {
                // k-fold cv need val error
                Path valErrPath = fileSystem.makeQualified(new Path(super.getPathFinder().getValErrorPath(sourceType), "val_error_" + i));
                localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.GS_VALIDATION_ERROR, valErrPath.toString()));
            }
            localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_TRAINER_ID, String.valueOf(i)));
            final String progressLogFile = getProgressLogFile(i);
            progressLogList.add(progressLogFile);
            localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_DTRAIN_PROGRESS_FILE, progressLogFile));
            String hdpVersion = HDPUtils.getHdpVersionForHDP224();
            if (StringUtils.isNotBlank(hdpVersion)) {
                localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, "hdp.version", hdpVersion));
                HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("hdfs-site.xml"), conf);
                HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("core-site.xml"), conf);
                HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("mapred-site.xml"), conf);
                HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("yarn-site.xml"), conf);
            }
            if (isParallel) {
                guaguaClient.addJob(localArgs.toArray(new String[0]));
            } else {
                TailThread tailThread = startTailThread(new String[] { progressLogFile });
                boolean ret = guaguaClient.createJob(localArgs.toArray(new String[0])).waitForCompletion(true);
                status += (ret ? 0 : 1);
                stopTailThread(tailThread);
            }
        }
        if (isParallel) {
            TailThread tailThread = startTailThread(progressLogList.toArray(new String[0]));
            status += guaguaClient.run();
            stopTailThread(tailThread);
        }
    }
    if (isKFoldCV) {
        // k-fold we also copy model files at last, such models can be used for evaluation
        for (int i = 0; i < baggingNum; i++) {
            String modelName = getModelName(i);
            Path modelPath = fileSystem.makeQualified(new Path(super.getPathFinder().getModelsPath(sourceType), modelName));
            if (ShifuFileUtils.getFileSystemBySourceType(sourceType).exists(modelPath)) {
                copyModelToLocal(modelName, modelPath, sourceType);
            } else {
                LOG.warn("Model {} isn't there, maybe job is failed, for bagging it can be ignored.", modelPath.toString());
                status += 1;
            }
        }
        List<Double> valErrs = readAllValidationErrors(sourceType, fileSystem, kCrossValidation);
        double sum = 0d;
        for (Double err : valErrs) {
            sum += err;
        }
        LOG.info("Average validation error for current k-fold cross validation is {}.", sum / valErrs.size());
        LOG.info("K-fold cross validation on distributed training finished in {}ms.", System.currentTimeMillis() - start);
    } else if (gs.hasHyperParam()) {
        // select the best parameter composite in grid search
        LOG.info("Original grid search params: {}", modelConfig.getParams());
        Map<String, Object> params = findBestParams(sourceType, fileSystem, gs);
        // temp copy all models for evaluation
        for (int i = 0; i < baggingNum; i++) {
            String modelName = getModelName(i);
            Path modelPath = fileSystem.makeQualified(new Path(super.getPathFinder().getModelsPath(sourceType), modelName));
            if (ShifuFileUtils.getFileSystemBySourceType(sourceType).exists(modelPath) && (status == 0)) {
                copyModelToLocal(modelName, modelPath, sourceType);
            } else {
                LOG.warn("Model {} isn't there, maybe job is failed, for bagging it can be ignored.", modelPath.toString());
            }
        }
        LOG.info("The best parameters in grid search is {}", params);
        LOG.info("Grid search on distributed training finished in {}ms.", System.currentTimeMillis() - start);
    } else {
        // copy model files at last.
        for (int i = 0; i < baggingNum; i++) {
            String modelName = getModelName(i);
            Path modelPath = fileSystem.makeQualified(new Path(super.getPathFinder().getModelsPath(sourceType), modelName));
            if (ShifuFileUtils.getFileSystemBySourceType(sourceType).exists(modelPath) && (status == 0)) {
                copyModelToLocal(modelName, modelPath, sourceType);
            } else {
                LOG.warn("Model {} isn't there, maybe job is failed, for bagging it can be ignored.", modelPath.toString());
            }
        }
        // copy temp model files, for RF/GBT, not to copy tmp models because of larger space needed, for others
        // by default copy tmp models to local
        boolean copyTmpModelsToLocal = Boolean.TRUE.toString().equalsIgnoreCase(Environment.getProperty(Constants.SHIFU_TMPMODEL_COPYTOLOCAL, "true"));
        if (copyTmpModelsToLocal) {
            copyTmpModelsToLocal(tmpModelsPath, sourceType);
        } else {
            LOG.info("Tmp models are not copied into local, please find them in hdfs path: {}", tmpModelsPath);
        }
        LOG.info("Distributed training finished in {}ms.", System.currentTimeMillis() - start);
    }
    if (CommonUtils.isTreeModel(modelConfig.getAlgorithm())) {
        List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(this.modelConfig, null);
        // compute feature importance and write to local file after models are trained
        Map<Integer, MutablePair<String, Double>> featureImportances = CommonUtils.computeTreeModelFeatureImportance(models);
        String localFsFolder = pathFinder.getLocalFeatureImportanceFolder();
        String localFIPath = pathFinder.getLocalFeatureImportancePath();
        processRollupForFIFiles(localFsFolder, localFIPath);
        CommonUtils.writeFeatureImportance(localFIPath, featureImportances);
    }
    if (status != 0) {
        LOG.error("Error may occurred. There is no model generated. Please check!");
    }
    return status;
}
Also used : Configuration(org.apache.hadoop.conf.Configuration) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) SourceType(ml.shifu.shifu.container.obj.RawSourceData.SourceType) FeatureSubsetStrategy(ml.shifu.shifu.core.dtrain.FeatureSubsetStrategy) BasicML(org.encog.ml.BasicML) GuaguaMapReduceClient(ml.shifu.guagua.mapreduce.GuaguaMapReduceClient) MutablePair(org.apache.commons.lang3.tuple.MutablePair) RequiredField(org.apache.pig.LoadPushDown.RequiredField) FileSystem(org.apache.hadoop.fs.FileSystem) BasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork) Path(org.apache.hadoop.fs.Path) GuaguaParquetMapReduceClient(ml.shifu.shifu.guagua.GuaguaParquetMapReduceClient) GridSearch(ml.shifu.shifu.core.dtrain.gs.GridSearch) RequiredFieldList(org.apache.pig.LoadPushDown.RequiredFieldList)

Example 5 with GridSearch

use of ml.shifu.shifu.core.dtrain.gs.GridSearch in project shifu by ShifuML.

the class WDLOutput method init.

private void init(MasterContext<WDLParams, WDLParams> context) {
    if (isInit.compareAndSet(false, true)) {
        this.conf = new Configuration();
        loadConfigFiles(context.getProps());
        this.trainerId = context.getProps().getProperty(CommonConstants.SHIFU_TRAINER_ID);
        GridSearch gs = new GridSearch(modelConfig.getTrain().getParams(), modelConfig.getTrain().getGridConfigFileContent());
        this.isGsMode = gs.hasHyperParam();
        Integer kCrossValidation = this.modelConfig.getTrain().getNumKFold();
        if (kCrossValidation != null && kCrossValidation > 0) {
            isKFoldCV = true;
        }
        this.tmpModelsFolder = context.getProps().getProperty(CommonConstants.SHIFU_TMP_MODELS_FOLDER);
        try {
            Path progressLog = new Path(context.getProps().getProperty(CommonConstants.SHIFU_DTRAIN_PROGRESS_FILE));
            // we need to append the log, so that client console can get refreshed. Or console will appear stuck.
            if (ShifuFileUtils.isFileExists(progressLog, SourceType.HDFS)) {
                this.progressOutput = FileSystem.get(new Configuration()).append(progressLog);
            } else {
                this.progressOutput = FileSystem.get(new Configuration()).create(progressLog);
            }
        } catch (IOException e) {
            LOG.error("Error in create progress log:", e);
        }
    }
}
Also used : Path(org.apache.hadoop.fs.Path) Configuration(org.apache.hadoop.conf.Configuration) IOException(java.io.IOException) GridSearch(ml.shifu.shifu.core.dtrain.gs.GridSearch)

Aggregations

GridSearch (ml.shifu.shifu.core.dtrain.gs.GridSearch)14 IOException (java.io.IOException)9 Path (org.apache.hadoop.fs.Path)8 Configuration (org.apache.hadoop.conf.Configuration)6 GuaguaRuntimeException (ml.shifu.guagua.GuaguaRuntimeException)4 SourceType (ml.shifu.shifu.container.obj.RawSourceData.SourceType)4 Properties (java.util.Properties)3 List (java.util.List)2 ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)2 TreeModel (ml.shifu.shifu.core.TreeModel)2 FeatureSubsetStrategy (ml.shifu.shifu.core.dtrain.FeatureSubsetStrategy)2 ConvergeAndValidToleranceEarlyStop (ml.shifu.shifu.core.dtrain.earlystop.ConvergeAndValidToleranceEarlyStop)2 WindowEarlyStop (ml.shifu.shifu.core.dtrain.earlystop.WindowEarlyStop)2 PoissonDistribution (org.apache.commons.math3.distribution.PoissonDistribution)2 Field (java.lang.reflect.Field)1 Method (java.lang.reflect.Method)1 ArrayList (java.util.ArrayList)1 Comparator (java.util.Comparator)1 HashMap (java.util.HashMap)1 Map (java.util.Map)1