Search in sources :

Example 1 with BasicFloatNetwork

use of ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork in project shifu by ShifuML.

the class VarSelectMapper method setup.

/**
 * Do initialization like ModelConfig and ColumnConfig loading, model loading and others like input or output number
 * loading.
 */
@Override
protected void setup(Context context) throws IOException, InterruptedException {
    loadConfigFiles(context);
    loadModel();
    // Copy mode to here
    cacheNetwork = copy((BasicFloatNetwork) model);
    this.filterBy = context.getConfiguration().get(Constants.SHIFU_VARSELECT_FILTEROUT_TYPE, Constants.FILTER_BY_SE);
    int[] inputOutputIndex = DTrainUtils.getInputOutputCandidateCounts(modelConfig.getNormalizeType(), columnConfigList);
    this.inputNodeCount = inputOutputIndex[0] == 0 ? inputOutputIndex[2] : inputOutputIndex[0];
    if (model instanceof BasicFloatNetwork) {
        this.inputs = new double[((BasicFloatNetwork) model).getFeatureSet().size()];
        this.featureSet = ((BasicFloatNetwork) model).getFeatureSet();
    } else {
        this.inputs = new double[this.inputNodeCount];
    }
    boolean isAfterVarSelect = (inputOutputIndex[0] != 0);
    // cache all feature list for sampling features
    if (this.featureSet == null || this.featureSet.size() == 0) {
        this.featureSet = new HashSet<Integer>(NormalUtils.getAllFeatureList(columnConfigList, isAfterVarSelect));
        this.inputs = new double[this.featureSet.size()];
    }
    if (inputs.length != this.inputNodeCount) {
        throw new IllegalArgumentException("Model input count " + model.getInputCount() + " is inconsistent with input size " + this.inputNodeCount + ".");
    }
    this.outputs = new double[inputOutputIndex[1]];
    this.columnIndexes = new long[this.inputs.length];
    this.inputsMLData = new BasicMLData(this.inputs.length);
    this.outputKey = new LongWritable();
    LOG.info("Filter by is {}", filterBy);
    // create Splitter
    String delimiter = context.getConfiguration().get(Constants.SHIFU_OUTPUT_DATA_DELIMITER);
    this.splitter = MapReduceUtils.generateShifuOutputSplitter(delimiter);
}
Also used : BasicMLData(org.encog.ml.data.basic.BasicMLData) PersistBasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork) BasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork) CacheBasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.CacheBasicFloatNetwork) LongWritable(org.apache.hadoop.io.LongWritable)

Example 2 with BasicFloatNetwork

use of ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork in project shifu by ShifuML.

the class Scorer method scoreNsData.

