Search in sources :

Example 1 with BasicML

use of org.encog.ml.BasicML in project shifu by ShifuML.

the class Scorer method scoreNsData.

public ScoreObject scoreNsData(MLDataPair inputPair, Map<NSColumn, String> rawNsDataMap) {
    if (inputPair == null && !this.alg.equalsIgnoreCase(NNConstants.NN_ALG_NAME)) {
        inputPair = NormalUtils.assembleNsDataPair(binCategoryMap, noVarSelect, modelConfig, selectedColumnConfigList, rawNsDataMap, cutoff, alg);
    }
    // clear cache
    this.cachedNormDataPair.clear();
    final MLDataPair pair = inputPair;
    List<MLData> modelResults = new ArrayList<MLData>();
    List<Callable<MLData>> tasks = null;
    if (this.multiThread) {
        tasks = new ArrayList<Callable<MLData>>();
    }
    for (final BasicML model : models) {
        // TODO, check if no need 'if' condition and refactor two if for loops please
        if (model instanceof BasicFloatNetwork || model instanceof NNModel) {
            final BasicFloatNetwork network = (model instanceof BasicFloatNetwork) ? (BasicFloatNetwork) model : ((NNModel) model).getIndependentNNModel().getBasicNetworks().get(0);
            String cacheKey = featureSetToString(network.getFeatureSet());
            MLDataPair dataPair = cachedNormDataPair.get(cacheKey);
            if (dataPair == null) {
                dataPair = NormalUtils.assembleNsDataPair(binCategoryMap, noVarSelect, modelConfig, selectedColumnConfigList, rawNsDataMap, cutoff, alg, network.getFeatureSet());
                cachedNormDataPair.put(cacheKey, dataPair);
            }
            final MLDataPair networkPair = dataPair;
            /*
                 * if(network.getFeatureSet().size() != networkPair.getInput().size()) {
                 * log.error("Network and input size mismatch: Network Size = " + network.getFeatureSet().size()
                 * + "; Input Size = " + networkPair.getInput().size());
                 * continue;
                 * }
                 */
            if (System.currentTimeMillis() % 1000 == 0L) {
                log.info("Network input count = {}, while input size = {}", network.getInputCount(), networkPair.getInput().size());
            }
            final int fnlOutputHiddenLayerIndex = outputHiddenLayerIndex;
            Callable<MLData> callable = new Callable<MLData>() {

                @Override
                public MLData call() {
                    MLData finalOutput = network.compute(networkPair.getInput());
                    if (fnlOutputHiddenLayerIndex == 0) {
                        return finalOutput;
                    }
                    // append output values in hidden layer
                    double[] hiddenOutputs = network.getLayerOutput(fnlOutputHiddenLayerIndex);
                    double[] outputs = new double[finalOutput.getData().length + hiddenOutputs.length];
                    System.arraycopy(finalOutput.getData(), 0, outputs, 0, finalOutput.getData().length);
                    System.arraycopy(hiddenOutputs, 0, outputs, finalOutput.getData().length, hiddenOutputs.length);
                    return new BasicMLData(outputs);
                }
            };
            if (multiThread) {
                tasks.add(callable);
            } else {
                try {
                    modelResults.add(callable.call());
                } catch (Exception e) {
                    log.error("error in model evaluation", e);
                }
            }
        } else if (model instanceof BasicNetwork) {
            final BasicNetwork network = (BasicNetwork) model;
            final MLDataPair networkPair = NormalUtils.assembleNsDataPair(binCategoryMap, noVarSelect, modelConfig, columnConfigList, rawNsDataMap, cutoff, alg, null);
            Callable<MLData> callable = new Callable<MLData>() {

                @Override
                public MLData call() {
                    return network.compute(networkPair.getInput());
                }
            };
            if (multiThread) {
                tasks.add(callable);
            } else {
                try {
                    modelResults.add(callable.call());
                } catch (Exception e) {
                    log.error("error in model evaluation", e);
                }
            }
        } else if (model instanceof SVM) {
            final SVM svm = (SVM) model;
            if (svm.getInputCount() != pair.getInput().size()) {
                log.error("SVM and input size mismatch: SVM Size = " + svm.getInputCount() + "; Input Size = " + pair.getInput().size());
                continue;
            }
            Callable<MLData> callable = new Callable<MLData>() {

                @Override
                public MLData call() {
                    return svm.compute(pair.getInput());
                }
            };
            if (multiThread) {
                tasks.add(callable);
            } else {
                try {
                    modelResults.add(callable.call());
                } catch (Exception e) {
                    log.error("error in model evaluation", e);
                }
            }
        } else if (model instanceof LR) {
            final LR lr = (LR) model;
            if (lr.getInputCount() != pair.getInput().size()) {
                log.error("LR and input size mismatch: LR Size = " + lr.getInputCount() + "; Input Size = " + pair.getInput().size());
                continue;
            }
            Callable<MLData> callable = new Callable<MLData>() {

                @Override
                public MLData call() {
                    return lr.compute(pair.getInput());
                }
            };
            if (multiThread) {
                tasks.add(callable);
            } else {
                try {
                    modelResults.add(callable.call());
                } catch (Exception e) {
                    log.error("error in model evaluation", e);
                }
            }
        } else if (model instanceof TreeModel) {
            final TreeModel tm = (TreeModel) model;
            if (tm.getInputCount() != pair.getInput().size()) {
                throw new RuntimeException("GBDT and input size mismatch: tm input Size = " + tm.getInputCount() + "; data input Size = " + pair.getInput().size());
            }
            Callable<MLData> callable = new Callable<MLData>() {

                @Override
                public MLData call() {
                    MLData result = tm.compute(pair.getInput());
                    return result;
                }
            };
            if (multiThread) {
                tasks.add(callable);
            } else {
                try {
                    modelResults.add(callable.call());
                } catch (Exception e) {
                    log.error("error in model evaluation", e);
                }
            }
        } else if (model instanceof GenericModel) {
            Callable<MLData> callable = new Callable<MLData>() {

                @Override
                public MLData call() {
                    return ((GenericModel) model).compute(pair.getInput());
                }
            };
            if (multiThread) {
                tasks.add(callable);
            } else {
                try {
                    modelResults.add(callable.call());
                } catch (Exception e) {
                    log.error("error in model evaluation", e);
                }
            }
        } else {
            throw new RuntimeException("unsupport models");
        }
    }
    List<Double> scores = new ArrayList<Double>();
    List<Integer> rfTreeSizeList = new ArrayList<Integer>();
    SortedMap<String, Double> hiddenOutputs = null;
    if (CollectionUtils.isNotEmpty(modelResults) || CollectionUtils.isNotEmpty(tasks)) {
        int modelSize = modelResults.size() > 0 ? modelResults.size() : tasks.size();
        if (modelSize != this.models.size()) {
            log.error("Get model results size doesn't match with models size.");
            return null;
        }
        if (multiThread) {
            modelResults = this.executorManager.submitTasksAndWaitResults(tasks);
        } else {
        // not multi-thread, modelResults is directly being populated in callable.call
        }
        if (this.outputHiddenLayerIndex != 0) {
            hiddenOutputs = new TreeMap<String, Double>(new Comparator<String>() {

                @Override
                public int compare(String o1, String o2) {
                    String[] split1 = o1.split("_");
                    String[] split2 = o2.split("_");
                    int model1Index = Integer.parseInt(split1[1]);
                    int model2Index = Integer.parseInt(split2[1]);
                    if (model1Index > model2Index) {
                        return 1;
                    } else if (model1Index < model2Index) {
                        return -1;
                    } else {
                        int hidden1Index = Integer.parseInt(split1[2]);
                        int hidden2Index = Integer.parseInt(split2[2]);
                        if (hidden1Index > hidden2Index) {
                            return 1;
                        } else if (hidden1Index < hidden2Index) {
                            return -1;
                        } else {
                            int hidden11Index = Integer.parseInt(split1[3]);
                            int hidden22Index = Integer.parseInt(split2[3]);
                            return Integer.valueOf(hidden11Index).compareTo(Integer.valueOf(hidden22Index));
                        }
                    }
                }
            });
        }
        for (int i = 0; i < this.models.size(); i++) {
            BasicML model = this.models.get(i);
            MLData score = modelResults.get(i);
            if (model instanceof BasicNetwork || model instanceof NNModel) {
                if (modelConfig != null && modelConfig.isRegression()) {
                    scores.add(toScore(score.getData(0)));
                    if (this.outputHiddenLayerIndex != 0) {
                        for (int j = 1; j < score.getData().length; j++) {
                            hiddenOutputs.put("model_" + i + "_" + this.outputHiddenLayerIndex + "_" + (j - 1), score.getData()[j]);
                        }
                    }
                } else if (modelConfig != null && modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll()) {
                    // if one vs all classification
                    scores.add(toScore(score.getData(0)));
                } else {
                    double[] outputs = score.getData();
                    for (double d : outputs) {
                        scores.add(toScore(d));
                    }
                }
            } else if (model instanceof SVM) {
                scores.add(toScore(score.getData(0)));
            } else if (model instanceof LR) {
                scores.add(toScore(score.getData(0)));
            } else if (model instanceof TreeModel) {
                if (modelConfig.isClassification() && !modelConfig.getTrain().isOneVsAll()) {
                    double[] scoreArray = score.getData();
                    for (double sc : scoreArray) {
                        scores.add(sc);
                    }
                } else {
                    // if one vs all multiple classification or regression
                    scores.add(toScore(score.getData(0)));
                }
                final TreeModel tm = (TreeModel) model;
                // regression for RF
                if (!tm.isClassfication() && !tm.isGBDT()) {
                    rfTreeSizeList.add(tm.getTrees().size());
                }
            } else if (model instanceof GenericModel) {
                scores.add(toScore(score.getData(0)));
            } else {
                throw new RuntimeException("unsupport models");
            }
        }
    }
    Integer tag = Constants.DEFAULT_IDEAL_VALUE;
    if (scores.size() == 0 && System.currentTimeMillis() % 100 == 0) {
        log.warn("No Scores Calculated...");
    }
    return new ScoreObject(scores, tag, rfTreeSizeList, hiddenOutputs);
}
Also used : MLDataPair(org.encog.ml.data.MLDataPair) ArrayList(java.util.ArrayList) BasicML(org.encog.ml.BasicML) SVM(org.encog.ml.svm.SVM) Callable(java.util.concurrent.Callable) Comparator(java.util.Comparator) BasicMLData(org.encog.ml.data.basic.BasicMLData) BasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork) BasicMLData(org.encog.ml.data.basic.BasicMLData) MLData(org.encog.ml.data.MLData) ScoreObject(ml.shifu.shifu.container.ScoreObject) BasicNetwork(org.encog.neural.networks.BasicNetwork)

