Search in sources :

Example 31 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class VariableSelector method sortByParetoCC.

public List<Tuple> sortByParetoCC(List<ColumnConfig> list) {
    if (this.epsilonArray == null) {
        this.epsilonArray = new double[] { 0.01d, 0.05d };
    }
    List<Tuple> tuples = new ArrayList<VariableSelector.Tuple>();
    for (ColumnConfig columnConfig : list) {
        if (columnConfig != null && columnConfig.getColumnStats() != null) {
            Double ks = columnConfig.getKs();
            Double iv = columnConfig.getIv();
            tuples.add(new Tuple(columnConfig.getColumnNum(), ks == null ? 0d : ks, iv == null ? 0d : 0 - iv));
        }
    }
    return sortByPareto(tuples);
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig)

Example 32 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class ConfusionMatrix method computeConfusionMatixForMultipleClassification.

public void computeConfusionMatixForMultipleClassification(long records) throws IOException {
    SourceType sourceType = evalConfig.getDataSet().getSource();
    List<Scanner> scanners = ShifuFileUtils.getDataScanners(pathFinder.getEvalScorePath(evalConfig, sourceType), sourceType);
    boolean isDir = ShifuFileUtils.isDir(pathFinder.getEvalScorePath(evalConfig, sourceType), sourceType);
    Set<String> tagSet = new HashSet<String>(modelConfig.getFlattenTags(modelConfig.getPosTags(evalConfig), modelConfig.getNegTags(evalConfig)));
    List<Set<String>> tags = modelConfig.getSetTags(modelConfig.getPosTags(evalConfig), modelConfig.getNegTags(evalConfig));
    int classes = tags.size();
    long cnt = 0, invalidTargetCnt = 0;
    ColumnConfig targetColumn = CommonUtils.findTargetColumn(columnConfigList);
    List<Integer> binCountNeg = targetColumn.getBinCountNeg();
    List<Integer> binCountPos = targetColumn.getBinCountPos();
    long[] binCount = new long[classes];
    double[] binRatio = new double[classes];
    long sumCnt = 0L;
    for (int i = 0; i < binCount.length; i++) {
        binCount[i] = binCountNeg.get(i) + binCountPos.get(i);
        sumCnt += binCount[i];
    }
    for (int i = 0; i < binCount.length; i++) {
        binRatio[i] = (binCount[i] * 1d) / sumCnt;
    }
    long[][] confusionMatrix = new long[classes][classes];
    for (Scanner scanner : scanners) {
        while (scanner.hasNext()) {
            if ((++cnt) % 100000 == 0) {
                LOG.info("Loaded " + cnt + " records.");
            }
            if (!isDir && cnt == 1) {
                // if the evaluation score file is the local file, skip the first line since we add header in
                continue;
            }
            // score is separated by default delimiter in our pig output format
            String[] raw = scanner.nextLine().split(Constants.DEFAULT_ESCAPE_DELIMITER);
            String tag = raw[targetColumnIndex];
            if (StringUtils.isBlank(tag) || !tagSet.contains(tag)) {
                invalidTargetCnt += 1;
                continue;
            }
            double[] scores = new double[classes];
            int predictIndex = -1;
            double maxScore = Double.NEGATIVE_INFINITY;
            if (CommonUtils.isTreeModel(modelConfig.getAlgorithm()) && !modelConfig.getTrain().isOneVsAll()) {
                // for RF native classification
                double[] tagCounts = new double[tags.size()];
                for (int i = this.multiClassScore1Index; i < (raw.length - this.metaColumns); i++) {
                    double dd = NumberFormatUtils.getDouble(raw[i], 0d);
                    tagCounts[(int) dd] += 1d;
                }
                double maxVotes = -1d;
                for (int i = 0; i < tagCounts.length; i++) {
                    if (tagCounts[i] > maxVotes) {
                        predictIndex = i;
                        maxScore = maxVotes = tagCounts[i];
                    }
                }
            } else if ((CommonUtils.isTreeModel(modelConfig.getAlgorithm()) || NNConstants.NN_ALG_NAME.equalsIgnoreCase(modelConfig.getAlgorithm())) && modelConfig.getTrain().isOneVsAll()) {
                // for RF, GBT & NN OneVsAll classification
                if (classes == 2) {
                    // for binary classification, only one model is needed.
                    for (int i = this.multiClassScore1Index; i < (1 + this.multiClassScore1Index); i++) {
                        double dd = NumberFormatUtils.getDouble(raw[i], 0d);
                        if (dd > ((1d - binRatio[i - this.multiClassScore1Index]) * scoreScale)) {
                            predictIndex = 0;
                        } else {
                            predictIndex = 1;
                        }
                    }
                } else {
                    // logic is here, per each onevsrest, it may be im-banlanced. for example, class a, b, c, first
                    // is a(1) vs b and c(0), ratio is 10:1, then to compare score, if score > 1/11 it is positive,
                    // check other models to see if still positive in b or c, take the largest one with ratio for
                    // final prediction
                    int[] predClasses = new int[classes];
                    double[] scoress = new double[classes];
                    double[] threhs = new double[classes];
                    for (int i = this.multiClassScore1Index; i < (classes + this.multiClassScore1Index); i++) {
                        double dd = NumberFormatUtils.getDouble(raw[i], 0d);
                        scoress[i - this.multiClassScore1Index] = dd;
                        threhs[i - this.multiClassScore1Index] = (1d - binRatio[i - this.multiClassScore1Index]) * scoreScale;
                        if (dd > ((1d - binRatio[i - this.multiClassScore1Index]) * scoreScale)) {
                            predClasses[i - this.multiClassScore1Index] = 1;
                        }
                    }
                    double maxRatio = -1d;
                    double maxPositiveRatio = -1d;
                    int maxRatioIndex = -1;
                    for (int i = 0; i < binCount.length; i++) {
                        if (binRatio[i] > maxRatio) {
                            maxRatio = binRatio[i];
                            maxRatioIndex = i;
                        }
                        // if has positive, choose one with highest ratio
                        if (predClasses[i] == 1) {
                            if (binRatio[i] > maxPositiveRatio) {
                                maxPositiveRatio = binRatio[i];
                                predictIndex = i;
                            }
                        }
                    }
                    // no any positive, take the largest one
                    if (maxPositiveRatio < 0d) {
                        predictIndex = maxRatioIndex;
                    }
                }
            } else {
                if (classes == 2) {
                    // for binary classification, only one model is needed.
                    for (int i = this.multiClassScore1Index; i < (1 + this.multiClassScore1Index); i++) {
                        double dd = NumberFormatUtils.getDouble(raw[i], 0d);
                        if (dd > ((1d - binRatio[i - this.multiClassScore1Index]) * scoreScale)) {
                            predictIndex = 0;
                        } else {
                            predictIndex = 1;
                        }
                    }
                } else {
                    // 1,2,3 4,5,6: 1,2,3 is model 0, 4,5,6 is model 1
                    for (int i = 0; i < classes; i++) {
                        for (int j = 0; j < multiClassModelCnt; j++) {
                            double dd = NumberFormatUtils.getDouble(raw[this.multiClassScore1Index + j * classes + i], 0d);
                            scores[i] += dd;
                        }
                        scores[i] /= multiClassModelCnt;
                        if (scores[i] > maxScore) {
                            predictIndex = i;
                            maxScore = scores[i];
                        }
                    }
                }
            }
            int tagIndex = -1;
            for (int i = 0; i < tags.size(); i++) {
                if (tags.get(i).contains(tag)) {
                    tagIndex = i;
                    break;
                }
            }
            confusionMatrix[tagIndex][predictIndex] += 1L;
        }
        scanner.close();
    }
    LOG.info("Totally loading {} records with invalid target records {} in eval {}.", cnt, invalidTargetCnt, evalConfig.getName());
    writeToConfMatrixFile(tags, confusionMatrix);
    // print conf matrix
    LOG.info("Multiple classification confustion matrix:");
    LOG.info(String.format("%15s: %20s", "     ", tags.toString()));
    for (int i = 0; i < confusionMatrix.length; i++) {
        LOG.info(String.format("%15s: %20s", tags.get(i), Arrays.toString(confusionMatrix[i])));
    }
}
Also used : Scanner(java.util.Scanner) HashSet(java.util.HashSet) Set(java.util.Set) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) SourceType(ml.shifu.shifu.container.obj.RawSourceData.SourceType) HashSet(java.util.HashSet)