public ScoreObject scoreNsData(MLDataPair inputPair, Map<NSColumn, String> rawNsDataMap) {
    if (inputPair == null && !this.alg.equalsIgnoreCase(NNConstants.NN_ALG_NAME)) {
        inputPair = NormalUtils.assembleNsDataPair(binCategoryMap, noVarSelect, modelConfig, selectedColumnConfigList, rawNsDataMap, cutoff, alg);
    }
    // clear cache
    this.cachedNormDataPair.clear();
    final MLDataPair pair = inputPair;
    List<MLData> modelResults = new ArrayList<MLData>();
    List<Callable<MLData>> tasks = null;
    if (this.multiThread) {
        tasks = new ArrayList<Callable<MLData>>();
    }
    for (final BasicML model : models) {
        // TODO, check if no need 'if' condition and refactor two if for loops please
        if (model instanceof BasicFloatNetwork || model instanceof NNModel) {
            final BasicFloatNetwork network = (model instanceof BasicFloatNetwork) ? (BasicFloatNetwork) model : ((NNModel) model).getIndependentNNModel().getBasicNetworks().get(0);
            String cacheKey = featureSetToString(network.getFeatureSet());
            MLDataPair dataPair = cachedNormDataPair.get(cacheKey);
            if (dataPair == null) {
                dataPair = NormalUtils.assembleNsDataPair(binCategoryMap, noVarSelect, modelConfig, selectedColumnConfigList, rawNsDataMap, cutoff, alg, network.getFeatureSet());
                cachedNormDataPair.put(cacheKey, dataPair);
            }
            final MLDataPair networkPair = dataPair;
            /*
                 * if(network.getFeatureSet().size() != networkPair.getInput().size()) {
                 * log.error("Network and input size mismatch: Network Size = " + network.getFeatureSet().size()
                 * + "; Input Size = " + networkPair.getInput().size());
                 * continue;
                 * }
                 */
            if (System.currentTimeMillis() % 1000 == 0L) {
                log.info("Network input count = {}, while input size = {}", network.getInputCount(), networkPair.getInput().size());
            }
            final int fnlOutputHiddenLayerIndex = outputHiddenLayerIndex;
            Callable<MLData> callable = new Callable<MLData>() {

                @Override
                public MLData call() {
                    MLData finalOutput = network.compute(networkPair.getInput());
                    if (fnlOutputHiddenLayerIndex == 0) {
                        return finalOutput;
                    }
                    // append output values in hidden layer
                    double[] hiddenOutputs = network.getLayerOutput(fnlOutputHiddenLayerIndex);
                    double[] outputs = new double[finalOutput.getData().length + hiddenOutputs.length];
                    System.arraycopy(finalOutput.getData(), 0, outputs, 0, finalOutput.getData().length);
                    System.arraycopy(hiddenOutputs, 0, outputs, finalOutput.getData().length, hiddenOutputs.length);
                    return new BasicMLData(outputs);
                }
            };
            if (multiThread) {
                tasks.add(callable);
            } else {
                try {
                    modelResults.add(callable.call());
                } catch (Exception e) {
                    log.error("error in model evaluation", e);
                }
            }
        } else if (model instanceof BasicNetwork) {
            final BasicNetwork network = (BasicNetwork) model;
            final MLDataPair networkPair = NormalUtils.assembleNsDataPair(binCategoryMap, noVarSelect, modelConfig, columnConfigList, rawNsDataMap, cutoff, alg, null);
            Callable<MLData> callable = new Callable<MLData>() {

                @Override
                public MLData call() {
                    return network.compute(networkPair.getInput());
                }
            };
            if (multiThread) {
                tasks.add(callable);
            } else {
                try {
                    modelResults.add(callable.call());
                } catch (Exception e) {
                    log.error("error in model evaluation", e);
                }
            }
        } else if (model instanceof SVM) {
            final SVM svm = (SVM) model;
            if (svm.getInputCount() != pair.getInput().size()) {
                log.error("SVM and input size mismatch: SVM Size = " + svm.getInputCount() + "; Input Size = " + pair.getInput().size());
                continue;
            }
            Callable<MLData> callable = new Callable<MLData>() {

                @Override
                public MLData call() {
                    return svm.compute(pair.getInput());
                }
            };
            if (multiThread) {
                tasks.add(callable);
            } else {
                try {
                    modelResults.add(callable.call());
                } catch (Exception e) {
                    log.error("error in model evaluation", e);
                }
            }
        } else if (model instanceof LR) {
            final LR lr = (LR) model;
            if (lr.getInputCount() != pair.getInput().size()) {
                log.error("LR and input size mismatch: LR Size = " + lr.getInputCount() + "; Input Size = " + pair.getInput().size());
                continue;
            }
            Callable<MLData> callable = new Callable<MLData>() {

                @Override
                public MLData call() {
                    return lr.compute(pair.getInput());
                }
            };
            if (multiThread) {
                tasks.add(callable);
            } else {
                try {
                    modelResults.add(callable.call());
                } catch (Exception e) {
                    log.error("error in model evaluation", e);
                }
            }
        } else if (model instanceof TreeModel) {
            final TreeModel tm = (TreeModel) model;
            if (tm.getInputCount() != pair.getInput().size()) {
                throw new RuntimeException("GBDT and input size mismatch: tm input Size = " + tm.getInputCount() + "; data input Size = " + pair.getInput().size());
            }
            Callable<MLData> callable = new Callable<MLData>() {

                @Override
                public MLData call() {
                    MLData result = tm.compute(pair.getInput());
                    return result;
                }
            };
            if (multiThread) {
                tasks.add(callable);
            } else {
                try {
                    modelResults.add(callable.call());
                } catch (Exception e) {
                    log.error("error in model evaluation", e);
                }
            }
        } else if (model instanceof GenericModel) {
            Callable<MLData> callable = new Callable<MLData>() {

                @Override
                public MLData call() {
                    return ((GenericModel) model).compute(pair.getInput());
                }
            };
            if (multiThread) {
                tasks.add(callable);
            } else {
                try {
                    modelResults.add(callable.call());
                } catch (Exception e) {
                    log.error("error in model evaluation", e);
                }
            }
        } else {
            throw new RuntimeException("unsupport models");
        }
    }
    List<Double> scores = new ArrayList<Double>();
    List<Integer> rfTreeSizeList = new ArrayList<Integer>();
    SortedMap<String, Double> hiddenOutputs = null;
    if (CollectionUtils.isNotEmpty(modelResults) || CollectionUtils.isNotEmpty(tasks)) {
        int modelSize = modelResults.size() > 0 ? modelResults.size() : tasks.size();
        if (modelSize != this.models.size()) {
            log.error("Get model results size doesn't match with models size.");
            return null;
        }
        if (multiThread) {
            modelResults = this.executorManager.submitTasksAndWaitResults(tasks);
        } else {
        // not multi-thread, modelResults is directly being populated in callable.call
        }
        if (this.outputHiddenLayerIndex != 0) {
            hiddenOutputs = new TreeMap<String, Double>(new Comparator<String>() {

                @Override
                public int compare(String o1, String o2) {
                    String[] split1 = o1.split("_");
                    String[] split2 = o2.split("_");
                    int model1Index = Integer.parseInt(split1[1]);
                    int model2Index = Integer.parseInt(split2[1]);
                    if (model1Index > model2Index) {
                        return 1;
                    } else if (model1Index < model2Index) {
                        return -1;
                    } else {
                        int hidden1Index = Integer.parseInt(split1[2]);
                        int hidden2Index = Integer.parseInt(split2[2]);
                        if (hidden1Index > hidden2Index) {
                            return 1;
                        } else if (hidden1Index < hidden2Index) {
                            return -1;
                        } else {
                            int hidden11Index = Integer.parseInt(split1[3]);
                            int hidden22Index = Integer.parseInt(split2[3]);
                            return Integer.valueOf(hidden11Index).compareTo(Integer.valueOf(hidden22Index));
                        }
                    }
                }
            });
        }
        for (int i = 0; i < this.models.size(); i++) {
            BasicML model = this.models.get(i);
            MLData score = modelResults.get(i);
            if (model instanceof BasicNetwork || model instanceof NNModel) {
                if (modelConfig != null && modelConfig.isRegression()) {
                    scores.add(toScore(score.getData(0)));
                    if (this.outputHiddenLayerIndex != 0) {
                        for (int j = 1; j < score.getData().length; j++) {
                            hiddenOutputs.put("model_" + i + "_" + this.outputHiddenLayerIndex + "_" + (j - 1), score.getData()[j]);
                        }
                    }
                } else if (modelConfig != null && modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll()) {
                    // if one vs all classification
                    scores.add(toScore(score.getData(0)));
                } else {
                    double[] outputs = score.getData();
                    for (double d : outputs) {
                        scores.add(toScore(d));
                    }
                }
            } else if (model instanceof SVM) {
                scores.add(toScore(score.getData(0)));
            } else if (model instanceof LR) {
                scores.add(toScore(score.getData(0)));
            } else if (model instanceof TreeModel) {
                if (modelConfig.isClassification() && !modelConfig.getTrain().isOneVsAll()) {
                    double[] scoreArray = score.getData();
                    for (double sc : scoreArray) {
                        scores.add(sc);
                    }
                } else {
                    // if one vs all multiple classification or regression
                    scores.add(toScore(score.getData(0)));
                }
                final TreeModel tm = (TreeModel) model;
                // regression for RF
                if (!tm.isClassfication() && !tm.isGBDT()) {
                    rfTreeSizeList.add(tm.getTrees().size());
                }
            } else if (model instanceof GenericModel) {
                scores.add(toScore(score.getData(0)));
            } else {
                throw new RuntimeException("unsupport models");
            }
        }
    }
    Integer tag = Constants.DEFAULT_IDEAL_VALUE;
    if (scores.size() == 0 && System.currentTimeMillis() % 100 == 0) {
        log.warn("No Scores Calculated...");
    }
    return new ScoreObject(scores, tag, rfTreeSizeList, hiddenOutputs);
}
Also used : MLDataPair(org.encog.ml.data.MLDataPair) ArrayList(java.util.ArrayList) BasicML(org.encog.ml.BasicML) SVM(org.encog.ml.svm.SVM) Callable(java.util.concurrent.Callable) Comparator(java.util.Comparator) BasicMLData(org.encog.ml.data.basic.BasicMLData) BasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork) BasicMLData(org.encog.ml.data.basic.BasicMLData) MLData(org.encog.ml.data.MLData) ScoreObject(ml.shifu.shifu.container.ScoreObject) BasicNetwork(org.encog.neural.networks.BasicNetwork)