Example 2 with BasicML

use of org.encog.ml.BasicML in project shifu by ShifuML.

the class EvalModelProcessor method validateEvalColumnConfig.

@SuppressWarnings("deprecation")
private void validateEvalColumnConfig(EvalConfig evalConfig) throws IOException {
    if (this.columnConfigList == null) {
        return;
    }
    String[] evalColumnNames = null;
    if (StringUtils.isNotBlank(evalConfig.getDataSet().getHeaderPath())) {
        String delimiter = // get header delimiter
        StringUtils.isBlank(evalConfig.getDataSet().getHeaderDelimiter()) ? evalConfig.getDataSet().getDataDelimiter() : evalConfig.getDataSet().getHeaderDelimiter();
        evalColumnNames = CommonUtils.getHeaders(evalConfig.getDataSet().getHeaderPath(), delimiter, evalConfig.getDataSet().getSource());
    } else {
        String delimiter = // get header delimiter
        StringUtils.isBlank(evalConfig.getDataSet().getHeaderDelimiter()) ? evalConfig.getDataSet().getDataDelimiter() : evalConfig.getDataSet().getHeaderDelimiter();
        String[] fields = CommonUtils.takeFirstLine(evalConfig.getDataSet().getDataPath(), delimiter, evalConfig.getDataSet().getSource());
        // if first line contains target column name, we guess it is csv format and first line is header.
        String evalTargetColumnName = ((StringUtils.isBlank(evalConfig.getDataSet().getTargetColumnName())) ? modelConfig.getTargetColumnName() : evalConfig.getDataSet().getTargetColumnName());
        if (StringUtils.join(fields, "").contains(evalTargetColumnName)) {
            // first line of data meaning second line in data files excluding first header line
            String[] dataInFirstLine = CommonUtils.takeFirstTwoLines(evalConfig.getDataSet().getDataPath(), delimiter, evalConfig.getDataSet().getSource())[1];
            if (dataInFirstLine != null && fields.length != dataInFirstLine.length) {
                throw new IllegalArgumentException("Eval header length and eval data length are not consistent, please check you header setting and data set setting in eval.");
            }
            // char or / in its name in shifu will be replaced;
            for (int i = 0; i < fields.length; i++) {
                fields[i] = CommonUtils.normColumnName(fields[i]);
            }
            evalColumnNames = fields;
            // for(int i = 0; i < fields.length; i++) {
            // evalColumnNames[i] = CommonUtils.getRelativePigHeaderColumnName(fields[i]);
            // }
            LOG.warn("No header path is provided, we will try to read first line and detect schema.");
            LOG.warn("Schema in ColumnConfig.json are named as first line of data set path.");
        } else {
            LOG.warn("No header path is provided, we will try to read first line and detect schema.");
            LOG.warn("Schema in ColumnConfig.json are named as  index 0, 1, 2, 3 ...");
            LOG.warn("Please make sure weight column and tag column are also taking index as name.");
            evalColumnNames = new String[fields.length];
            for (int i = 0; i < fields.length; i++) {
                evalColumnNames[i] = i + "";
            }
        }
    }
    Set<NSColumn> names = new HashSet<NSColumn>();
    for (String evalColumnName : evalColumnNames) {
        names.add(new NSColumn(evalColumnName));
    }
    String filterExpressions = super.modelConfig.getSegmentFilterExpressionsAsString();
    if (StringUtils.isNotBlank(filterExpressions)) {
        int segFilterSize = CommonUtils.split(filterExpressions, Constants.SHIFU_STATS_FILTER_EXPRESSIONS_DELIMETER).length;
        for (int i = 0; i < segFilterSize; i++) {
            for (int j = 0; j < evalColumnNames.length; j++) {
                names.add(new NSColumn(evalColumnNames[j] + "_" + (i + 1)));
            }
        }
    }
    if (Constants.GENERIC.equalsIgnoreCase(modelConfig.getAlgorithm()) || Constants.TENSORFLOW.equalsIgnoreCase(modelConfig.getAlgorithm())) {
        // TODO correct this logic
        return;
    }
    List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(modelConfig, evalConfig, SourceType.LOCAL, evalConfig.getGbtConvertToProb(), evalConfig.getGbtScoreConvertStrategy());
    if (CollectionUtils.isNotEmpty(models)) {
        validateFinalColumns(evalConfig, this.modelConfig.getModelSetName(), false, this.columnConfigList, names);
    }
    NSColumn targetColumn = new NSColumn(evalConfig.getDataSet().getTargetColumnName());
    if (StringUtils.isNotBlank(evalConfig.getDataSet().getTargetColumnName()) && !names.contains(targetColumn) && !names.contains(new NSColumn(targetColumn.getSimpleName()))) {
        throw new IllegalArgumentException("Target column " + evalConfig.getDataSet().getTargetColumnName() + " does not exist in - " + evalConfig.getDataSet().getHeaderPath());
    }
    NSColumn weightColumn = new NSColumn(evalConfig.getDataSet().getWeightColumnName());
    if (StringUtils.isNotBlank(evalConfig.getDataSet().getWeightColumnName()) && !names.contains(weightColumn) && !names.contains(new NSColumn(weightColumn.getSimpleName()))) {
        throw new IllegalArgumentException("Weight column " + evalConfig.getDataSet().getWeightColumnName() + " does not exist in - " + evalConfig.getDataSet().getHeaderPath());
    }
    List<ModelSpec> subModels = ModelSpecLoaderUtils.loadSubModels(modelConfig, this.columnConfigList, evalConfig, SourceType.LOCAL, evalConfig.getGbtConvertToProb(), evalConfig.getGbtScoreConvertStrategy());
    if (CollectionUtils.isNotEmpty(subModels)) {
        for (ModelSpec modelSpec : subModels) {
            validateFinalColumns(evalConfig, modelSpec.getModelName(), true, modelSpec.getColumnConfigList(), names);
        }
    }
}
Also used : BasicML(org.encog.ml.BasicML) ModelSpec(ml.shifu.shifu.core.model.ModelSpec) NSColumn(ml.shifu.shifu.column.NSColumn) HashSet(java.util.HashSet)