Example 33 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class UpdateBinningInfoMapper method populateStats.

private void populateStats(String[] units, String tag, Double weight, int columnIndex, int newCCIndex) {
    ColumnConfig columnConfig = this.columnConfigList.get(columnIndex);
    CountAndFrequentItems countAndFrequentItems = this.variableCountMap.get(newCCIndex);
    if (countAndFrequentItems == null) {
        countAndFrequentItems = new CountAndFrequentItems();
        this.variableCountMap.put(newCCIndex, countAndFrequentItems);
    }
    countAndFrequentItems.offer(this.missingOrInvalidValues, units[columnIndex]);
    boolean isMissingValue = false;
    boolean isInvalidValue = false;
    BinningInfoWritable binningInfoWritable = this.columnBinningInfo.get(newCCIndex);
    if (binningInfoWritable == null) {
        return;
    }
    binningInfoWritable.setTotalCount(binningInfoWritable.getTotalCount() + 1L);
    if (columnConfig.isHybrid()) {
        int binNum = 0;
        if (units[columnIndex] == null || missingOrInvalidValues.contains(units[columnIndex].toLowerCase())) {
            isMissingValue = true;
        }
        String str = units[columnIndex];
        double douVal = BinUtils.parseNumber(str);
        Double hybridThreshold = columnConfig.getHybridThreshold();
        if (hybridThreshold == null) {
            hybridThreshold = Double.NEGATIVE_INFINITY;
        }
        // douVal < hybridThreshould which will also be set to category
        boolean isCategory = Double.isNaN(douVal) || douVal < hybridThreshold;
        boolean isNumber = !Double.isNaN(douVal);
        if (isMissingValue) {
            binningInfoWritable.setMissingCount(binningInfoWritable.getMissingCount() + 1L);
            binNum = binningInfoWritable.getBinCategories().size() + binningInfoWritable.getBinBoundaries().size();
        } else if (isCategory) {
            // get categorical bin number in category list
            binNum = quickLocateCategoricalBin(this.categoricalBinMap.get(newCCIndex), str);
            if (binNum < 0) {
                isInvalidValue = true;
            }
            if (isInvalidValue) {
                // the same as missing count
                binningInfoWritable.setMissingCount(binningInfoWritable.getMissingCount() + 1L);
                binNum = binningInfoWritable.getBinCategories().size() + binningInfoWritable.getBinBoundaries().size();
            } else {
                // if real category value, binNum should + binBoundaries.size
                binNum += binningInfoWritable.getBinBoundaries().size();
                ;
            }
        } else if (isNumber) {
            binNum = getBinNum(binningInfoWritable.getBinBoundaries(), douVal);
            if (binNum == -1) {
                throw new RuntimeException("binNum should not be -1 to this step.");
            }
            // other stats are treated as numerical features
            binningInfoWritable.setSum(binningInfoWritable.getSum() + douVal);
            double squaredVal = douVal * douVal;
            binningInfoWritable.setSquaredSum(binningInfoWritable.getSquaredSum() + squaredVal);
            binningInfoWritable.setTripleSum(binningInfoWritable.getTripleSum() + squaredVal * douVal);
            binningInfoWritable.setQuarticSum(binningInfoWritable.getQuarticSum() + squaredVal * squaredVal);
            if (Double.compare(binningInfoWritable.getMax(), douVal) < 0) {
                binningInfoWritable.setMax(douVal);
            }
            if (Double.compare(binningInfoWritable.getMin(), douVal) > 0) {
                binningInfoWritable.setMin(douVal);
            }
        }
        if (posTags.contains(tag)) {
            binningInfoWritable.getBinCountPos()[binNum] += 1L;
            binningInfoWritable.getBinWeightPos()[binNum] += weight;
        } else if (negTags.contains(tag)) {
            binningInfoWritable.getBinCountNeg()[binNum] += 1L;
            binningInfoWritable.getBinWeightNeg()[binNum] += weight;
        }
    } else if (columnConfig.isCategorical()) {
        int lastBinIndex = binningInfoWritable.getBinCategories().size();
        int binNum = 0;
        if (units[columnIndex] == null || missingOrInvalidValues.contains(units[columnIndex].toLowerCase())) {
            isMissingValue = true;
        } else {
            String str = units[columnIndex];
            binNum = quickLocateCategoricalBin(this.categoricalBinMap.get(newCCIndex), str);
            if (binNum < 0) {
                isInvalidValue = true;
            }
        }
        if (isInvalidValue || isMissingValue) {
            binningInfoWritable.setMissingCount(binningInfoWritable.getMissingCount() + 1L);
            binNum = lastBinIndex;
        }
        if (modelConfig.isRegression()) {
            if (posTags.contains(tag)) {
                binningInfoWritable.getBinCountPos()[binNum] += 1L;
                binningInfoWritable.getBinWeightPos()[binNum] += weight;
            } else if (negTags.contains(tag)) {
                binningInfoWritable.getBinCountNeg()[binNum] += 1L;
                binningInfoWritable.getBinWeightNeg()[binNum] += weight;
            }
        } else {
            // for multiple classification, set bin count to BinCountPos and leave BinCountNeg empty
            binningInfoWritable.getBinCountPos()[binNum] += 1L;
            binningInfoWritable.getBinWeightPos()[binNum] += weight;
        }
    } else if (columnConfig.isNumerical()) {
        int lastBinIndex = binningInfoWritable.getBinBoundaries().size();
        double douVal = 0.0;
        if (units[columnIndex] == null || units[columnIndex].length() == 0 || missingOrInvalidValues.contains(units[columnIndex].toLowerCase())) {
            isMissingValue = true;
        } else {
            try {
                douVal = Double.parseDouble(units[columnIndex].trim());
            } catch (Exception e) {
                isInvalidValue = true;
            }
        }
        // add logic the same as CalculateNewStatsUDF
        if (Double.compare(douVal, modelConfig.getNumericalValueThreshold()) > 0) {
            isInvalidValue = true;
        }
        if (isInvalidValue || isMissingValue) {
            binningInfoWritable.setMissingCount(binningInfoWritable.getMissingCount() + 1L);
            if (modelConfig.isRegression()) {
                if (posTags.contains(tag)) {
                    binningInfoWritable.getBinCountPos()[lastBinIndex] += 1L;
                    binningInfoWritable.getBinWeightPos()[lastBinIndex] += weight;
                } else if (negTags.contains(tag)) {
                    binningInfoWritable.getBinCountNeg()[lastBinIndex] += 1L;
                    binningInfoWritable.getBinWeightNeg()[lastBinIndex] += weight;
                }
            }
        } else {
            // For invalid or missing values, no need update sum, squaredSum, max, min ...
            int binNum = getBinNum(binningInfoWritable.getBinBoundaries(), units[columnIndex]);
            if (binNum == -1) {
                throw new RuntimeException("binNum should not be -1 to this step.");
            }
            if (modelConfig.isRegression()) {
                if (posTags.contains(tag)) {
                    binningInfoWritable.getBinCountPos()[binNum] += 1L;
                    binningInfoWritable.getBinWeightPos()[binNum] += weight;
                } else if (negTags.contains(tag)) {
                    binningInfoWritable.getBinCountNeg()[binNum] += 1L;
                    binningInfoWritable.getBinWeightNeg()[binNum] += weight;
                }
            }
            binningInfoWritable.setSum(binningInfoWritable.getSum() + douVal);
            double squaredVal = douVal * douVal;
            binningInfoWritable.setSquaredSum(binningInfoWritable.getSquaredSum() + squaredVal);
            binningInfoWritable.setTripleSum(binningInfoWritable.getTripleSum() + squaredVal * douVal);
            binningInfoWritable.setQuarticSum(binningInfoWritable.getQuarticSum() + squaredVal * squaredVal);
            if (Double.compare(binningInfoWritable.getMax(), douVal) < 0) {
                binningInfoWritable.setMax(douVal);
            }
            if (Double.compare(binningInfoWritable.getMin(), douVal) > 0) {
                binningInfoWritable.setMin(douVal);
            }
        }
    }
}
Also used : CountAndFrequentItems(ml.shifu.shifu.core.autotype.AutoTypeDistinctCountMapper.CountAndFrequentItems) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) IOException(java.io.IOException) FileNotFoundException(java.io.FileNotFoundException)