Example 3 with BasicFloatNetwork

use of ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork in project shifu by ShifuML.

the class TrainModelProcessor method runDistributedTrain.

protected int runDistributedTrain() throws IOException, InterruptedException, ClassNotFoundException {
    LOG.info("Started {}distributed training.", isDryTrain ? "dry " : "");
    int status = 0;
    Configuration conf = new Configuration();
    SourceType sourceType = super.getModelConfig().getDataSet().getSource();
    final List<String> args = new ArrayList<String>();
    GridSearch gs = new GridSearch(modelConfig.getTrain().getParams(), modelConfig.getTrain().getGridConfigFileContent());
    prepareCommonParams(gs.hasHyperParam(), args, sourceType);
    String alg = super.getModelConfig().getTrain().getAlgorithm();
    // add tmp models folder to config
    FileSystem fileSystem = ShifuFileUtils.getFileSystemBySourceType(sourceType);
    Path tmpModelsPath = fileSystem.makeQualified(new Path(super.getPathFinder().getPathBySourceType(new Path(Constants.TMP, Constants.DEFAULT_MODELS_TMP_FOLDER), sourceType)));
    args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_TMP_MODELS_FOLDER, tmpModelsPath.toString()));
    int baggingNum = isForVarSelect ? 1 : super.getModelConfig().getBaggingNum();
    if (modelConfig.isClassification()) {
        int classes = modelConfig.getTags().size();
        if (classes == 2) {
            // binary classification, only need one job
            baggingNum = 1;
        } else {
            if (modelConfig.getTrain().isOneVsAll()) {
                // one vs all multiple classification, we need multiple bagging jobs to do ONEVSALL
                baggingNum = modelConfig.getTags().size();
            } else {
            // native classification, using bagging from setting job, no need set here
            }
        }
        if (baggingNum != super.getModelConfig().getBaggingNum()) {
            LOG.warn("'train:baggingNum' is set to {} because of ONEVSALL multiple classification.", baggingNum);
        }
    }
    boolean isKFoldCV = false;
    Integer kCrossValidation = this.modelConfig.getTrain().getNumKFold();
    if (kCrossValidation != null && kCrossValidation > 0) {
        isKFoldCV = true;
        baggingNum = modelConfig.getTrain().getNumKFold();
        if (baggingNum != super.getModelConfig().getBaggingNum() && gs.hasHyperParam()) {
            // if it is grid search mode, then kfold mode is disabled
            LOG.warn("'train:baggingNum' is set to {} because of k-fold cross validation is enabled by 'numKFold' not -1.", baggingNum);
        }
    }
    long start = System.currentTimeMillis();
    boolean isParallel = Boolean.valueOf(Environment.getProperty(Constants.SHIFU_DTRAIN_PARALLEL, SHIFU_DEFAULT_DTRAIN_PARALLEL)).booleanValue();
    GuaguaMapReduceClient guaguaClient;
    int[] inputOutputIndex = DTrainUtils.getInputOutputCandidateCounts(modelConfig.getNormalizeType(), this.columnConfigList);
    int inputNodeCount = inputOutputIndex[0] == 0 ? inputOutputIndex[2] : inputOutputIndex[0];
    int candidateCount = inputOutputIndex[2];
    boolean isAfterVarSelect = (inputOutputIndex[0] != 0);
    // cache all feature list for sampling features
    List<Integer> allFeatures = NormalUtils.getAllFeatureList(this.columnConfigList, isAfterVarSelect);
    if (modelConfig.getNormalize().getIsParquet()) {
        guaguaClient = new GuaguaParquetMapReduceClient();
        // set required field list to make sure we only load selected columns.
        RequiredFieldList requiredFieldList = new RequiredFieldList();
        boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
        for (ColumnConfig columnConfig : super.columnConfigList) {
            if (columnConfig.isTarget()) {
                requiredFieldList.add(new RequiredField(columnConfig.getColumnName(), columnConfig.getColumnNum(), null, DataType.FLOAT));
            } else {
                if (inputNodeCount == candidateCount) {
                    // no any variables are selected
                    if (!columnConfig.isMeta() && !columnConfig.isTarget() && CommonUtils.isGoodCandidate(columnConfig, hasCandidates)) {
                        requiredFieldList.add(new RequiredField(columnConfig.getColumnName(), columnConfig.getColumnNum(), null, DataType.FLOAT));
                    }
                } else {
                    if (!columnConfig.isMeta() && !columnConfig.isTarget() && columnConfig.isFinalSelect()) {
                        requiredFieldList.add(new RequiredField(columnConfig.getColumnName(), columnConfig.getColumnNum(), null, DataType.FLOAT));
                    }
                }
            }
        }
        // weight is added manually
        requiredFieldList.add(new RequiredField("weight", columnConfigList.size(), null, DataType.DOUBLE));
        args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, "parquet.private.pig.required.fields", serializeRequiredFieldList(requiredFieldList)));
        args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, "parquet.private.pig.column.index.access", "true"));
    } else {
        guaguaClient = new GuaguaMapReduceClient();
    }
    int parallelNum = Integer.parseInt(Environment.getProperty(CommonConstants.SHIFU_TRAIN_BAGGING_INPARALLEL, "5"));
    int parallelGroups = 1;
    if (gs.hasHyperParam()) {
        parallelGroups = (gs.getFlattenParams().size() % parallelNum == 0 ? gs.getFlattenParams().size() / parallelNum : gs.getFlattenParams().size() / parallelNum + 1);
        baggingNum = gs.getFlattenParams().size();
        LOG.warn("'train:baggingNum' is set to {} because of grid search enabled by settings in 'train#params'.", gs.getFlattenParams().size());
    } else {
        parallelGroups = baggingNum % parallelNum == 0 ? baggingNum / parallelNum : baggingNum / parallelNum + 1;
    }
    LOG.info("Distributed trainning with baggingNum: {}", baggingNum);
    List<String> progressLogList = new ArrayList<String>(baggingNum);
    boolean isOneJobNotContinuous = false;
    for (int j = 0; j < parallelGroups; j++) {
        int currBags = baggingNum;
        if (gs.hasHyperParam()) {
            if (j == parallelGroups - 1) {
                currBags = gs.getFlattenParams().size() % parallelNum == 0 ? parallelNum : gs.getFlattenParams().size() % parallelNum;
            } else {
                currBags = parallelNum;
            }
        } else {
            if (j == parallelGroups - 1) {
                currBags = baggingNum % parallelNum == 0 ? parallelNum : baggingNum % parallelNum;
            } else {
                currBags = parallelNum;
            }
        }
        for (int k = 0; k < currBags; k++) {
            int i = j * parallelNum + k;
            if (gs.hasHyperParam()) {
                LOG.info("Start the {}th grid search job with params: {}", i, gs.getParams(i));
            } else if (isKFoldCV) {
                LOG.info("Start the {}th k-fold cross validation job with params.", i);
            }
            List<String> localArgs = new ArrayList<String>(args);
            // set name for each bagging job.
            localArgs.add("-n");
            localArgs.add(String.format("Shifu Master-Workers %s Training Iteration: %s id:%s", alg, super.getModelConfig().getModelSetName(), i));
            LOG.info("Start trainer with id: {}", i);
            String modelName = getModelName(i);
            Path modelPath = fileSystem.makeQualified(new Path(super.getPathFinder().getModelsPath(sourceType), modelName));
            Path bModelPath = fileSystem.makeQualified(new Path(super.getPathFinder().getNNBinaryModelsPath(sourceType), modelName));
            // check if job is continuous training, this can be set multiple times and we only get last one
            boolean isContinuous = false;
            if (gs.hasHyperParam()) {
                isContinuous = false;
            } else {
                int intContinuous = checkContinuousTraining(fileSystem, localArgs, modelPath, modelConfig.getTrain().getParams());
                if (intContinuous == -1) {
                    LOG.warn("Model with index {} with size of trees is over treeNum, such training will not be started.", i);
                    continue;
                } else {
                    isContinuous = (intContinuous == 1);
                }
            }
            // training
            if (gs.hasHyperParam() || isKFoldCV) {
                isContinuous = false;
            }
            if (!isContinuous && !isOneJobNotContinuous) {
                isOneJobNotContinuous = true;
                // delete all old models if not continuous
                String srcModelPath = super.getPathFinder().getModelsPath(sourceType);
                String mvModelPath = srcModelPath + "_" + System.currentTimeMillis();
                LOG.info("Old model path has been moved to {}", mvModelPath);
                fileSystem.rename(new Path(srcModelPath), new Path(mvModelPath));
                fileSystem.mkdirs(new Path(srcModelPath));
                FileSystem.getLocal(conf).delete(new Path(super.getPathFinder().getModelsPath(SourceType.LOCAL)), true);
            }
            if (NNConstants.NN_ALG_NAME.equalsIgnoreCase(alg)) {
                // tree related parameters initialization
                Map<String, Object> params = gs.hasHyperParam() ? gs.getParams(i) : this.modelConfig.getTrain().getParams();
                Object fssObj = params.get("FeatureSubsetStrategy");
                FeatureSubsetStrategy featureSubsetStrategy = null;
                double featureSubsetRate = 0d;
                if (fssObj != null) {
                    try {
                        featureSubsetRate = Double.parseDouble(fssObj.toString());
                        // no need validate featureSubsetRate is in (0,1], as already validated in ModelInspector
                        featureSubsetStrategy = null;
                    } catch (NumberFormatException ee) {
                        featureSubsetStrategy = FeatureSubsetStrategy.of(fssObj.toString());
                    }
                } else {
                    LOG.warn("FeatureSubsetStrategy is not set, set to ALL by default.");
                    featureSubsetStrategy = FeatureSubsetStrategy.ALL;
                    featureSubsetRate = 0;
                }
                Set<Integer> subFeatures = null;
                if (isContinuous) {
                    BasicFloatNetwork existingModel = (BasicFloatNetwork) ModelSpecLoaderUtils.getBasicNetwork(ModelSpecLoaderUtils.loadModel(modelConfig, modelPath, ShifuFileUtils.getFileSystemBySourceType(this.modelConfig.getDataSet().getSource())));
                    if (existingModel == null) {
                        subFeatures = new HashSet<Integer>(getSubsamplingFeatures(allFeatures, featureSubsetStrategy, featureSubsetRate, inputNodeCount));
                    } else {
                        subFeatures = existingModel.getFeatureSet();
                    }
                } else {
                    subFeatures = new HashSet<Integer>(getSubsamplingFeatures(allFeatures, featureSubsetStrategy, featureSubsetRate, inputNodeCount));
                }
                if (subFeatures == null || subFeatures.size() == 0) {
                    localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_NN_FEATURE_SUBSET, ""));
                } else {
                    localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_NN_FEATURE_SUBSET, StringUtils.join(subFeatures, ',')));
                    LOG.debug("Size: {}, list: {}.", subFeatures.size(), StringUtils.join(subFeatures, ','));
                }
            }
            localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.GUAGUA_OUTPUT, modelPath.toString()));
            localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, Constants.SHIFU_NN_BINARY_MODEL_PATH, bModelPath.toString()));
            if (gs.hasHyperParam() || isKFoldCV) {
                // k-fold cv need val error
                Path valErrPath = fileSystem.makeQualified(new Path(super.getPathFinder().getValErrorPath(sourceType), "val_error_" + i));
                localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.GS_VALIDATION_ERROR, valErrPath.toString()));
            }
            localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_TRAINER_ID, String.valueOf(i)));
            final String progressLogFile = getProgressLogFile(i);
            progressLogList.add(progressLogFile);
            localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_DTRAIN_PROGRESS_FILE, progressLogFile));
            String hdpVersion = HDPUtils.getHdpVersionForHDP224();
            if (StringUtils.isNotBlank(hdpVersion)) {
                localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, "hdp.version", hdpVersion));
                HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("hdfs-site.xml"), conf);
                HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("core-site.xml"), conf);
                HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("mapred-site.xml"), conf);
                HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("yarn-site.xml"), conf);
            }
            if (isParallel) {
                guaguaClient.addJob(localArgs.toArray(new String[0]));
            } else {
                TailThread tailThread = startTailThread(new String[] { progressLogFile });
                boolean ret = guaguaClient.createJob(localArgs.toArray(new String[0])).waitForCompletion(true);
                status += (ret ? 0 : 1);
                stopTailThread(tailThread);
            }
        }
        if (isParallel) {
            TailThread tailThread = startTailThread(progressLogList.toArray(new String[0]));
            status += guaguaClient.run();
            stopTailThread(tailThread);
        }
    }
    if (isKFoldCV) {
        // k-fold we also copy model files at last, such models can be used for evaluation
        for (int i = 0; i < baggingNum; i++) {
            String modelName = getModelName(i);
            Path modelPath = fileSystem.makeQualified(new Path(super.getPathFinder().getModelsPath(sourceType), modelName));
            if (ShifuFileUtils.getFileSystemBySourceType(sourceType).exists(modelPath)) {
                copyModelToLocal(modelName, modelPath, sourceType);
            } else {
                LOG.warn("Model {} isn't there, maybe job is failed, for bagging it can be ignored.", modelPath.toString());
                status += 1;
            }
        }
        List<Double> valErrs = readAllValidationErrors(sourceType, fileSystem, kCrossValidation);
        double sum = 0d;
        for (Double err : valErrs) {
            sum += err;
        }
        LOG.info("Average validation error for current k-fold cross validation is {}.", sum / valErrs.size());
        LOG.info("K-fold cross validation on distributed training finished in {}ms.", System.currentTimeMillis() - start);
    } else if (gs.hasHyperParam()) {
        // select the best parameter composite in grid search
        LOG.info("Original grid search params: {}", modelConfig.getParams());
        Map<String, Object> params = findBestParams(sourceType, fileSystem, gs);
        // temp copy all models for evaluation
        for (int i = 0; i < baggingNum; i++) {
            String modelName = getModelName(i);
            Path modelPath = fileSystem.makeQualified(new Path(super.getPathFinder().getModelsPath(sourceType), modelName));
            if (ShifuFileUtils.getFileSystemBySourceType(sourceType).exists(modelPath) && (status == 0)) {
                copyModelToLocal(modelName, modelPath, sourceType);
            } else {
                LOG.warn("Model {} isn't there, maybe job is failed, for bagging it can be ignored.", modelPath.toString());
            }
        }
        LOG.info("The best parameters in grid search is {}", params);
        LOG.info("Grid search on distributed training finished in {}ms.", System.currentTimeMillis() - start);
    } else {
        // copy model files at last.
        for (int i = 0; i < baggingNum; i++) {
            String modelName = getModelName(i);
            Path modelPath = fileSystem.makeQualified(new Path(super.getPathFinder().getModelsPath(sourceType), modelName));
            if (ShifuFileUtils.getFileSystemBySourceType(sourceType).exists(modelPath) && (status == 0)) {
                copyModelToLocal(modelName, modelPath, sourceType);
            } else {
                LOG.warn("Model {} isn't there, maybe job is failed, for bagging it can be ignored.", modelPath.toString());
            }
        }
        // copy temp model files, for RF/GBT, not to copy tmp models because of larger space needed, for others
        // by default copy tmp models to local
        boolean copyTmpModelsToLocal = Boolean.TRUE.toString().equalsIgnoreCase(Environment.getProperty(Constants.SHIFU_TMPMODEL_COPYTOLOCAL, "true"));
        if (copyTmpModelsToLocal) {
            copyTmpModelsToLocal(tmpModelsPath, sourceType);
        } else {
            LOG.info("Tmp models are not copied into local, please find them in hdfs path: {}", tmpModelsPath);
        }
        LOG.info("Distributed training finished in {}ms.", System.currentTimeMillis() - start);
    }
    if (CommonUtils.isTreeModel(modelConfig.getAlgorithm())) {
        List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(this.modelConfig, null);
        // compute feature importance and write to local file after models are trained
        Map<Integer, MutablePair<String, Double>> featureImportances = CommonUtils.computeTreeModelFeatureImportance(models);
        String localFsFolder = pathFinder.getLocalFeatureImportanceFolder();
        String localFIPath = pathFinder.getLocalFeatureImportancePath();
        processRollupForFIFiles(localFsFolder, localFIPath);
        CommonUtils.writeFeatureImportance(localFIPath, featureImportances);
    }
    if (status != 0) {
        LOG.error("Error may occurred. There is no model generated. Please check!");
    }
    return status;
}
Also used : Configuration(org.apache.hadoop.conf.Configuration) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) SourceType(ml.shifu.shifu.container.obj.RawSourceData.SourceType) FeatureSubsetStrategy(ml.shifu.shifu.core.dtrain.FeatureSubsetStrategy) BasicML(org.encog.ml.BasicML) GuaguaMapReduceClient(ml.shifu.guagua.mapreduce.GuaguaMapReduceClient) MutablePair(org.apache.commons.lang3.tuple.MutablePair) RequiredField(org.apache.pig.LoadPushDown.RequiredField) FileSystem(org.apache.hadoop.fs.FileSystem) BasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork) Path(org.apache.hadoop.fs.Path) GuaguaParquetMapReduceClient(ml.shifu.shifu.guagua.GuaguaParquetMapReduceClient) GridSearch(ml.shifu.shifu.core.dtrain.gs.GridSearch) RequiredFieldList(org.apache.pig.LoadPushDown.RequiredFieldList)