Example 3 with BasicML

use of org.encog.ml.BasicML in project shifu by ShifuML.

the class TrainModelProcessor method runDistributedTrain.

protected int runDistributedTrain() throws IOException, InterruptedException, ClassNotFoundException {
    LOG.info("Started {}distributed training.", isDryTrain ? "dry " : "");
    int status = 0;
    Configuration conf = new Configuration();
    SourceType sourceType = super.getModelConfig().getDataSet().getSource();
    final List<String> args = new ArrayList<String>();
    GridSearch gs = new GridSearch(modelConfig.getTrain().getParams(), modelConfig.getTrain().getGridConfigFileContent());
    prepareCommonParams(gs.hasHyperParam(), args, sourceType);
    String alg = super.getModelConfig().getTrain().getAlgorithm();
    // add tmp models folder to config
    FileSystem fileSystem = ShifuFileUtils.getFileSystemBySourceType(sourceType);
    Path tmpModelsPath = fileSystem.makeQualified(new Path(super.getPathFinder().getPathBySourceType(new Path(Constants.TMP, Constants.DEFAULT_MODELS_TMP_FOLDER), sourceType)));
    args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_TMP_MODELS_FOLDER, tmpModelsPath.toString()));
    int baggingNum = isForVarSelect ? 1 : super.getModelConfig().getBaggingNum();
    if (modelConfig.isClassification()) {
        int classes = modelConfig.getTags().size();
        if (classes == 2) {
            // binary classification, only need one job
            baggingNum = 1;
        } else {
            if (modelConfig.getTrain().isOneVsAll()) {
                // one vs all multiple classification, we need multiple bagging jobs to do ONEVSALL
                baggingNum = modelConfig.getTags().size();
            } else {
            // native classification, using bagging from setting job, no need set here
            }
        }
        if (baggingNum != super.getModelConfig().getBaggingNum()) {
            LOG.warn("'train:baggingNum' is set to {} because of ONEVSALL multiple classification.", baggingNum);
        }
    }
    boolean isKFoldCV = false;
    Integer kCrossValidation = this.modelConfig.getTrain().getNumKFold();
    if (kCrossValidation != null && kCrossValidation > 0) {
        isKFoldCV = true;
        baggingNum = modelConfig.getTrain().getNumKFold();
        if (baggingNum != super.getModelConfig().getBaggingNum() && gs.hasHyperParam()) {
            // if it is grid search mode, then kfold mode is disabled
            LOG.warn("'train:baggingNum' is set to {} because of k-fold cross validation is enabled by 'numKFold' not -1.", baggingNum);
        }
    }
    long start = System.currentTimeMillis();
    boolean isParallel = Boolean.valueOf(Environment.getProperty(Constants.SHIFU_DTRAIN_PARALLEL, SHIFU_DEFAULT_DTRAIN_PARALLEL)).booleanValue();
    GuaguaMapReduceClient guaguaClient;
    int[] inputOutputIndex = DTrainUtils.getInputOutputCandidateCounts(modelConfig.getNormalizeType(), this.columnConfigList);
    int inputNodeCount = inputOutputIndex[0] == 0 ? inputOutputIndex[2] : inputOutputIndex[0];
    int candidateCount = inputOutputIndex[2];
    boolean isAfterVarSelect = (inputOutputIndex[0] != 0);
    // cache all feature list for sampling features
    List<Integer> allFeatures = NormalUtils.getAllFeatureList(this.columnConfigList, isAfterVarSelect);
    if (modelConfig.getNormalize().getIsParquet()) {
        guaguaClient = new GuaguaParquetMapReduceClient();
        // set required field list to make sure we only load selected columns.
        RequiredFieldList requiredFieldList = new RequiredFieldList();
        boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
        for (ColumnConfig columnConfig : super.columnConfigList) {
            if (columnConfig.isTarget()) {
                requiredFieldList.add(new RequiredField(columnConfig.getColumnName(), columnConfig.getColumnNum(), null, DataType.FLOAT));
            } else {
                if (inputNodeCount == candidateCount) {
                    // no any variables are selected
                    if (!columnConfig.isMeta() && !columnConfig.isTarget() && CommonUtils.isGoodCandidate(columnConfig, hasCandidates)) {
                        requiredFieldList.add(new RequiredField(columnConfig.getColumnName(), columnConfig.getColumnNum(), null, DataType.FLOAT));
                    }
                } else {
                    if (!columnConfig.isMeta() && !columnConfig.isTarget() && columnConfig.isFinalSelect()) {
                        requiredFieldList.add(new RequiredField(columnConfig.getColumnName(), columnConfig.getColumnNum(), null, DataType.FLOAT));
                    }
                }
            }
        }
        // weight is added manually
        requiredFieldList.add(new RequiredField("weight", columnConfigList.size(), null, DataType.DOUBLE));
        args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, "parquet.private.pig.required.fields", serializeRequiredFieldList(requiredFieldList)));
        args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, "parquet.private.pig.column.index.access", "true"));
    } else {
        guaguaClient = new GuaguaMapReduceClient();
    }
    int parallelNum = Integer.parseInt(Environment.getProperty(CommonConstants.SHIFU_TRAIN_BAGGING_INPARALLEL, "5"));
    int parallelGroups = 1;
    if (gs.hasHyperParam()) {
        parallelGroups = (gs.getFlattenParams().size() % parallelNum == 0 ? gs.getFlattenParams().size() / parallelNum : gs.getFlattenParams().size() / parallelNum + 1);
        baggingNum = gs.getFlattenParams().size();
        LOG.warn("'train:baggingNum' is set to {} because of grid search enabled by settings in 'train#params'.", gs.getFlattenParams().size());
    } else {
        parallelGroups = baggingNum % parallelNum == 0 ? baggingNum / parallelNum : baggingNum / parallelNum + 1;
    }
    LOG.info("Distributed trainning with baggingNum: {}", baggingNum);
    List<String> progressLogList = new ArrayList<String>(baggingNum);
    boolean isOneJobNotContinuous = false;
    for (int j = 0; j < parallelGroups; j++) {
        int currBags = baggingNum;
        if (gs.hasHyperParam()) {
            if (j == parallelGroups - 1) {
                currBags = gs.getFlattenParams().size() % parallelNum == 0 ? parallelNum : gs.getFlattenParams().size() % parallelNum;
            } else {
                currBags = parallelNum;
            }
        } else {
            if (j == parallelGroups - 1) {
                currBags = baggingNum % parallelNum == 0 ? parallelNum : baggingNum % parallelNum;
            } else {
                currBags = parallelNum;
            }
        }
        for (int k = 0; k < currBags; k++) {
            int i = j * parallelNum + k;
            if (gs.hasHyperParam()) {
                LOG.info("Start the {}th grid search job with params: {}", i, gs.getParams(i));
            } else if (isKFoldCV) {
                LOG.info("Start the {}th k-fold cross validation job with params.", i);
            }
            List<String> localArgs = new ArrayList<String>(args);
            // set name for each bagging job.
            localArgs.add("-n");
            localArgs.add(String.format("Shifu Master-Workers %s Training Iteration: %s id:%s", alg, super.getModelConfig().getModelSetName(), i));
            LOG.info("Start trainer with id: {}", i);
            String modelName = getModelName(i);
            Path modelPath = fileSystem.makeQualified(new Path(super.getPathFinder().getModelsPath(sourceType), modelName));
            Path bModelPath = fileSystem.makeQualified(new Path(super.getPathFinder().getNNBinaryModelsPath(sourceType), modelName));
            // check if job is continuous training, this can be set multiple times and we only get last one
            boolean isContinuous = false;
            if (gs.hasHyperParam()) {
                isContinuous = false;
            } else {
                int intContinuous = checkContinuousTraining(fileSystem, localArgs, modelPath, modelConfig.getTrain().getParams());
                if (intContinuous == -1) {
                    LOG.warn("Model with index {} with size of trees is over treeNum, such training will not be started.", i);
                    continue;
                } else {
                    isContinuous = (intContinuous == 1);
                }
            }
            // training
            if (gs.hasHyperParam() || isKFoldCV) {
                isContinuous = false;
            }
            if (!isContinuous && !isOneJobNotContinuous) {
                isOneJobNotContinuous = true;
                // delete all old models if not continuous
                String srcModelPath = super.getPathFinder().getModelsPath(sourceType);
                String mvModelPath = srcModelPath + "_" + System.currentTimeMillis();
                LOG.info("Old model path has been moved to {}", mvModelPath);
                fileSystem.rename(new Path(srcModelPath), new Path(mvModelPath));
                fileSystem.mkdirs(new Path(srcModelPath));
                FileSystem.getLocal(conf).delete(new Path(super.getPathFinder().getModelsPath(SourceType.LOCAL)), true);
            }
            if (NNConstants.NN_ALG_NAME.equalsIgnoreCase(alg)) {
                // tree related parameters initialization
                Map<String, Object> params = gs.hasHyperParam() ? gs.getParams(i) : this.modelConfig.getTrain().getParams();
                Object fssObj = params.get("FeatureSubsetStrategy");
                FeatureSubsetStrategy featureSubsetStrategy = null;
                double featureSubsetRate = 0d;
                if (fssObj != null) {
                    try {
                        featureSubsetRate = Double.parseDouble(fssObj.toString());
                        // no need validate featureSubsetRate is in (0,1], as already validated in ModelInspector
                        featureSubsetStrategy = null;
                    } catch (NumberFormatException ee) {
                        featureSubsetStrategy = FeatureSubsetStrategy.of(fssObj.toString());
                    }
                } else {
                    LOG.warn("FeatureSubsetStrategy is not set, set to ALL by default.");
                    featureSubsetStrategy = FeatureSubsetStrategy.ALL;
                    featureSubsetRate = 0;
                }
                Set<Integer> subFeatures = null;
                if (isContinuous) {
                    BasicFloatNetwork existingModel = (BasicFloatNetwork) ModelSpecLoaderUtils.getBasicNetwork(ModelSpecLoaderUtils.loadModel(modelConfig, modelPath, ShifuFileUtils.getFileSystemBySourceType(this.modelConfig.getDataSet().getSource())));
                    if (existingModel == null) {
                        subFeatures = new HashSet<Integer>(getSubsamplingFeatures(allFeatures, featureSubsetStrategy, featureSubsetRate, inputNodeCount));
                    } else {
                        subFeatures = existingModel.getFeatureSet();
                    }
                } else {
                    subFeatures = new HashSet<Integer>(getSubsamplingFeatures(allFeatures, featureSubsetStrategy, featureSubsetRate, inputNodeCount));
                }
                if (subFeatures == null || subFeatures.size() == 0) {
                    localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_NN_FEATURE_SUBSET, ""));
                } else {
                    localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_NN_FEATURE_SUBSET, StringUtils.join(subFeatures, ',')));
                    LOG.debug("Size: {}, list: {}.", subFeatures.size(), StringUtils.join(subFeatures, ','));
                }
            }
            localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.GUAGUA_OUTPUT, modelPath.toString()));
            localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, Constants.SHIFU_NN_BINARY_MODEL_PATH, bModelPath.toString()));
            if (gs.hasHyperParam() || isKFoldCV) {
                // k-fold cv need val error
                Path valErrPath = fileSystem.makeQualified(new Path(super.getPathFinder().getValErrorPath(sourceType), "val_error_" + i));
                localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.GS_VALIDATION_ERROR, valErrPath.toString()));
            }
            localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_TRAINER_ID, String.valueOf(i)));
            final String progressLogFile = getProgressLogFile(i);
            progressLogList.add(progressLogFile);
            localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_DTRAIN_PROGRESS_FILE, progressLogFile));
            String hdpVersion = HDPUtils.getHdpVersionForHDP224();
            if (StringUtils.isNotBlank(hdpVersion)) {
                localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, "hdp.version", hdpVersion));
                HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("hdfs-site.xml"), conf);
                HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("core-site.xml"), conf);
                HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("mapred-site.xml"), conf);
                HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("yarn-site.xml"), conf);
            }
            if (isParallel) {
                guaguaClient.addJob(localArgs.toArray(new String[0]));
            } else {
                TailThread tailThread = startTailThread(new String[] { progressLogFile });
                boolean ret = guaguaClient.createJob(localArgs.toArray(new String[0])).waitForCompletion(true);
                status += (ret ? 0 : 1);
                stopTailThread(tailThread);
            }
        }
        if (isParallel) {
            TailThread tailThread = startTailThread(progressLogList.toArray(new String[0]));
            status += guaguaClient.run();
            stopTailThread(tailThread);
        }
    }
    if (isKFoldCV) {
        // k-fold we also copy model files at last, such models can be used for evaluation
        for (int i = 0; i < baggingNum; i++) {
            String modelName = getModelName(i);
            Path modelPath = fileSystem.makeQualified(new Path(super.getPathFinder().getModelsPath(sourceType), modelName));
            if (ShifuFileUtils.getFileSystemBySourceType(sourceType).exists(modelPath)) {
                copyModelToLocal(modelName, modelPath, sourceType);
            } else {
                LOG.warn("Model {} isn't there, maybe job is failed, for bagging it can be ignored.", modelPath.toString());
                status += 1;
            }
        }
        List<Double> valErrs = readAllValidationErrors(sourceType, fileSystem, kCrossValidation);
        double sum = 0d;
        for (Double err : valErrs) {
            sum += err;
        }
        LOG.info("Average validation error for current k-fold cross validation is {}.", sum / valErrs.size());
        LOG.info("K-fold cross validation on distributed training finished in {}ms.", System.currentTimeMillis() - start);
    } else if (gs.hasHyperParam()) {
        // select the best parameter composite in grid search
        LOG.info("Original grid search params: {}", modelConfig.getParams());
        Map<String, Object> params = findBestParams(sourceType, fileSystem, gs);
        // temp copy all models for evaluation
        for (int i = 0; i < baggingNum; i++) {
            String modelName = getModelName(i);
            Path modelPath = fileSystem.makeQualified(new Path(super.getPathFinder().getModelsPath(sourceType), modelName));
            if (ShifuFileUtils.getFileSystemBySourceType(sourceType).exists(modelPath) && (status == 0)) {
                copyModelToLocal(modelName, modelPath, sourceType);
            } else {
                LOG.warn("Model {} isn't there, maybe job is failed, for bagging it can be ignored.", modelPath.toString());
            }
        }
        LOG.info("The best parameters in grid search is {}", params);
        LOG.info("Grid search on distributed training finished in {}ms.", System.currentTimeMillis() - start);
    } else {
        // copy model files at last.
        for (int i = 0; i < baggingNum; i++) {
            String modelName = getModelName(i);
            Path modelPath = fileSystem.makeQualified(new Path(super.getPathFinder().getModelsPath(sourceType), modelName));
            if (ShifuFileUtils.getFileSystemBySourceType(sourceType).exists(modelPath) && (status == 0)) {
                copyModelToLocal(modelName, modelPath, sourceType);
            } else {
                LOG.warn("Model {} isn't there, maybe job is failed, for bagging it can be ignored.", modelPath.toString());
            }
        }
        // copy temp model files, for RF/GBT, not to copy tmp models because of larger space needed, for others
        // by default copy tmp models to local
        boolean copyTmpModelsToLocal = Boolean.TRUE.toString().equalsIgnoreCase(Environment.getProperty(Constants.SHIFU_TMPMODEL_COPYTOLOCAL, "true"));
        if (copyTmpModelsToLocal) {
            copyTmpModelsToLocal(tmpModelsPath, sourceType);
        } else {
            LOG.info("Tmp models are not copied into local, please find them in hdfs path: {}", tmpModelsPath);
        }
        LOG.info("Distributed training finished in {}ms.", System.currentTimeMillis() - start);
    }
    if (CommonUtils.isTreeModel(modelConfig.getAlgorithm())) {
        List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(this.modelConfig, null);
        // compute feature importance and write to local file after models are trained
        Map<Integer, MutablePair<String, Double>> featureImportances = CommonUtils.computeTreeModelFeatureImportance(models);
        String localFsFolder = pathFinder.getLocalFeatureImportanceFolder();
        String localFIPath = pathFinder.getLocalFeatureImportancePath();
        processRollupForFIFiles(localFsFolder, localFIPath);
        CommonUtils.writeFeatureImportance(localFIPath, featureImportances);
    }
    if (status != 0) {
        LOG.error("Error may occurred. There is no model generated. Please check!");
    }
    return status;
}
Also used : Configuration(org.apache.hadoop.conf.Configuration) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) SourceType(ml.shifu.shifu.container.obj.RawSourceData.SourceType) FeatureSubsetStrategy(ml.shifu.shifu.core.dtrain.FeatureSubsetStrategy) BasicML(org.encog.ml.BasicML) GuaguaMapReduceClient(ml.shifu.guagua.mapreduce.GuaguaMapReduceClient) MutablePair(org.apache.commons.lang3.tuple.MutablePair) RequiredField(org.apache.pig.LoadPushDown.RequiredField) FileSystem(org.apache.hadoop.fs.FileSystem) BasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork) Path(org.apache.hadoop.fs.Path) GuaguaParquetMapReduceClient(ml.shifu.shifu.guagua.GuaguaParquetMapReduceClient) GridSearch(ml.shifu.shifu.core.dtrain.gs.GridSearch) RequiredFieldList(org.apache.pig.LoadPushDown.RequiredFieldList)