Example 34 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class UpdateBinningInfoMapper method loadColumnBinningInfo.

/**
 * Load and initialize column binning info object.
 */
private void loadColumnBinningInfo() throws FileNotFoundException, IOException {
    BufferedReader reader = null;
    try {
        reader = new BufferedReader(new InputStreamReader(new FileInputStream(Constants.BINNING_INFO_FILE_NAME), Charset.forName("UTF-8")));
        String line = reader.readLine();
        while (line != null && line.length() != 0) {
            LOG.debug("line is {}", line);
            // here just use String.split for just two columns
            String[] cols = Lists.newArrayList(this.splitter.split(line)).toArray(new String[0]);
            if (cols != null && cols.length >= 2) {
                Integer rawColumnNum = Integer.parseInt(cols[0]);
                BinningInfoWritable binningInfo = new BinningInfoWritable();
                int corrColumnNum = rawColumnNum;
                if (rawColumnNum >= this.columnConfigList.size()) {
                    corrColumnNum = rawColumnNum % this.columnConfigList.size();
                }
                binningInfo.setColumnNum(rawColumnNum);
                ColumnConfig columnConfig = this.columnConfigList.get(corrColumnNum);
                int binSize = 0;
                if (columnConfig.isHybrid()) {
                    binningInfo.setNumeric(true);
                    String[] splits = CommonUtils.split(cols[1], Constants.HYBRID_BIN_STR_DILIMETER);
                    List<Double> list = new ArrayList<Double>();
                    for (String startElement : BIN_BOUNDARY_SPLITTER.split(splits[0])) {
                        list.add(Double.valueOf(startElement));
                    }
                    binningInfo.setBinBoundaries(list);
                    List<String> cateList = new ArrayList<String>();
                    Map<String, Integer> map = this.categoricalBinMap.get(rawColumnNum);
                    if (map == null) {
                        map = new HashMap<String, Integer>();
                        this.categoricalBinMap.put(rawColumnNum, map);
                    }
                    int index = 0;
                    if (!StringUtils.isBlank(splits[1])) {
                        for (String startElement : BIN_BOUNDARY_SPLITTER.split(splits[1])) {
                            cateList.add(startElement);
                            map.put(startElement, index++);
                        }
                    }
                    binningInfo.setBinCategories(cateList);
                    binSize = list.size() + cateList.size();
                } else if (columnConfig.isNumerical()) {
                    binningInfo.setNumeric(true);
                    List<Double> list = new ArrayList<Double>();
                    for (String startElement : BIN_BOUNDARY_SPLITTER.split(cols[1])) {
                        list.add(Double.valueOf(startElement));
                    }
                    binningInfo.setBinBoundaries(list);
                    binSize = list.size();
                } else {
                    binningInfo.setNumeric(false);
                    List<String> list = new ArrayList<String>();
                    Map<String, Integer> map = this.categoricalBinMap.get(rawColumnNum);
                    if (map == null) {
                        map = new HashMap<String, Integer>();
                        this.categoricalBinMap.put(rawColumnNum, map);
                    }
                    int index = 0;
                    if (!StringUtils.isBlank(cols[1])) {
                        for (String startElement : BIN_BOUNDARY_SPLITTER.split(cols[1])) {
                            list.add(startElement);
                            map.put(startElement, index++);
                        }
                    }
                    binningInfo.setBinCategories(list);
                    binSize = list.size();
                }
                long[] binCountPos = new long[binSize + 1];
                binningInfo.setBinCountPos(binCountPos);
                long[] binCountNeg = new long[binSize + 1];
                binningInfo.setBinCountNeg(binCountNeg);
                double[] binWeightPos = new double[binSize + 1];
                binningInfo.setBinWeightPos(binWeightPos);
                double[] binWeightNeg = new double[binSize + 1];
                binningInfo.setBinWeightNeg(binWeightNeg);
                LOG.debug("column num {}  and info {}", rawColumnNum, binningInfo);
                this.columnBinningInfo.put(rawColumnNum, binningInfo);
            }
            line = reader.readLine();
        }
    } finally {
        if (reader != null) {
            reader.close();
        }
    }
}
Also used : InputStreamReader(java.io.InputStreamReader) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) FileInputStream(java.io.FileInputStream) BufferedReader(java.io.BufferedReader) ArrayList(java.util.ArrayList) List(java.util.List) HashMap(java.util.HashMap) Map(java.util.Map)