Example 4 with BasicFloatNetwork

use of ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork in project shifu by ShifuML.

the class ModelStatsCreator method build.

@Override
public ModelStats build(BasicML basicML) {
    ModelStats modelStats = new ModelStats();
    if (basicML instanceof BasicFloatNetwork) {
        BasicFloatNetwork bfn = (BasicFloatNetwork) basicML;
        Set<Integer> featureSet = bfn.getFeatureSet();
        for (ColumnConfig columnConfig : columnConfigList) {
            if (columnConfig.isFinalSelect() && (CollectionUtils.isEmpty(featureSet) || featureSet.contains(columnConfig.getColumnNum()))) {
                UnivariateStats univariateStats = new UnivariateStats();
                // here, no need to consider if column is in segment expansion
                // as we need to address new stats variable
                // set simple column name in PMML
                univariateStats.setField(FieldName.create(NormalUtils.getSimpleColumnName(columnConfig.getColumnName())));
                if (columnConfig.isCategorical()) {
                    DiscrStats discrStats = new DiscrStats();
                    Array countArray = createCountArray(columnConfig);
                    discrStats.addArrays(countArray);
                    if (!isConcise) {
                        List<Extension> extensions = createExtensions(columnConfig);
                        discrStats.addExtensions(extensions.toArray(new Extension[extensions.size()]));
                    }
                    univariateStats.setDiscrStats(discrStats);
                } else {
                    // numerical column
                    univariateStats.setNumericInfo(createNumericInfo(columnConfig));
                    if (!isConcise) {
                        univariateStats.setContStats(createConStats(columnConfig));
                    }
                }
                modelStats.addUnivariateStats(univariateStats);
            }
        }
    } else {
        for (ColumnConfig columnConfig : columnConfigList) {
            if (columnConfig.isFinalSelect()) {
                UnivariateStats univariateStats = new UnivariateStats();
                // here, no need to consider if column is in segment expansion as we need to address new stats
                // variable
                univariateStats.setField(FieldName.create(NormalUtils.getSimpleColumnName(columnConfig.getColumnName())));
                if (columnConfig.isCategorical()) {
                    DiscrStats discrStats = new DiscrStats();
                    Array countArray = createCountArray(columnConfig);
                    discrStats.addArrays(countArray);
                    if (!isConcise) {
                        List<Extension> extensions = createExtensions(columnConfig);
                        discrStats.addExtensions(extensions.toArray(new Extension[extensions.size()]));
                    }
                    univariateStats.setDiscrStats(discrStats);
                } else {
                    // numerical column
                    univariateStats.setNumericInfo(createNumericInfo(columnConfig));
                    if (!isConcise) {
                        univariateStats.setContStats(createConStats(columnConfig));
                    }
                }
                modelStats.addUnivariateStats(univariateStats);
            }
        }
    }
    return modelStats;
}
Also used : Array(org.dmg.pmml.Array) Extension(org.dmg.pmml.Extension) DiscrStats(org.dmg.pmml.DiscrStats) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) UnivariateStats(org.dmg.pmml.UnivariateStats) ModelStats(org.dmg.pmml.ModelStats) BasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)