Example 4 with BasicML

use of org.encog.ml.BasicML in project shifu by ShifuML.

the class PMMLTranslator method build.

public PMML build(List<BasicML> basicMLs) {
    if (basicMLs == null || basicMLs.size() == 0) {
        throw new IllegalArgumentException("Input ml model list is empty.");
    }
    PMML pmml = new PMML();
    // create and add header
    Header header = new Header();
    pmml.setHeader(header);
    header.setCopyright(" Copyright [2013-2018] PayPal Software Foundation\n" + "\n" + " Licensed under the Apache License, Version 2.0 (the \"License\");\n" + " you may not use this file except in compliance with the License.\n" + " You may obtain a copy of the License at\n" + "\n" + "    http://www.apache.org/licenses/LICENSE-2.0\n" + "\n" + " Unless required by applicable law or agreed to in writing, software\n" + " distributed under the License is distributed on an \"AS IS\" BASIS,\n" + " WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n" + " See the License for the specific language governing permissions and\n" + " limitations under the License.\n");
    Application application = new Application();
    header.setApplication(application);
    application.setName("shifu");
    String findContainingJar = JarManager.findContainingJar(TreeEnsemblePMMLTranslator.class);
    JarFile jar = null;
    try {
        jar = new JarFile(findContainingJar);
        final Manifest manifest = jar.getManifest();
        String version = manifest.getMainAttributes().getValue("version");
        application.setVersion(version);
    } catch (Exception e) {
        LOG.warn(e.getMessage());
    } finally {
        if (jar != null) {
            try {
                jar.close();
            } catch (IOException e) {
                LOG.warn(e.getMessage());
            }
        }
    }
    // create and set data dictionary for all bagging models
    pmml.setDataDictionary(this.dataDictionaryCreator.build(null));
    if (isOutBaggingToOne) {
        MiningModel miningModel = new MiningModel();
        miningModel.setMiningSchema(this.miningSchemaCreator.build(null));
        miningModel.setMiningFunction(MiningFunction.fromValue("regression"));
        miningModel.setTargets(((NNPmmlModelCreator) this.modelCreator).createTargets());
        AbstractSpecifCreator minningModelCreator = new MiningModelPmmlCreator(this.specifCreator.getModelConfig(), this.specifCreator.getColumnConfigList());
        minningModelCreator.build(null, miningModel);
        Segmentation seg = new Segmentation();
        miningModel.setSegmentation(seg);
        seg.setMultipleModelMethod(MultipleModelMethod.AVERAGE);
        List<Segment> list = seg.getSegments();
        int idCount = 0;
        for (BasicML basicML : basicMLs) {
            // create model element
            Model tmpmodel = this.modelCreator.build(basicML);
            // create mining schema
            tmpmodel.setMiningSchema(this.miningSchemaCreator.build(basicML));
            // create variable statistical info
            tmpmodel.setModelStats(this.modelStatsCreator.build(basicML));
            // create variable transform
            tmpmodel.setLocalTransformations(this.localTransformationsCreator.build(basicML));
            this.specifCreator.build(basicML, tmpmodel, idCount);
            Segment segment = new Segment();
            segment.setId("Segement" + String.valueOf(idCount));
            segment.setPredicate(new True());
            segment.setModel(tmpmodel);
            list.add(segment);
            idCount++;
        }
        List<Model> models = pmml.getModels();
        models.add(miningModel);
    } else {
        BasicML basicML = basicMLs.get(0);
        // create model element
        Model model = this.modelCreator.build(basicML);
        // create mining schema
        model.setMiningSchema(this.miningSchemaCreator.build(basicML));
        // create variable statistical info
        model.setModelStats(this.modelStatsCreator.build(basicML));
        // create variable transform
        model.setLocalTransformations(this.localTransformationsCreator.build(basicML));
        this.specifCreator.build(basicML, model);
        pmml.addModels(model);
    }
    return pmml;
}
Also used : Segmentation(org.dmg.pmml.mining.Segmentation) AbstractSpecifCreator(ml.shifu.shifu.core.pmml.builder.creator.AbstractSpecifCreator) True(org.dmg.pmml.True) BasicML(org.encog.ml.BasicML) IOException(java.io.IOException) JarFile(java.util.jar.JarFile) Manifest(java.util.jar.Manifest) IOException(java.io.IOException) Segment(org.dmg.pmml.mining.Segment) MiningModelPmmlCreator(ml.shifu.shifu.core.pmml.builder.impl.MiningModelPmmlCreator) Header(org.dmg.pmml.Header) MiningModel(org.dmg.pmml.mining.MiningModel) Model(org.dmg.pmml.Model) MiningModel(org.dmg.pmml.mining.MiningModel) PMML(org.dmg.pmml.PMML) Application(org.dmg.pmml.Application)