Example 35 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class TrainModelProcessor method runDistributedTrain.

protected int runDistributedTrain() throws IOException, InterruptedException, ClassNotFoundException {
    LOG.info("Started {}distributed training.", isDryTrain ? "dry " : "");
    int status = 0;
    Configuration conf = new Configuration();
    SourceType sourceType = super.getModelConfig().getDataSet().getSource();
    final List<String> args = new ArrayList<String>();
    GridSearch gs = new GridSearch(modelConfig.getTrain().getParams(), modelConfig.getTrain().getGridConfigFileContent());
    prepareCommonParams(gs.hasHyperParam(), args, sourceType);
    String alg = super.getModelConfig().getTrain().getAlgorithm();
    // add tmp models folder to config
    FileSystem fileSystem = ShifuFileUtils.getFileSystemBySourceType(sourceType);
    Path tmpModelsPath = fileSystem.makeQualified(new Path(super.getPathFinder().getPathBySourceType(new Path(Constants.TMP, Constants.DEFAULT_MODELS_TMP_FOLDER), sourceType)));
    args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_TMP_MODELS_FOLDER, tmpModelsPath.toString()));
    int baggingNum = isForVarSelect ? 1 : super.getModelConfig().getBaggingNum();
    if (modelConfig.isClassification()) {
        int classes = modelConfig.getTags().size();
        if (classes == 2) {
            // binary classification, only need one job
            baggingNum = 1;
        } else {
            if (modelConfig.getTrain().isOneVsAll()) {
                // one vs all multiple classification, we need multiple bagging jobs to do ONEVSALL
                baggingNum = modelConfig.getTags().size();
            } else {
            // native classification, using bagging from setting job, no need set here
            }
        }
        if (baggingNum != super.getModelConfig().getBaggingNum()) {
            LOG.warn("'train:baggingNum' is set to {} because of ONEVSALL multiple classification.", baggingNum);
        }
    }
    boolean isKFoldCV = false;
    Integer kCrossValidation = this.modelConfig.getTrain().getNumKFold();
    if (kCrossValidation != null && kCrossValidation > 0) {
        isKFoldCV = true;
        baggingNum = modelConfig.getTrain().getNumKFold();
        if (baggingNum != super.getModelConfig().getBaggingNum() && gs.hasHyperParam()) {
            // if it is grid search mode, then kfold mode is disabled
            LOG.warn("'train:baggingNum' is set to {} because of k-fold cross validation is enabled by 'numKFold' not -1.", baggingNum);
        }
    }
    long start = System.currentTimeMillis();
    boolean isParallel = Boolean.valueOf(Environment.getProperty(Constants.SHIFU_DTRAIN_PARALLEL, SHIFU_DEFAULT_DTRAIN_PARALLEL)).booleanValue();
    GuaguaMapReduceClient guaguaClient;
    int[] inputOutputIndex = DTrainUtils.getInputOutputCandidateCounts(modelConfig.getNormalizeType(), this.columnConfigList);
    int inputNodeCount = inputOutputIndex[0] == 0 ? inputOutputIndex[2] : inputOutputIndex[0];
    int candidateCount = inputOutputIndex[2];
    boolean isAfterVarSelect = (inputOutputIndex[0] != 0);
    // cache all feature list for sampling features
    List<Integer> allFeatures = NormalUtils.getAllFeatureList(this.columnConfigList, isAfterVarSelect);
    if (modelConfig.getNormalize().getIsParquet()) {
        guaguaClient = new GuaguaParquetMapReduceClient();
        // set required field list to make sure we only load selected columns.
        RequiredFieldList requiredFieldList = new RequiredFieldList();
        boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
        for (ColumnConfig columnConfig : super.columnConfigList) {
            if (columnConfig.isTarget()) {
                requiredFieldList.add(new RequiredField(columnConfig.getColumnName(), columnConfig.getColumnNum(), null, DataType.FLOAT));
            } else {
                if (inputNodeCount == candidateCount) {
                    // no any variables are selected
                    if (!columnConfig.isMeta() && !columnConfig.isTarget() && CommonUtils.isGoodCandidate(columnConfig, hasCandidates)) {
                        requiredFieldList.add(new RequiredField(columnConfig.getColumnName(), columnConfig.getColumnNum(), null, DataType.FLOAT));
                    }
                } else {
                    if (!columnConfig.isMeta() && !columnConfig.isTarget() && columnConfig.isFinalSelect()) {
                        requiredFieldList.add(new RequiredField(columnConfig.getColumnName(), columnConfig.getColumnNum(), null, DataType.FLOAT));
                    }
                }
            }
        }
        // weight is added manually
        requiredFieldList.add(new RequiredField("weight", columnConfigList.size(), null, DataType.DOUBLE));
        args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, "parquet.private.pig.required.fields", serializeRequiredFieldList(requiredFieldList)));
        args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, "parquet.private.pig.column.index.access", "true"));
    } else {
        guaguaClient = new GuaguaMapReduceClient();
    }
    int parallelNum = Integer.parseInt(Environment.getProperty(CommonConstants.SHIFU_TRAIN_BAGGING_INPARALLEL, "5"));
    int parallelGroups = 1;
    if (gs.hasHyperParam()) {
        parallelGroups = (gs.getFlattenParams().size() % parallelNum == 0 ? gs.getFlattenParams().size() / parallelNum : gs.getFlattenParams().size() / parallelNum + 1);
        baggingNum = gs.getFlattenParams().size();
        LOG.warn("'train:baggingNum' is set to {} because of grid search enabled by settings in 'train#params'.", gs.getFlattenParams().size());
    } else {
        parallelGroups = baggingNum % parallelNum == 0 ? baggingNum / parallelNum : baggingNum / parallelNum + 1;
    }
    LOG.info("Distributed trainning with baggingNum: {}", baggingNum);
    List<String> progressLogList = new ArrayList<String>(baggingNum);
    boolean isOneJobNotContinuous = false;
    for (int j = 0; j < parallelGroups; j++) {
        int currBags = baggingNum;
        if (gs.hasHyperParam()) {
            if (j == parallelGroups - 1) {
                currBags = gs.getFlattenParams().size() % parallelNum == 0 ? parallelNum : gs.getFlattenParams().size() % parallelNum;
            } else {
                currBags = parallelNum;
            }
        } else {
            if (j == parallelGroups - 1) {
                currBags = baggingNum % parallelNum == 0 ? parallelNum : baggingNum % parallelNum;
            } else {
                currBags = parallelNum;
            }
        }
        for (int k = 0; k < currBags; k++) {
            int i = j * parallelNum + k;
            if (gs.hasHyperParam()) {
                LOG.info("Start the {}th grid search job with params: {}", i, gs.getParams(i));
            } else if (isKFoldCV) {
                LOG.info("Start the {}th k-fold cross validation job with params.", i);
            }
            List<String> localArgs = new ArrayList<String>(args);
            // set name for each bagging job.
            localArgs.add("-n");
            localArgs.add(String.format("Shifu Master-Workers %s Training Iteration: %s id:%s", alg, super.getModelConfig().getModelSetName(), i));
            LOG.info("Start trainer with id: {}", i);
            String modelName = getModelName(i);
            Path modelPath = fileSystem.makeQualified(new Path(super.getPathFinder().getModelsPath(sourceType), modelName));
            Path bModelPath = fileSystem.makeQualified(new Path(super.getPathFinder().getNNBinaryModelsPath(sourceType), modelName));
            // check if job is continuous training, this can be set multiple times and we only get last one
            boolean isContinuous = false;
            if (gs.hasHyperParam()) {
                isContinuous = false;
            } else {
                int intContinuous = checkContinuousTraining(fileSystem, localArgs, modelPath, modelConfig.getTrain().getParams());
                if (intContinuous == -1) {
                    LOG.warn("Model with index {} with size of trees is over treeNum, such training will not be started.", i);
                    continue;
                } else {
                    isContinuous = (intContinuous == 1);
                }
            }
            // training
            if (gs.hasHyperParam() || isKFoldCV) {
                isContinuous = false;
            }
            if (!isContinuous && !isOneJobNotContinuous) {
                isOneJobNotContinuous = true;
                // delete all old models if not continuous
                String srcModelPath = super.getPathFinder().getModelsPath(sourceType);
                String mvModelPath = srcModelPath + "_" + System.currentTimeMillis();
                LOG.info("Old model path has been moved to {}", mvModelPath);
                fileSystem.rename(new Path(srcModelPath), new Path(mvModelPath));
                fileSystem.mkdirs(new Path(srcModelPath));
                FileSystem.getLocal(conf).delete(new Path(super.getPathFinder().getModelsPath(SourceType.LOCAL)), true);
            }
            if (NNConstants.NN_ALG_NAME.equalsIgnoreCase(alg)) {
                // tree related parameters initialization
                Map<String, Object> params = gs.hasHyperParam() ? gs.getParams(i) : this.modelConfig.getTrain().getParams();
                Object fssObj = params.get("FeatureSubsetStrategy");
                FeatureSubsetStrategy featureSubsetStrategy = null;
                double featureSubsetRate = 0d;
                if (fssObj != null) {
                    try {
                        featureSubsetRate = Double.parseDouble(fssObj.toString());
                        // no need validate featureSubsetRate is in (0,1], as already validated in ModelInspector
                        featureSubsetStrategy = null;
                    } catch (NumberFormatException ee) {
                        featureSubsetStrategy = FeatureSubsetStrategy.of(fssObj.toString());
                    }
                } else {
                    LOG.warn("FeatureSubsetStrategy is not set, set to ALL by default.");
                    featureSubsetStrategy = FeatureSubsetStrategy.ALL;
                    featureSubsetRate = 0;
                }
                Set<Integer> subFeatures = null;
                if (isContinuous) {
                    BasicFloatNetwork existingModel = (BasicFloatNetwork) ModelSpecLoaderUtils.getBasicNetwork(ModelSpecLoaderUtils.loadModel(modelConfig, modelPath, ShifuFileUtils.getFileSystemBySourceType(this.modelConfig.getDataSet().getSource())));
                    if (existingModel == null) {
                        subFeatures = new HashSet<Integer>(getSubsamplingFeatures(allFeatures, featureSubsetStrategy, featureSubsetRate, inputNodeCount));
                    } else {
                        subFeatures = existingModel.getFeatureSet();
                    }
                } else {
                    subFeatures = new HashSet<Integer>(getSubsamplingFeatures(allFeatures, featureSubsetStrategy, featureSubsetRate, inputNodeCount));
                }
                if (subFeatures == null || subFeatures.size() == 0) {
                    localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_NN_FEATURE_SUBSET, ""));
                } else {
                    localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_NN_FEATURE_SUBSET, StringUtils.join(subFeatures, ',')));
                    LOG.debug("Size: {}, list: {}.", subFeatures.size(), StringUtils.join(subFeatures, ','));
                }
            }
            localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.GUAGUA_OUTPUT, modelPath.toString()));
            localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, Constants.SHIFU_NN_BINARY_MODEL_PATH, bModelPath.toString()));
            if (gs.hasHyperParam() || isKFoldCV) {
                // k-fold cv need val error
                Path valErrPath = fileSystem.makeQualified(new Path(super.getPathFinder().getValErrorPath(sourceType), "val_error_" + i));
                localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.GS_VALIDATION_ERROR, valErrPath.toString()));
            }
            localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_TRAINER_ID, String.valueOf(i)));
            final String progressLogFile = getProgressLogFile(i);
            progressLogList.add(progressLogFile);
            localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_DTRAIN_PROGRESS_FILE, progressLogFile));
            String hdpVersion = HDPUtils.getHdpVersionForHDP224();
            if (StringUtils.isNotBlank(hdpVersion)) {
                localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, "hdp.version", hdpVersion));
                HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("hdfs-site.xml"), conf);
                HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("core-site.xml"), conf);
                HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("mapred-site.xml"), conf);
                HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("yarn-site.xml"), conf);
            }
            if (isParallel) {
                guaguaClient.addJob(localArgs.toArray(new String[0]));
            } else {
                TailThread tailThread = startTailThread(new String[] { progressLogFile });
                boolean ret = guaguaClient.createJob(localArgs.toArray(new String[0])).waitForCompletion(true);
                status += (ret ? 0 : 1);
                stopTailThread(tailThread);
            }
        }
        if (isParallel) {
            TailThread tailThread = startTailThread(progressLogList.toArray(new String[0]));
            status += guaguaClient.run();
            stopTailThread(tailThread);
        }
    }
    if (isKFoldCV) {
        // k-fold we also copy model files at last, such models can be used for evaluation
        for (int i = 0; i < baggingNum; i++) {
            String modelName = getModelName(i);
            Path modelPath = fileSystem.makeQualified(new Path(super.getPathFinder().getModelsPath(sourceType), modelName));
            if (ShifuFileUtils.getFileSystemBySourceType(sourceType).exists(modelPath)) {
                copyModelToLocal(modelName, modelPath, sourceType);
            } else {
                LOG.warn("Model {} isn't there, maybe job is failed, for bagging it can be ignored.", modelPath.toString());
                status += 1;
            }
        }
        List<Double> valErrs = readAllValidationErrors(sourceType, fileSystem, kCrossValidation);
        double sum = 0d;
        for (Double err : valErrs) {
            sum += err;
        }
        LOG.info("Average validation error for current k-fold cross validation is {}.", sum / valErrs.size());
        LOG.info("K-fold cross validation on distributed training finished in {}ms.", System.currentTimeMillis() - start);
    } else if (gs.hasHyperParam()) {
        // select the best parameter composite in grid search
        LOG.info("Original grid search params: {}", modelConfig.getParams());
        Map<String, Object> params = findBestParams(sourceType, fileSystem, gs);
        // temp copy all models for evaluation
        for (int i = 0; i < baggingNum; i++) {
            String modelName = getModelName(i);
            Path modelPath = fileSystem.makeQualified(new Path(super.getPathFinder().getModelsPath(sourceType), modelName));
            if (ShifuFileUtils.getFileSystemBySourceType(sourceType).exists(modelPath) && (status == 0)) {
                copyModelToLocal(modelName, modelPath, sourceType);
            } else {
                LOG.warn("Model {} isn't there, maybe job is failed, for bagging it can be ignored.", modelPath.toString());
            }
        }
        LOG.info("The best parameters in grid search is {}", params);
        LOG.info("Grid search on distributed training finished in {}ms.", System.currentTimeMillis() - start);
    } else {
        // copy model files at last.
        for (int i = 0; i < baggingNum; i++) {
            String modelName = getModelName(i);
            Path modelPath = fileSystem.makeQualified(new Path(super.getPathFinder().getModelsPath(sourceType), modelName));
            if (ShifuFileUtils.getFileSystemBySourceType(sourceType).exists(modelPath) && (status == 0)) {
                copyModelToLocal(modelName, modelPath, sourceType);
            } else {
                LOG.warn("Model {} isn't there, maybe job is failed, for bagging it can be ignored.", modelPath.toString());
            }
        }
        // copy temp model files, for RF/GBT, not to copy tmp models because of larger space needed, for others
        // by default copy tmp models to local
        boolean copyTmpModelsToLocal = Boolean.TRUE.toString().equalsIgnoreCase(Environment.getProperty(Constants.SHIFU_TMPMODEL_COPYTOLOCAL, "true"));
        if (copyTmpModelsToLocal) {
            copyTmpModelsToLocal(tmpModelsPath, sourceType);
        } else {
            LOG.info("Tmp models are not copied into local, please find them in hdfs path: {}", tmpModelsPath);
        }
        LOG.info("Distributed training finished in {}ms.", System.currentTimeMillis() - start);
    }
    if (CommonUtils.isTreeModel(modelConfig.getAlgorithm())) {
        List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(this.modelConfig, null);
        // compute feature importance and write to local file after models are trained
        Map<Integer, MutablePair<String, Double>> featureImportances = CommonUtils.computeTreeModelFeatureImportance(models);
        String localFsFolder = pathFinder.getLocalFeatureImportanceFolder();
        String localFIPath = pathFinder.getLocalFeatureImportancePath();
        processRollupForFIFiles(localFsFolder, localFIPath);
        CommonUtils.writeFeatureImportance(localFIPath, featureImportances);
    }
    if (status != 0) {
        LOG.error("Error may occurred. There is no model generated. Please check!");
    }
    return status;
}
Also used : Configuration(org.apache.hadoop.conf.Configuration) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) SourceType(ml.shifu.shifu.container.obj.RawSourceData.SourceType) FeatureSubsetStrategy(ml.shifu.shifu.core.dtrain.FeatureSubsetStrategy) BasicML(org.encog.ml.BasicML) GuaguaMapReduceClient(ml.shifu.guagua.mapreduce.GuaguaMapReduceClient) MutablePair(org.apache.commons.lang3.tuple.MutablePair) RequiredField(org.apache.pig.LoadPushDown.RequiredField) FileSystem(org.apache.hadoop.fs.FileSystem) BasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork) Path(org.apache.hadoop.fs.Path) GuaguaParquetMapReduceClient(ml.shifu.shifu.guagua.GuaguaParquetMapReduceClient) GridSearch(ml.shifu.shifu.core.dtrain.gs.GridSearch) RequiredFieldList(org.apache.pig.LoadPushDown.RequiredFieldList)

Aggregations

ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)131 ArrayList (java.util.ArrayList)36 Test (org.testng.annotations.Test)17 IOException (java.io.IOException)16 HashMap (java.util.HashMap)12 Tuple (org.apache.pig.data.Tuple)10 File (java.io.File)8 NSColumn (ml.shifu.shifu.column.NSColumn)8 ModelConfig (ml.shifu.shifu.container.obj.ModelConfig)8 ShifuException (ml.shifu.shifu.exception.ShifuException)8 Path (org.apache.hadoop.fs.Path)8 List (java.util.List)7 Scanner (java.util.Scanner)7 DataBag (org.apache.pig.data.DataBag)7 SourceType (ml.shifu.shifu.container.obj.RawSourceData.SourceType)5 BasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)5 TrainingDataSet (ml.shifu.shifu.core.dvarsel.dataset.TrainingDataSet)5 BasicMLData (org.encog.ml.data.basic.BasicMLData)5 BufferedWriter (java.io.BufferedWriter)3 FileInputStream (java.io.FileInputStream)3