use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class VariableSelector method sortByParetoCC.
public List<Tuple> sortByParetoCC(List<ColumnConfig> list) {
if (this.epsilonArray == null) {
this.epsilonArray = new double[] { 0.01d, 0.05d };
}
List<Tuple> tuples = new ArrayList<VariableSelector.Tuple>();
for (ColumnConfig columnConfig : list) {
if (columnConfig != null && columnConfig.getColumnStats() != null) {
Double ks = columnConfig.getKs();
Double iv = columnConfig.getIv();
tuples.add(new Tuple(columnConfig.getColumnNum(), ks == null ? 0d : ks, iv == null ? 0d : 0 - iv));
}
}
return sortByPareto(tuples);
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class ConfusionMatrix method computeConfusionMatixForMultipleClassification.
public void computeConfusionMatixForMultipleClassification(long records) throws IOException {
SourceType sourceType = evalConfig.getDataSet().getSource();
List<Scanner> scanners = ShifuFileUtils.getDataScanners(pathFinder.getEvalScorePath(evalConfig, sourceType), sourceType);
boolean isDir = ShifuFileUtils.isDir(pathFinder.getEvalScorePath(evalConfig, sourceType), sourceType);
Set<String> tagSet = new HashSet<String>(modelConfig.getFlattenTags(modelConfig.getPosTags(evalConfig), modelConfig.getNegTags(evalConfig)));
List<Set<String>> tags = modelConfig.getSetTags(modelConfig.getPosTags(evalConfig), modelConfig.getNegTags(evalConfig));
int classes = tags.size();
long cnt = 0, invalidTargetCnt = 0;
ColumnConfig targetColumn = CommonUtils.findTargetColumn(columnConfigList);
List<Integer> binCountNeg = targetColumn.getBinCountNeg();
List<Integer> binCountPos = targetColumn.getBinCountPos();
long[] binCount = new long[classes];
double[] binRatio = new double[classes];
long sumCnt = 0L;
for (int i = 0; i < binCount.length; i++) {
binCount[i] = binCountNeg.get(i) + binCountPos.get(i);
sumCnt += binCount[i];
}
for (int i = 0; i < binCount.length; i++) {
binRatio[i] = (binCount[i] * 1d) / sumCnt;
}
long[][] confusionMatrix = new long[classes][classes];
for (Scanner scanner : scanners) {
while (scanner.hasNext()) {
if ((++cnt) % 100000 == 0) {
LOG.info("Loaded " + cnt + " records.");
}
if (!isDir && cnt == 1) {
// if the evaluation score file is the local file, skip the first line since we add header in
continue;
}
// score is separated by default delimiter in our pig output format
String[] raw = scanner.nextLine().split(Constants.DEFAULT_ESCAPE_DELIMITER);
String tag = raw[targetColumnIndex];
if (StringUtils.isBlank(tag) || !tagSet.contains(tag)) {
invalidTargetCnt += 1;
continue;
}
double[] scores = new double[classes];
int predictIndex = -1;
double maxScore = Double.NEGATIVE_INFINITY;
if (CommonUtils.isTreeModel(modelConfig.getAlgorithm()) && !modelConfig.getTrain().isOneVsAll()) {
// for RF native classification
double[] tagCounts = new double[tags.size()];
for (int i = this.multiClassScore1Index; i < (raw.length - this.metaColumns); i++) {
double dd = NumberFormatUtils.getDouble(raw[i], 0d);
tagCounts[(int) dd] += 1d;
}
double maxVotes = -1d;
for (int i = 0; i < tagCounts.length; i++) {
if (tagCounts[i] > maxVotes) {
predictIndex = i;
maxScore = maxVotes = tagCounts[i];
}
}
} else if ((CommonUtils.isTreeModel(modelConfig.getAlgorithm()) || NNConstants.NN_ALG_NAME.equalsIgnoreCase(modelConfig.getAlgorithm())) && modelConfig.getTrain().isOneVsAll()) {
// for RF, GBT & NN OneVsAll classification
if (classes == 2) {
// for binary classification, only one model is needed.
for (int i = this.multiClassScore1Index; i < (1 + this.multiClassScore1Index); i++) {
double dd = NumberFormatUtils.getDouble(raw[i], 0d);
if (dd > ((1d - binRatio[i - this.multiClassScore1Index]) * scoreScale)) {
predictIndex = 0;
} else {
predictIndex = 1;
}
}
} else {
// logic is here, per each onevsrest, it may be im-banlanced. for example, class a, b, c, first
// is a(1) vs b and c(0), ratio is 10:1, then to compare score, if score > 1/11 it is positive,
// check other models to see if still positive in b or c, take the largest one with ratio for
// final prediction
int[] predClasses = new int[classes];
double[] scoress = new double[classes];
double[] threhs = new double[classes];
for (int i = this.multiClassScore1Index; i < (classes + this.multiClassScore1Index); i++) {
double dd = NumberFormatUtils.getDouble(raw[i], 0d);
scoress[i - this.multiClassScore1Index] = dd;
threhs[i - this.multiClassScore1Index] = (1d - binRatio[i - this.multiClassScore1Index]) * scoreScale;
if (dd > ((1d - binRatio[i - this.multiClassScore1Index]) * scoreScale)) {
predClasses[i - this.multiClassScore1Index] = 1;
}
}
double maxRatio = -1d;
double maxPositiveRatio = -1d;
int maxRatioIndex = -1;
for (int i = 0; i < binCount.length; i++) {
if (binRatio[i] > maxRatio) {
maxRatio = binRatio[i];
maxRatioIndex = i;
}
// if has positive, choose one with highest ratio
if (predClasses[i] == 1) {
if (binRatio[i] > maxPositiveRatio) {
maxPositiveRatio = binRatio[i];
predictIndex = i;
}
}
}
// no any positive, take the largest one
if (maxPositiveRatio < 0d) {
predictIndex = maxRatioIndex;
}
}
} else {
if (classes == 2) {
// for binary classification, only one model is needed.
for (int i = this.multiClassScore1Index; i < (1 + this.multiClassScore1Index); i++) {
double dd = NumberFormatUtils.getDouble(raw[i], 0d);
if (dd > ((1d - binRatio[i - this.multiClassScore1Index]) * scoreScale)) {
predictIndex = 0;
} else {
predictIndex = 1;
}
}
} else {
// 1,2,3 4,5,6: 1,2,3 is model 0, 4,5,6 is model 1
for (int i = 0; i < classes; i++) {
for (int j = 0; j < multiClassModelCnt; j++) {
double dd = NumberFormatUtils.getDouble(raw[this.multiClassScore1Index + j * classes + i], 0d);
scores[i] += dd;
}
scores[i] /= multiClassModelCnt;
if (scores[i] > maxScore) {
predictIndex = i;
maxScore = scores[i];
}
}
}
}
int tagIndex = -1;
for (int i = 0; i < tags.size(); i++) {
if (tags.get(i).contains(tag)) {
tagIndex = i;
break;
}
}
confusionMatrix[tagIndex][predictIndex] += 1L;
}
scanner.close();
}
LOG.info("Totally loading {} records with invalid target records {} in eval {}.", cnt, invalidTargetCnt, evalConfig.getName());
writeToConfMatrixFile(tags, confusionMatrix);
// print conf matrix
LOG.info("Multiple classification confustion matrix:");
LOG.info(String.format("%15s: %20s", " ", tags.toString()));
for (int i = 0; i < confusionMatrix.length; i++) {
LOG.info(String.format("%15s: %20s", tags.get(i), Arrays.toString(confusionMatrix[i])));
}
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class UpdateBinningInfoMapper method populateStats.
private void populateStats(String[] units, String tag, Double weight, int columnIndex, int newCCIndex) {
ColumnConfig columnConfig = this.columnConfigList.get(columnIndex);
CountAndFrequentItems countAndFrequentItems = this.variableCountMap.get(newCCIndex);
if (countAndFrequentItems == null) {
countAndFrequentItems = new CountAndFrequentItems();
this.variableCountMap.put(newCCIndex, countAndFrequentItems);
}
countAndFrequentItems.offer(this.missingOrInvalidValues, units[columnIndex]);
boolean isMissingValue = false;
boolean isInvalidValue = false;
BinningInfoWritable binningInfoWritable = this.columnBinningInfo.get(newCCIndex);
if (binningInfoWritable == null) {
return;
}
binningInfoWritable.setTotalCount(binningInfoWritable.getTotalCount() + 1L);
if (columnConfig.isHybrid()) {
int binNum = 0;
if (units[columnIndex] == null || missingOrInvalidValues.contains(units[columnIndex].toLowerCase())) {
isMissingValue = true;
}
String str = units[columnIndex];
double douVal = BinUtils.parseNumber(str);
Double hybridThreshold = columnConfig.getHybridThreshold();
if (hybridThreshold == null) {
hybridThreshold = Double.NEGATIVE_INFINITY;
}
// douVal < hybridThreshould which will also be set to category
boolean isCategory = Double.isNaN(douVal) || douVal < hybridThreshold;
boolean isNumber = !Double.isNaN(douVal);
if (isMissingValue) {
binningInfoWritable.setMissingCount(binningInfoWritable.getMissingCount() + 1L);
binNum = binningInfoWritable.getBinCategories().size() + binningInfoWritable.getBinBoundaries().size();
} else if (isCategory) {
// get categorical bin number in category list
binNum = quickLocateCategoricalBin(this.categoricalBinMap.get(newCCIndex), str);
if (binNum < 0) {
isInvalidValue = true;
}
if (isInvalidValue) {
// the same as missing count
binningInfoWritable.setMissingCount(binningInfoWritable.getMissingCount() + 1L);
binNum = binningInfoWritable.getBinCategories().size() + binningInfoWritable.getBinBoundaries().size();
} else {
// if real category value, binNum should + binBoundaries.size
binNum += binningInfoWritable.getBinBoundaries().size();
;
}
} else if (isNumber) {
binNum = getBinNum(binningInfoWritable.getBinBoundaries(), douVal);
if (binNum == -1) {
throw new RuntimeException("binNum should not be -1 to this step.");
}
// other stats are treated as numerical features
binningInfoWritable.setSum(binningInfoWritable.getSum() + douVal);
double squaredVal = douVal * douVal;
binningInfoWritable.setSquaredSum(binningInfoWritable.getSquaredSum() + squaredVal);
binningInfoWritable.setTripleSum(binningInfoWritable.getTripleSum() + squaredVal * douVal);
binningInfoWritable.setQuarticSum(binningInfoWritable.getQuarticSum() + squaredVal * squaredVal);
if (Double.compare(binningInfoWritable.getMax(), douVal) < 0) {
binningInfoWritable.setMax(douVal);
}
if (Double.compare(binningInfoWritable.getMin(), douVal) > 0) {
binningInfoWritable.setMin(douVal);
}
}
if (posTags.contains(tag)) {
binningInfoWritable.getBinCountPos()[binNum] += 1L;
binningInfoWritable.getBinWeightPos()[binNum] += weight;
} else if (negTags.contains(tag)) {
binningInfoWritable.getBinCountNeg()[binNum] += 1L;
binningInfoWritable.getBinWeightNeg()[binNum] += weight;
}
} else if (columnConfig.isCategorical()) {
int lastBinIndex = binningInfoWritable.getBinCategories().size();
int binNum = 0;
if (units[columnIndex] == null || missingOrInvalidValues.contains(units[columnIndex].toLowerCase())) {
isMissingValue = true;
} else {
String str = units[columnIndex];
binNum = quickLocateCategoricalBin(this.categoricalBinMap.get(newCCIndex), str);
if (binNum < 0) {
isInvalidValue = true;
}
}
if (isInvalidValue || isMissingValue) {
binningInfoWritable.setMissingCount(binningInfoWritable.getMissingCount() + 1L);
binNum = lastBinIndex;
}
if (modelConfig.isRegression()) {
if (posTags.contains(tag)) {
binningInfoWritable.getBinCountPos()[binNum] += 1L;
binningInfoWritable.getBinWeightPos()[binNum] += weight;
} else if (negTags.contains(tag)) {
binningInfoWritable.getBinCountNeg()[binNum] += 1L;
binningInfoWritable.getBinWeightNeg()[binNum] += weight;
}
} else {
// for multiple classification, set bin count to BinCountPos and leave BinCountNeg empty
binningInfoWritable.getBinCountPos()[binNum] += 1L;
binningInfoWritable.getBinWeightPos()[binNum] += weight;
}
} else if (columnConfig.isNumerical()) {
int lastBinIndex = binningInfoWritable.getBinBoundaries().size();
double douVal = 0.0;
if (units[columnIndex] == null || units[columnIndex].length() == 0 || missingOrInvalidValues.contains(units[columnIndex].toLowerCase())) {
isMissingValue = true;
} else {
try {
douVal = Double.parseDouble(units[columnIndex].trim());
} catch (Exception e) {
isInvalidValue = true;
}
}
// add logic the same as CalculateNewStatsUDF
if (Double.compare(douVal, modelConfig.getNumericalValueThreshold()) > 0) {
isInvalidValue = true;
}
if (isInvalidValue || isMissingValue) {
binningInfoWritable.setMissingCount(binningInfoWritable.getMissingCount() + 1L);
if (modelConfig.isRegression()) {
if (posTags.contains(tag)) {
binningInfoWritable.getBinCountPos()[lastBinIndex] += 1L;
binningInfoWritable.getBinWeightPos()[lastBinIndex] += weight;
} else if (negTags.contains(tag)) {
binningInfoWritable.getBinCountNeg()[lastBinIndex] += 1L;
binningInfoWritable.getBinWeightNeg()[lastBinIndex] += weight;
}
}
} else {
// For invalid or missing values, no need update sum, squaredSum, max, min ...
int binNum = getBinNum(binningInfoWritable.getBinBoundaries(), units[columnIndex]);
if (binNum == -1) {
throw new RuntimeException("binNum should not be -1 to this step.");
}
if (modelConfig.isRegression()) {
if (posTags.contains(tag)) {
binningInfoWritable.getBinCountPos()[binNum] += 1L;
binningInfoWritable.getBinWeightPos()[binNum] += weight;
} else if (negTags.contains(tag)) {
binningInfoWritable.getBinCountNeg()[binNum] += 1L;
binningInfoWritable.getBinWeightNeg()[binNum] += weight;
}
}
binningInfoWritable.setSum(binningInfoWritable.getSum() + douVal);
double squaredVal = douVal * douVal;
binningInfoWritable.setSquaredSum(binningInfoWritable.getSquaredSum() + squaredVal);
binningInfoWritable.setTripleSum(binningInfoWritable.getTripleSum() + squaredVal * douVal);
binningInfoWritable.setQuarticSum(binningInfoWritable.getQuarticSum() + squaredVal * squaredVal);
if (Double.compare(binningInfoWritable.getMax(), douVal) < 0) {
binningInfoWritable.setMax(douVal);
}
if (Double.compare(binningInfoWritable.getMin(), douVal) > 0) {
binningInfoWritable.setMin(douVal);
}
}
}
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class UpdateBinningInfoMapper method loadColumnBinningInfo.
/**
* Load and initialize column binning info object.
*/
private void loadColumnBinningInfo() throws FileNotFoundException, IOException {
BufferedReader reader = null;
try {
reader = new BufferedReader(new InputStreamReader(new FileInputStream(Constants.BINNING_INFO_FILE_NAME), Charset.forName("UTF-8")));
String line = reader.readLine();
while (line != null && line.length() != 0) {
LOG.debug("line is {}", line);
// here just use String.split for just two columns
String[] cols = Lists.newArrayList(this.splitter.split(line)).toArray(new String[0]);
if (cols != null && cols.length >= 2) {
Integer rawColumnNum = Integer.parseInt(cols[0]);
BinningInfoWritable binningInfo = new BinningInfoWritable();
int corrColumnNum = rawColumnNum;
if (rawColumnNum >= this.columnConfigList.size()) {
corrColumnNum = rawColumnNum % this.columnConfigList.size();
}
binningInfo.setColumnNum(rawColumnNum);
ColumnConfig columnConfig = this.columnConfigList.get(corrColumnNum);
int binSize = 0;
if (columnConfig.isHybrid()) {
binningInfo.setNumeric(true);
String[] splits = CommonUtils.split(cols[1], Constants.HYBRID_BIN_STR_DILIMETER);
List<Double> list = new ArrayList<Double>();
for (String startElement : BIN_BOUNDARY_SPLITTER.split(splits[0])) {
list.add(Double.valueOf(startElement));
}
binningInfo.setBinBoundaries(list);
List<String> cateList = new ArrayList<String>();
Map<String, Integer> map = this.categoricalBinMap.get(rawColumnNum);
if (map == null) {
map = new HashMap<String, Integer>();
this.categoricalBinMap.put(rawColumnNum, map);
}
int index = 0;
if (!StringUtils.isBlank(splits[1])) {
for (String startElement : BIN_BOUNDARY_SPLITTER.split(splits[1])) {
cateList.add(startElement);
map.put(startElement, index++);
}
}
binningInfo.setBinCategories(cateList);
binSize = list.size() + cateList.size();
} else if (columnConfig.isNumerical()) {
binningInfo.setNumeric(true);
List<Double> list = new ArrayList<Double>();
for (String startElement : BIN_BOUNDARY_SPLITTER.split(cols[1])) {
list.add(Double.valueOf(startElement));
}
binningInfo.setBinBoundaries(list);
binSize = list.size();
} else {
binningInfo.setNumeric(false);
List<String> list = new ArrayList<String>();
Map<String, Integer> map = this.categoricalBinMap.get(rawColumnNum);
if (map == null) {
map = new HashMap<String, Integer>();
this.categoricalBinMap.put(rawColumnNum, map);
}
int index = 0;
if (!StringUtils.isBlank(cols[1])) {
for (String startElement : BIN_BOUNDARY_SPLITTER.split(cols[1])) {
list.add(startElement);
map.put(startElement, index++);
}
}
binningInfo.setBinCategories(list);
binSize = list.size();
}
long[] binCountPos = new long[binSize + 1];
binningInfo.setBinCountPos(binCountPos);
long[] binCountNeg = new long[binSize + 1];
binningInfo.setBinCountNeg(binCountNeg);
double[] binWeightPos = new double[binSize + 1];
binningInfo.setBinWeightPos(binWeightPos);
double[] binWeightNeg = new double[binSize + 1];
binningInfo.setBinWeightNeg(binWeightNeg);
LOG.debug("column num {} and info {}", rawColumnNum, binningInfo);
this.columnBinningInfo.put(rawColumnNum, binningInfo);
}
line = reader.readLine();
}
} finally {
if (reader != null) {
reader.close();
}
}
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class TrainModelProcessor method runDistributedTrain.
protected int runDistributedTrain() throws IOException, InterruptedException, ClassNotFoundException {
LOG.info("Started {}distributed training.", isDryTrain ? "dry " : "");
int status = 0;
Configuration conf = new Configuration();
SourceType sourceType = super.getModelConfig().getDataSet().getSource();
final List<String> args = new ArrayList<String>();
GridSearch gs = new GridSearch(modelConfig.getTrain().getParams(), modelConfig.getTrain().getGridConfigFileContent());
prepareCommonParams(gs.hasHyperParam(), args, sourceType);
String alg = super.getModelConfig().getTrain().getAlgorithm();
// add tmp models folder to config
FileSystem fileSystem = ShifuFileUtils.getFileSystemBySourceType(sourceType);
Path tmpModelsPath = fileSystem.makeQualified(new Path(super.getPathFinder().getPathBySourceType(new Path(Constants.TMP, Constants.DEFAULT_MODELS_TMP_FOLDER), sourceType)));
args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_TMP_MODELS_FOLDER, tmpModelsPath.toString()));
int baggingNum = isForVarSelect ? 1 : super.getModelConfig().getBaggingNum();
if (modelConfig.isClassification()) {
int classes = modelConfig.getTags().size();
if (classes == 2) {
// binary classification, only need one job
baggingNum = 1;
} else {
if (modelConfig.getTrain().isOneVsAll()) {
// one vs all multiple classification, we need multiple bagging jobs to do ONEVSALL
baggingNum = modelConfig.getTags().size();
} else {
// native classification, using bagging from setting job, no need set here
}
}
if (baggingNum != super.getModelConfig().getBaggingNum()) {
LOG.warn("'train:baggingNum' is set to {} because of ONEVSALL multiple classification.", baggingNum);
}
}
boolean isKFoldCV = false;
Integer kCrossValidation = this.modelConfig.getTrain().getNumKFold();
if (kCrossValidation != null && kCrossValidation > 0) {
isKFoldCV = true;
baggingNum = modelConfig.getTrain().getNumKFold();
if (baggingNum != super.getModelConfig().getBaggingNum() && gs.hasHyperParam()) {
// if it is grid search mode, then kfold mode is disabled
LOG.warn("'train:baggingNum' is set to {} because of k-fold cross validation is enabled by 'numKFold' not -1.", baggingNum);
}
}
long start = System.currentTimeMillis();
boolean isParallel = Boolean.valueOf(Environment.getProperty(Constants.SHIFU_DTRAIN_PARALLEL, SHIFU_DEFAULT_DTRAIN_PARALLEL)).booleanValue();
GuaguaMapReduceClient guaguaClient;
int[] inputOutputIndex = DTrainUtils.getInputOutputCandidateCounts(modelConfig.getNormalizeType(), this.columnConfigList);
int inputNodeCount = inputOutputIndex[0] == 0 ? inputOutputIndex[2] : inputOutputIndex[0];
int candidateCount = inputOutputIndex[2];
boolean isAfterVarSelect = (inputOutputIndex[0] != 0);
// cache all feature list for sampling features
List<Integer> allFeatures = NormalUtils.getAllFeatureList(this.columnConfigList, isAfterVarSelect);
if (modelConfig.getNormalize().getIsParquet()) {
guaguaClient = new GuaguaParquetMapReduceClient();
// set required field list to make sure we only load selected columns.
RequiredFieldList requiredFieldList = new RequiredFieldList();
boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
for (ColumnConfig columnConfig : super.columnConfigList) {
if (columnConfig.isTarget()) {
requiredFieldList.add(new RequiredField(columnConfig.getColumnName(), columnConfig.getColumnNum(), null, DataType.FLOAT));
} else {
if (inputNodeCount == candidateCount) {
// no any variables are selected
if (!columnConfig.isMeta() && !columnConfig.isTarget() && CommonUtils.isGoodCandidate(columnConfig, hasCandidates)) {
requiredFieldList.add(new RequiredField(columnConfig.getColumnName(), columnConfig.getColumnNum(), null, DataType.FLOAT));
}
} else {
if (!columnConfig.isMeta() && !columnConfig.isTarget() && columnConfig.isFinalSelect()) {
requiredFieldList.add(new RequiredField(columnConfig.getColumnName(), columnConfig.getColumnNum(), null, DataType.FLOAT));
}
}
}
}
// weight is added manually
requiredFieldList.add(new RequiredField("weight", columnConfigList.size(), null, DataType.DOUBLE));
args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, "parquet.private.pig.required.fields", serializeRequiredFieldList(requiredFieldList)));
args.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, "parquet.private.pig.column.index.access", "true"));
} else {
guaguaClient = new GuaguaMapReduceClient();
}
int parallelNum = Integer.parseInt(Environment.getProperty(CommonConstants.SHIFU_TRAIN_BAGGING_INPARALLEL, "5"));
int parallelGroups = 1;
if (gs.hasHyperParam()) {
parallelGroups = (gs.getFlattenParams().size() % parallelNum == 0 ? gs.getFlattenParams().size() / parallelNum : gs.getFlattenParams().size() / parallelNum + 1);
baggingNum = gs.getFlattenParams().size();
LOG.warn("'train:baggingNum' is set to {} because of grid search enabled by settings in 'train#params'.", gs.getFlattenParams().size());
} else {
parallelGroups = baggingNum % parallelNum == 0 ? baggingNum / parallelNum : baggingNum / parallelNum + 1;
}
LOG.info("Distributed trainning with baggingNum: {}", baggingNum);
List<String> progressLogList = new ArrayList<String>(baggingNum);
boolean isOneJobNotContinuous = false;
for (int j = 0; j < parallelGroups; j++) {
int currBags = baggingNum;
if (gs.hasHyperParam()) {
if (j == parallelGroups - 1) {
currBags = gs.getFlattenParams().size() % parallelNum == 0 ? parallelNum : gs.getFlattenParams().size() % parallelNum;
} else {
currBags = parallelNum;
}
} else {
if (j == parallelGroups - 1) {
currBags = baggingNum % parallelNum == 0 ? parallelNum : baggingNum % parallelNum;
} else {
currBags = parallelNum;
}
}
for (int k = 0; k < currBags; k++) {
int i = j * parallelNum + k;
if (gs.hasHyperParam()) {
LOG.info("Start the {}th grid search job with params: {}", i, gs.getParams(i));
} else if (isKFoldCV) {
LOG.info("Start the {}th k-fold cross validation job with params.", i);
}
List<String> localArgs = new ArrayList<String>(args);
// set name for each bagging job.
localArgs.add("-n");
localArgs.add(String.format("Shifu Master-Workers %s Training Iteration: %s id:%s", alg, super.getModelConfig().getModelSetName(), i));
LOG.info("Start trainer with id: {}", i);
String modelName = getModelName(i);
Path modelPath = fileSystem.makeQualified(new Path(super.getPathFinder().getModelsPath(sourceType), modelName));
Path bModelPath = fileSystem.makeQualified(new Path(super.getPathFinder().getNNBinaryModelsPath(sourceType), modelName));
// check if job is continuous training, this can be set multiple times and we only get last one
boolean isContinuous = false;
if (gs.hasHyperParam()) {
isContinuous = false;
} else {
int intContinuous = checkContinuousTraining(fileSystem, localArgs, modelPath, modelConfig.getTrain().getParams());
if (intContinuous == -1) {
LOG.warn("Model with index {} with size of trees is over treeNum, such training will not be started.", i);
continue;
} else {
isContinuous = (intContinuous == 1);
}
}
// training
if (gs.hasHyperParam() || isKFoldCV) {
isContinuous = false;
}
if (!isContinuous && !isOneJobNotContinuous) {
isOneJobNotContinuous = true;
// delete all old models if not continuous
String srcModelPath = super.getPathFinder().getModelsPath(sourceType);
String mvModelPath = srcModelPath + "_" + System.currentTimeMillis();
LOG.info("Old model path has been moved to {}", mvModelPath);
fileSystem.rename(new Path(srcModelPath), new Path(mvModelPath));
fileSystem.mkdirs(new Path(srcModelPath));
FileSystem.getLocal(conf).delete(new Path(super.getPathFinder().getModelsPath(SourceType.LOCAL)), true);
}
if (NNConstants.NN_ALG_NAME.equalsIgnoreCase(alg)) {
// tree related parameters initialization
Map<String, Object> params = gs.hasHyperParam() ? gs.getParams(i) : this.modelConfig.getTrain().getParams();
Object fssObj = params.get("FeatureSubsetStrategy");
FeatureSubsetStrategy featureSubsetStrategy = null;
double featureSubsetRate = 0d;
if (fssObj != null) {
try {
featureSubsetRate = Double.parseDouble(fssObj.toString());
// no need validate featureSubsetRate is in (0,1], as already validated in ModelInspector
featureSubsetStrategy = null;
} catch (NumberFormatException ee) {
featureSubsetStrategy = FeatureSubsetStrategy.of(fssObj.toString());
}
} else {
LOG.warn("FeatureSubsetStrategy is not set, set to ALL by default.");
featureSubsetStrategy = FeatureSubsetStrategy.ALL;
featureSubsetRate = 0;
}
Set<Integer> subFeatures = null;
if (isContinuous) {
BasicFloatNetwork existingModel = (BasicFloatNetwork) ModelSpecLoaderUtils.getBasicNetwork(ModelSpecLoaderUtils.loadModel(modelConfig, modelPath, ShifuFileUtils.getFileSystemBySourceType(this.modelConfig.getDataSet().getSource())));
if (existingModel == null) {
subFeatures = new HashSet<Integer>(getSubsamplingFeatures(allFeatures, featureSubsetStrategy, featureSubsetRate, inputNodeCount));
} else {
subFeatures = existingModel.getFeatureSet();
}
} else {
subFeatures = new HashSet<Integer>(getSubsamplingFeatures(allFeatures, featureSubsetStrategy, featureSubsetRate, inputNodeCount));
}
if (subFeatures == null || subFeatures.size() == 0) {
localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_NN_FEATURE_SUBSET, ""));
} else {
localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_NN_FEATURE_SUBSET, StringUtils.join(subFeatures, ',')));
LOG.debug("Size: {}, list: {}.", subFeatures.size(), StringUtils.join(subFeatures, ','));
}
}
localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.GUAGUA_OUTPUT, modelPath.toString()));
localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, Constants.SHIFU_NN_BINARY_MODEL_PATH, bModelPath.toString()));
if (gs.hasHyperParam() || isKFoldCV) {
// k-fold cv need val error
Path valErrPath = fileSystem.makeQualified(new Path(super.getPathFinder().getValErrorPath(sourceType), "val_error_" + i));
localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.GS_VALIDATION_ERROR, valErrPath.toString()));
}
localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_TRAINER_ID, String.valueOf(i)));
final String progressLogFile = getProgressLogFile(i);
progressLogList.add(progressLogFile);
localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, CommonConstants.SHIFU_DTRAIN_PROGRESS_FILE, progressLogFile));
String hdpVersion = HDPUtils.getHdpVersionForHDP224();
if (StringUtils.isNotBlank(hdpVersion)) {
localArgs.add(String.format(CommonConstants.MAPREDUCE_PARAM_FORMAT, "hdp.version", hdpVersion));
HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("hdfs-site.xml"), conf);
HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("core-site.xml"), conf);
HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("mapred-site.xml"), conf);
HDPUtils.addFileToClassPath(HDPUtils.findContainingFile("yarn-site.xml"), conf);
}
if (isParallel) {
guaguaClient.addJob(localArgs.toArray(new String[0]));
} else {
TailThread tailThread = startTailThread(new String[] { progressLogFile });
boolean ret = guaguaClient.createJob(localArgs.toArray(new String[0])).waitForCompletion(true);
status += (ret ? 0 : 1);
stopTailThread(tailThread);
}
}
if (isParallel) {
TailThread tailThread = startTailThread(progressLogList.toArray(new String[0]));
status += guaguaClient.run();
stopTailThread(tailThread);
}
}
if (isKFoldCV) {
// k-fold we also copy model files at last, such models can be used for evaluation
for (int i = 0; i < baggingNum; i++) {
String modelName = getModelName(i);
Path modelPath = fileSystem.makeQualified(new Path(super.getPathFinder().getModelsPath(sourceType), modelName));
if (ShifuFileUtils.getFileSystemBySourceType(sourceType).exists(modelPath)) {
copyModelToLocal(modelName, modelPath, sourceType);
} else {
LOG.warn("Model {} isn't there, maybe job is failed, for bagging it can be ignored.", modelPath.toString());
status += 1;
}
}
List<Double> valErrs = readAllValidationErrors(sourceType, fileSystem, kCrossValidation);
double sum = 0d;
for (Double err : valErrs) {
sum += err;
}
LOG.info("Average validation error for current k-fold cross validation is {}.", sum / valErrs.size());
LOG.info("K-fold cross validation on distributed training finished in {}ms.", System.currentTimeMillis() - start);
} else if (gs.hasHyperParam()) {
// select the best parameter composite in grid search
LOG.info("Original grid search params: {}", modelConfig.getParams());
Map<String, Object> params = findBestParams(sourceType, fileSystem, gs);
// temp copy all models for evaluation
for (int i = 0; i < baggingNum; i++) {
String modelName = getModelName(i);
Path modelPath = fileSystem.makeQualified(new Path(super.getPathFinder().getModelsPath(sourceType), modelName));
if (ShifuFileUtils.getFileSystemBySourceType(sourceType).exists(modelPath) && (status == 0)) {
copyModelToLocal(modelName, modelPath, sourceType);
} else {
LOG.warn("Model {} isn't there, maybe job is failed, for bagging it can be ignored.", modelPath.toString());
}
}
LOG.info("The best parameters in grid search is {}", params);
LOG.info("Grid search on distributed training finished in {}ms.", System.currentTimeMillis() - start);
} else {
// copy model files at last.
for (int i = 0; i < baggingNum; i++) {
String modelName = getModelName(i);
Path modelPath = fileSystem.makeQualified(new Path(super.getPathFinder().getModelsPath(sourceType), modelName));
if (ShifuFileUtils.getFileSystemBySourceType(sourceType).exists(modelPath) && (status == 0)) {
copyModelToLocal(modelName, modelPath, sourceType);
} else {
LOG.warn("Model {} isn't there, maybe job is failed, for bagging it can be ignored.", modelPath.toString());
}
}
// copy temp model files, for RF/GBT, not to copy tmp models because of larger space needed, for others
// by default copy tmp models to local
boolean copyTmpModelsToLocal = Boolean.TRUE.toString().equalsIgnoreCase(Environment.getProperty(Constants.SHIFU_TMPMODEL_COPYTOLOCAL, "true"));
if (copyTmpModelsToLocal) {
copyTmpModelsToLocal(tmpModelsPath, sourceType);
} else {
LOG.info("Tmp models are not copied into local, please find them in hdfs path: {}", tmpModelsPath);
}
LOG.info("Distributed training finished in {}ms.", System.currentTimeMillis() - start);
}
if (CommonUtils.isTreeModel(modelConfig.getAlgorithm())) {
List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(this.modelConfig, null);
// compute feature importance and write to local file after models are trained
Map<Integer, MutablePair<String, Double>> featureImportances = CommonUtils.computeTreeModelFeatureImportance(models);
String localFsFolder = pathFinder.getLocalFeatureImportanceFolder();
String localFIPath = pathFinder.getLocalFeatureImportancePath();
processRollupForFIFiles(localFsFolder, localFIPath);
CommonUtils.writeFeatureImportance(localFIPath, featureImportances);
}
if (status != 0) {
LOG.error("Error may occurred. There is no model generated. Please check!");
}
return status;
}
Aggregations