use of ml.shifu.shifu.core.dtrain.gs.GridSearch in project shifu by ShifuML.
the class NNOutput method init.
private void init(MasterContext<NNParams, NNParams> context) {
this.isDry = Boolean.TRUE.toString().equals(context.getProps().getProperty(CommonConstants.SHIFU_DRY_DTRAIN));
if (this.isDry) {
return;
}
if (isInit.compareAndSet(false, true)) {
loadConfigFiles(context.getProps());
this.trainerId = context.getProps().getProperty(CommonConstants.SHIFU_TRAINER_ID);
this.tmpModelsFolder = context.getProps().getProperty(CommonConstants.SHIFU_TMP_MODELS_FOLDER);
gridSearch = new GridSearch(modelConfig.getTrain().getParams(), modelConfig.getTrain().getGridConfigFileContent());
validParams = this.modelConfig.getTrain().getParams();
if (gridSearch.hasHyperParam()) {
validParams = gridSearch.getParams(Integer.parseInt(trainerId));
LOG.info("Start grid search in nn output with params: {}", validParams);
}
Integer kCrossValidation = this.modelConfig.getTrain().getNumKFold();
if (kCrossValidation != null && kCrossValidation > 0) {
isKFoldCV = true;
}
Object dropoutRateObj = validParams.get(CommonConstants.DROPOUT_RATE);
if (dropoutRateObj != null) {
this.dropoutRate = Double.valueOf(dropoutRateObj.toString());
}
LOG.info("'dropoutRate' in master output is :{}", this.dropoutRate);
this.wgtInit = "default";
Object wgtInitObj = validParams.get(CommonConstants.WEIGHT_INITIALIZER);
if (wgtInitObj != null) {
this.wgtInit = wgtInitObj.toString();
}
this.bModel = new Path(context.getProps().getProperty(Constants.SHIFU_NN_BINARY_MODEL_PATH));
initNetwork(context);
}
try {
Path progressLog = new Path(context.getProps().getProperty(CommonConstants.SHIFU_DTRAIN_PROGRESS_FILE));
// we need to append the log, so that client console can get refreshed. Or console will appear stuck.
if (ShifuFileUtils.isFileExists(progressLog, SourceType.HDFS)) {
this.progressOutput = FileSystem.get(new Configuration()).append(progressLog);
} else {
this.progressOutput = FileSystem.get(new Configuration()).create(progressLog);
}
} catch (IOException e) {
LOG.error("Error in create progress log:", e);
}
}
use of ml.shifu.shifu.core.dtrain.gs.GridSearch in project shifu by ShifuML.
the class ModelSpecLoaderUtils method locateBasicModels.
/**
* Find model spec files
*
* @param modelConfig
* model config
* @param evalConfig
* eval configuration
* @param sourceType
* {@link SourceType} LOCAL or HDFS?
* @return The basic model file list
* @throws IOException
* Exception when fail to load basic models
*/
public static List<FileStatus> locateBasicModels(ModelConfig modelConfig, EvalConfig evalConfig, SourceType sourceType) throws IOException {
// we have to register PersistBasicFloatNetwork for loading such models
PersistorRegistry.getInstance().add(new PersistBasicFloatNetwork());
List<FileStatus> listStatus = findModels(modelConfig, evalConfig, sourceType);
if (CollectionUtils.isEmpty(listStatus)) {
// throw new ShifuException(ShifuErrorCode.ERROR_MODEL_FILE_NOT_FOUND);
// disable exception, since we there maybe sub-models
listStatus = findGenericModels(modelConfig, evalConfig, sourceType);
// if models not found, continue which makes eval works when training is in progress.
if (CollectionUtils.isNotEmpty(listStatus)) {
return listStatus;
}
}
// to avoid the *unix and windows file list order
Collections.sort(listStatus, new Comparator<FileStatus>() {
@Override
public int compare(FileStatus f1, FileStatus f2) {
return f1.getPath().getName().compareToIgnoreCase(f2.getPath().getName());
}
});
// added in shifu 0.2.5 to slice models not belonging to last training
int baggingModelSize = modelConfig.getTrain().getBaggingNum();
if (modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll()) {
baggingModelSize = modelConfig.getTags().size();
}
Integer kCrossValidation = modelConfig.getTrain().getNumKFold();
if (kCrossValidation != null && kCrossValidation > 0) {
// if kfold is enabled , bagging set it to bagging model size
baggingModelSize = kCrossValidation;
}
GridSearch gs = new GridSearch(modelConfig.getTrain().getParams(), modelConfig.getTrain().getGridConfigFileContent());
if (gs.hasHyperParam()) {
// if it is grid search, set model size to all flatten params
baggingModelSize = gs.getFlattenParams().size();
}
listStatus = listStatus.size() <= baggingModelSize ? listStatus : listStatus.subList(0, baggingModelSize);
return listStatus;
}
use of ml.shifu.shifu.core.dtrain.gs.GridSearch in project shifu by ShifuML.
the class DTOutput method init.
private void init(MasterContext<DTMasterParams, DTWorkerParams> context) {
if (isInit.compareAndSet(false, true)) {
this.conf = new Configuration();
loadConfigFiles(context.getProps());
this.trainerId = context.getProps().getProperty(CommonConstants.SHIFU_TRAINER_ID);
GridSearch gs = new GridSearch(modelConfig.getTrain().getParams(), modelConfig.getTrain().getGridConfigFileContent());
this.isGsMode = gs.hasHyperParam();
this.validParams = modelConfig.getParams();
if (isGsMode) {
this.validParams = gs.getParams(Integer.parseInt(trainerId));
}
Integer kCrossValidation = this.modelConfig.getTrain().getNumKFold();
if (kCrossValidation != null && kCrossValidation > 0) {
isKFoldCV = true;
}
this.tmpModelsFolder = context.getProps().getProperty(CommonConstants.SHIFU_TMP_MODELS_FOLDER);
this.isRF = ALGORITHM.RF.toString().equalsIgnoreCase(modelConfig.getAlgorithm());
this.isGBDT = ALGORITHM.GBT.toString().equalsIgnoreCase(modelConfig.getAlgorithm());
int[] inputOutputIndex = DTrainUtils.getNumericAndCategoricalInputAndOutputCounts(this.columnConfigList);
// numerical + categorical = # of all input
this.inputCount = inputOutputIndex[0] + inputOutputIndex[1];
try {
Path progressLog = new Path(context.getProps().getProperty(CommonConstants.SHIFU_DTRAIN_PROGRESS_FILE));
// we need to append the log, so that client console can get refreshed. Or console will appear stuck.
if (ShifuFileUtils.isFileExists(progressLog, SourceType.HDFS)) {
this.progressOutput = FileSystem.get(new Configuration()).append(progressLog);
} else {
this.progressOutput = FileSystem.get(new Configuration()).create(progressLog);
}
} catch (IOException e) {
LOG.error("Error in create progress log:", e);
}
this.treeNum = Integer.valueOf(validParams.get("TreeNum").toString());
;
}
}
use of ml.shifu.shifu.core.dtrain.gs.GridSearch in project shifu by ShifuML.
the class DTWorker method init.
@Override
public void init(WorkerContext<DTMasterParams, DTWorkerParams> context) {
Properties props = context.getProps();
try {
SourceType sourceType = SourceType.valueOf(props.getProperty(CommonConstants.MODELSET_SOURCE_TYPE, SourceType.HDFS.toString()));
this.modelConfig = CommonUtils.loadModelConfig(props.getProperty(CommonConstants.SHIFU_MODEL_CONFIG), sourceType);
this.columnConfigList = CommonUtils.loadColumnConfigList(props.getProperty(CommonConstants.SHIFU_COLUMN_CONFIG), sourceType);
} catch (IOException e) {
throw new RuntimeException(e);
}
this.columnCategoryIndexMapping = new HashMap<Integer, Map<String, Integer>>();
for (ColumnConfig config : this.columnConfigList) {
if (config.isCategorical()) {
if (config.getBinCategory() != null) {
Map<String, Integer> tmpMap = new HashMap<String, Integer>();
for (int i = 0; i < config.getBinCategory().size(); i++) {
List<String> catVals = CommonUtils.flattenCatValGrp(config.getBinCategory().get(i));
for (String cval : catVals) {
tmpMap.put(cval, i);
}
}
this.columnCategoryIndexMapping.put(config.getColumnNum(), tmpMap);
}
}
}
this.hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
// create Splitter
String delimiter = context.getProps().getProperty(Constants.SHIFU_OUTPUT_DATA_DELIMITER);
this.splitter = MapReduceUtils.generateShifuOutputSplitter(delimiter);
Integer kCrossValidation = this.modelConfig.getTrain().getNumKFold();
if (kCrossValidation != null && kCrossValidation > 0) {
isKFoldCV = true;
LOG.info("Cross validation is enabled by kCrossValidation: {}.", kCrossValidation);
}
Double upSampleWeight = modelConfig.getTrain().getUpSampleWeight();
if (Double.compare(upSampleWeight, 1d) != 0 && (modelConfig.isRegression() || (modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll()))) {
// set mean to upSampleWeight -1 and get sample + 1 to make sure no zero sample value
LOG.info("Enable up sampling with weight {}.", upSampleWeight);
this.upSampleRng = new PoissonDistribution(upSampleWeight - 1);
}
this.isContinuousEnabled = Boolean.TRUE.toString().equalsIgnoreCase(context.getProps().getProperty(CommonConstants.CONTINUOUS_TRAINING));
this.workerThreadCount = modelConfig.getTrain().getWorkerThreadCount();
this.threadPool = Executors.newFixedThreadPool(this.workerThreadCount);
// enable shut down logic
context.addCompletionCallBack(new WorkerCompletionCallBack<DTMasterParams, DTWorkerParams>() {
@Override
public void callback(WorkerContext<DTMasterParams, DTWorkerParams> context) {
DTWorker.this.threadPool.shutdownNow();
try {
DTWorker.this.threadPool.awaitTermination(2, TimeUnit.SECONDS);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
});
this.trainerId = Integer.valueOf(context.getProps().getProperty(CommonConstants.SHIFU_TRAINER_ID, "0"));
this.isOneVsAll = modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll();
GridSearch gs = new GridSearch(modelConfig.getTrain().getParams(), modelConfig.getTrain().getGridConfigFileContent());
Map<String, Object> validParams = this.modelConfig.getTrain().getParams();
if (gs.hasHyperParam()) {
validParams = gs.getParams(this.trainerId);
LOG.info("Start grid search worker with params: {}", validParams);
}
this.treeNum = Integer.valueOf(validParams.get("TreeNum").toString());
double memoryFraction = Double.valueOf(context.getProps().getProperty("guagua.data.memoryFraction", "0.6"));
LOG.info("Max heap memory: {}, fraction: {}", Runtime.getRuntime().maxMemory(), memoryFraction);
double validationRate = this.modelConfig.getValidSetRate();
if (StringUtils.isNotBlank(modelConfig.getValidationDataSetRawPath())) {
// fixed 0.6 and 0.4 of max memory for trainingData and validationData
this.trainingData = new MemoryLimitedList<Data>((long) (Runtime.getRuntime().maxMemory() * memoryFraction * 0.6), new ArrayList<Data>());
this.validationData = new MemoryLimitedList<Data>((long) (Runtime.getRuntime().maxMemory() * memoryFraction * 0.4), new ArrayList<Data>());
} else {
if (Double.compare(validationRate, 0d) != 0) {
this.trainingData = new MemoryLimitedList<Data>((long) (Runtime.getRuntime().maxMemory() * memoryFraction * (1 - validationRate)), new ArrayList<Data>());
this.validationData = new MemoryLimitedList<Data>((long) (Runtime.getRuntime().maxMemory() * memoryFraction * validationRate), new ArrayList<Data>());
} else {
this.trainingData = new MemoryLimitedList<Data>((long) (Runtime.getRuntime().maxMemory() * memoryFraction), new ArrayList<Data>());
}
}
int[] inputOutputIndex = DTrainUtils.getNumericAndCategoricalInputAndOutputCounts(this.columnConfigList);
// numerical + categorical = # of all input
this.inputCount = inputOutputIndex[0] + inputOutputIndex[1];
// regression outputNodeCount is 1, binaryClassfication, it is 1, OneVsAll it is 1, Native classification it is
// 1, with index of 0,1,2,3 denotes different classes
this.isAfterVarSelect = (inputOutputIndex[3] == 1);
this.isManualValidation = (modelConfig.getValidationDataSetRawPath() != null && !"".equals(modelConfig.getValidationDataSetRawPath()));
int numClasses = this.modelConfig.isClassification() ? this.modelConfig.getTags().size() : 2;
String imStr = validParams.get("Impurity").toString();
int minInstancesPerNode = Integer.valueOf(validParams.get("MinInstancesPerNode").toString());
double minInfoGain = Double.valueOf(validParams.get("MinInfoGain").toString());
if (imStr.equalsIgnoreCase("entropy")) {
impurity = new Entropy(numClasses, minInstancesPerNode, minInfoGain);
} else if (imStr.equalsIgnoreCase("gini")) {
impurity = new Gini(numClasses, minInstancesPerNode, minInfoGain);
} else if (imStr.equalsIgnoreCase("friedmanmse")) {
impurity = new FriedmanMSE(minInstancesPerNode, minInfoGain);
} else {
impurity = new Variance(minInstancesPerNode, minInfoGain);
}
this.isRF = ALGORITHM.RF.toString().equalsIgnoreCase(modelConfig.getAlgorithm());
this.isGBDT = ALGORITHM.GBT.toString().equalsIgnoreCase(modelConfig.getAlgorithm());
String lossStr = validParams.get("Loss").toString();
if (lossStr.equalsIgnoreCase("log")) {
this.loss = new LogLoss();
} else if (lossStr.equalsIgnoreCase("absolute")) {
this.loss = new AbsoluteLoss();
} else if (lossStr.equalsIgnoreCase("halfgradsquared")) {
this.loss = new HalfGradSquaredLoss();
} else if (lossStr.equalsIgnoreCase("squared")) {
this.loss = new SquaredLoss();
} else {
try {
this.loss = (Loss) ClassUtils.newInstance(Class.forName(lossStr));
} catch (ClassNotFoundException e) {
LOG.warn("Class not found for {}, using default SquaredLoss", lossStr);
this.loss = new SquaredLoss();
}
}
if (this.isGBDT) {
this.learningRate = Double.valueOf(validParams.get(CommonConstants.LEARNING_RATE).toString());
Object swrObj = validParams.get("GBTSampleWithReplacement");
if (swrObj != null) {
this.gbdtSampleWithReplacement = Boolean.TRUE.toString().equalsIgnoreCase(swrObj.toString());
}
Object dropoutObj = validParams.get(CommonConstants.DROPOUT_RATE);
if (dropoutObj != null) {
this.dropOutRate = Double.valueOf(dropoutObj.toString());
}
}
this.isStratifiedSampling = this.modelConfig.getTrain().getStratifiedSample();
this.checkpointOutput = new Path(context.getProps().getProperty(CommonConstants.SHIFU_DT_MASTER_CHECKPOINT_FOLDER, "tmp/cp_" + context.getAppId()));
LOG.info("Worker init params:isAfterVarSel={}, treeNum={}, impurity={}, loss={}, learningRate={}, gbdtSampleWithReplacement={}, isRF={}, isGBDT={}, isStratifiedSampling={}, isKFoldCV={}, kCrossValidation={}, dropOutRate={}", isAfterVarSelect, treeNum, impurity.getClass().getName(), loss.getClass().getName(), this.learningRate, this.gbdtSampleWithReplacement, this.isRF, this.isGBDT, this.isStratifiedSampling, this.isKFoldCV, kCrossValidation, this.dropOutRate);
// for fail over, load existing trees
if (!context.isFirstIteration()) {
if (this.isGBDT) {
// set flag here and recover later in doComputing, this is to make sure recover after load part which
// can load latest trees in #doCompute
isNeedRecoverGBDTPredict = true;
} else {
// RF , trees are recovered from last master results
recoverTrees = context.getLastMasterResult().getTrees();
}
}
if (context.isFirstIteration() && this.isContinuousEnabled && this.isGBDT) {
Path modelPath = new Path(context.getProps().getProperty(CommonConstants.GUAGUA_OUTPUT));
TreeModel existingModel = null;
try {
existingModel = (TreeModel) ModelSpecLoaderUtils.loadModel(modelConfig, modelPath, ShifuFileUtils.getFileSystemBySourceType(this.modelConfig.getDataSet().getSource()));
} catch (IOException e) {
LOG.error("Error in get existing model, will ignore and start from scratch", e);
}
if (existingModel == null) {
LOG.warn("No model is found even set to continuous model training.");
return;
} else {
recoverTrees = existingModel.getTrees();
LOG.info("Loading existing {} trees", recoverTrees.size());
}
}
}
use of ml.shifu.shifu.core.dtrain.gs.GridSearch in project shifu by ShifuML.
the class TrainModelProcessor method validateDistributedTrain.
private void validateDistributedTrain() throws IOException {
String alg = super.getModelConfig().getTrain().getAlgorithm();
if (Constants.TENSORFLOW.equalsIgnoreCase(alg)) {
// we do not train tensorflow in dist mode currently
return;
}
if (!(// NN algorithm
NNConstants.NN_ALG_NAME.equalsIgnoreCase(alg) || // LR algorithm
LogisticRegressionContants.LR_ALG_NAME.equalsIgnoreCase(alg) || // RF or GBT algortihm
CommonUtils.isTreeModel(alg) || Constants.TF_ALG_NAME.equalsIgnoreCase(alg) || Constants.WDL.equalsIgnoreCase(alg))) {
throw new IllegalArgumentException("Currently we only support NN, LR, RF(RandomForest), WDL and GBDT(Gradient Boost Desicion Tree) distributed training.");
}
if ((LogisticRegressionContants.LR_ALG_NAME.equalsIgnoreCase(alg) || CommonConstants.GBT_ALG_NAME.equalsIgnoreCase(alg)) && modelConfig.isClassification() && modelConfig.getTrain().getMultiClassifyMethod() == MultipleClassification.NATIVE) {
throw new IllegalArgumentException("Distributed LR, GBDT(Gradient Boost Desicion Tree) only support binary classification, native multiple classification is not supported.");
}
if (modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll() && !CommonUtils.isTreeModel(alg) && !NNConstants.NN_ALG_NAME.equalsIgnoreCase(alg)) {
throw new IllegalArgumentException("Only GBT and RF and NN support OneVsAll multiple classification.");
}
if (super.getModelConfig().getDataSet().getSource() != SourceType.HDFS) {
throw new IllegalArgumentException("Currently we only support distributed training on HDFS source type.");
}
if (isDebug()) {
LOG.warn("Currently we haven't debug logic. It's the same as you don't set it.");
}
// check if parquet format norm output is consistent with current isParquet setting.
boolean isParquetMetaFileExist = false;
try {
isParquetMetaFileExist = ShifuFileUtils.getFileSystemBySourceType(super.getModelConfig().getDataSet().getSource()).exists(new Path(super.getPathFinder().getNormalizedDataPath(), "_common_metadata"));
} catch (Exception e) {
isParquetMetaFileExist = false;
}
if (super.modelConfig.getNormalize().getIsParquet() && !isParquetMetaFileExist) {
throw new IllegalArgumentException("Your normlized input in " + super.getPathFinder().getNormalizedDataPath() + " is not parquet format. Please keep isParquet and re-run norm again and then run training step or change isParquet to false.");
} else if (!super.modelConfig.getNormalize().getIsParquet() && isParquetMetaFileExist) {
throw new IllegalArgumentException("Your normlized input in " + super.getPathFinder().getNormalizedDataPath() + " is parquet format. Please keep isParquet and re-run norm again or change isParquet directly to true.");
}
GridSearch gridSearch = new GridSearch(modelConfig.getTrain().getParams(), modelConfig.getTrain().getGridConfigFileContent());
if (!LogisticRegressionContants.LR_ALG_NAME.equalsIgnoreCase(alg) && !NNConstants.NN_ALG_NAME.equalsIgnoreCase(alg) && !CommonUtils.isTreeModel(alg) && gridSearch.hasHyperParam()) {
// if grid search but not NN, not RF, not GBT, not LR
throw new IllegalArgumentException("Grid search only supports NN, GBT and RF algorithms");
}
if (gridSearch.hasHyperParam() && super.getModelConfig().getDataSet().getSource() != SourceType.HDFS && modelConfig.isDistributedRunMode()) {
// if grid search but not mapred/dist run mode, not hdfs raw data set
throw new IllegalArgumentException("Grid search only supports NN, GBT and RF algorithms");
}
}
Aggregations