Example 5 with BasicFloatNetwork

use of ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork in project shifu by ShifuML.

the class DataDictionaryCreator method build.

@Override
public DataDictionary build(BasicML basicML) {
    DataDictionary dict = new DataDictionary();
    List<DataField> fields = new ArrayList<DataField>();
    boolean isSegExpansionMode = columnConfigList.size() > datasetHeaders.length;
    int segSize = segmentExpansions.size();
    if (basicML != null && basicML instanceof BasicFloatNetwork) {
        BasicFloatNetwork bfn = (BasicFloatNetwork) basicML;
        Set<Integer> featureSet = bfn.getFeatureSet();
        for (ColumnConfig columnConfig : columnConfigList) {
            if (columnConfig.getColumnNum() >= datasetHeaders.length) {
                // in order
                break;
            }
            if (isConcise) {
                if (columnConfig.isFinalSelect() && (CollectionUtils.isEmpty(featureSet) || featureSet.contains(columnConfig.getColumnNum())) || columnConfig.isTarget()) {
                    fields.add(convertColumnToDataField(columnConfig));
                } else if (isSegExpansionMode) {
                    // even current column not selected, if segment column selected, we should keep raw column
                    for (int i = 0; i < segSize; i++) {
                        int newIndex = datasetHeaders.length * (i + 1) + columnConfig.getColumnNum();
                        ColumnConfig cc = columnConfigList.get(newIndex);
                        if (cc.isFinalSelect()) {
                            // if one segment feature is selected, we should put raw column in
                            fields.add(convertColumnToDataField(columnConfig));
                            break;
                        }
                    }
                }
            } else {
                fields.add(convertColumnToDataField(columnConfig));
            }
        }
    } else {
        for (ColumnConfig columnConfig : columnConfigList) {
            if (columnConfig.getColumnNum() >= datasetHeaders.length) {
                // in order
                break;
            }
            if (isConcise) {
                if (columnConfig.isFinalSelect() || columnConfig.isTarget()) {
                    fields.add(convertColumnToDataField(columnConfig));
                } else if (isSegExpansionMode) {
                    // even current column not selected, if segment column selected, we should keep raw column
                    for (int i = 0; i < segSize; i++) {
                        int newIndex = datasetHeaders.length * (i + 1) + columnConfig.getColumnNum();
                        ColumnConfig cc = columnConfigList.get(newIndex);
                        if (cc.isFinalSelect()) {
                            // if one segment feature is selected, we should put raw column in
                            fields.add(convertColumnToDataField(columnConfig));
                            break;
                        }
                    }
                }
            } else {
                fields.add(convertColumnToDataField(columnConfig));
            }
        }
    }
    dict.addDataFields(fields.toArray(new DataField[fields.size()]));
    dict.setNumberOfFields(fields.size());
    return dict;
}
Also used : DataField(org.dmg.pmml.DataField) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) ArrayList(java.util.ArrayList) BasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork) DataDictionary(org.dmg.pmml.DataDictionary)

Aggregations

BasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)13 ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)5 ArrayList (java.util.ArrayList)4 PersistBasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork)4 BasicML (org.encog.ml.BasicML)4 List (java.util.List)3 BasicMLData (org.encog.ml.data.basic.BasicMLData)3 IOException (java.io.IOException)2 Path (org.apache.hadoop.fs.Path)2 RequiredFieldList (org.apache.pig.LoadPushDown.RequiredFieldList)2 BufferedInputStream (java.io.BufferedInputStream)1 DataInputStream (java.io.DataInputStream)1 Comparator (java.util.Comparator)1 HashMap (java.util.HashMap)1 Map (java.util.Map)1 Callable (java.util.concurrent.Callable)1 GZIPInputStream (java.util.zip.GZIPInputStream)1 GuaguaRuntimeException (ml.shifu.guagua.GuaguaRuntimeException)1 GuaguaMapReduceClient (ml.shifu.guagua.mapreduce.GuaguaMapReduceClient)1 ScoreObject (ml.shifu.shifu.container.ScoreObject)1