Example 5 with BasicML

use of org.encog.ml.BasicML in project shifu by ShifuML.

the class NNMaster method initOrRecoverParams.

private NNParams initOrRecoverParams(MasterContext<NNParams, NNParams> context) {
    // read existing model weights
    NNParams params = null;
    try {
        Path modelPath = new Path(context.getProps().getProperty(CommonConstants.GUAGUA_OUTPUT));
        BasicML basicML = ModelSpecLoaderUtils.loadModel(modelConfig, modelPath, ShifuFileUtils.getFileSystemBySourceType(this.modelConfig.getDataSet().getSource()));
        params = initWeights();
        BasicFloatNetwork existingModel = (BasicFloatNetwork) ModelSpecLoaderUtils.getBasicNetwork(basicML);
        if (existingModel != null) {
            LOG.info("Starting to train model from existing model {}.", modelPath);
            int mspecCompareResult = new NNStructureComparator().compare(this.flatNetwork, existingModel.getFlat());
            if (mspecCompareResult == 0) {
                // same model structure
                params.setWeights(existingModel.getFlat().getWeights());
                this.fixedWeightIndexSet = getFixedWights(fixedLayers);
            } else if (mspecCompareResult == 1) {
                // new model structure is larger than existing one
                this.fixedWeightIndexSet = fitExistingModelIn(existingModel.getFlat(), this.flatNetwork, this.fixedLayers, this.fixedBias);
            } else {
                // new model structure is smaller, couldn't hold existing one
                throw new GuaguaRuntimeException("Network changed for recover or continuous training. " + "New network couldn't hold existing network!");
            }
        } else {
            LOG.info("Starting to train model from scratch.");
        }
    } catch (IOException e) {
        throw new GuaguaRuntimeException(e);
    }
    return params;
}
Also used : Path(org.apache.hadoop.fs.Path) BasicML(org.encog.ml.BasicML) BasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork) GuaguaRuntimeException(ml.shifu.guagua.GuaguaRuntimeException) IOException(java.io.IOException)

Aggregations

BasicML (org.encog.ml.BasicML)23 File (java.io.File)6 BasicNetwork (org.encog.neural.networks.BasicNetwork)5 IOException (java.io.IOException)4 ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)4 BasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)4 PersistBasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork)4 FileSystem (org.apache.hadoop.fs.FileSystem)4 FlatNetwork (org.encog.neural.flat.FlatNetwork)4 ArrayList (java.util.ArrayList)3 NSColumn (ml.shifu.shifu.column.NSColumn)3 ModelRunner (ml.shifu.shifu.core.ModelRunner)3 ModelSpec (ml.shifu.shifu.core.model.ModelSpec)3 MutablePair (org.apache.commons.lang3.tuple.MutablePair)3 Configuration (org.apache.hadoop.conf.Configuration)3 FileStatus (org.apache.hadoop.fs.FileStatus)3 Path (org.apache.hadoop.fs.Path)3 JarFile (java.util.jar.JarFile)2 CaseScoreResult (ml.shifu.shifu.container.CaseScoreResult)2 SourceType (ml.shifu.shifu.container.obj.RawSourceData.SourceType)2