use of ml.shifu.guagua.GuaguaRuntimeException in project shifu by ShifuML.
the class DTMaster method recoverMasterStatus.
private void recoverMasterStatus(SourceType sourceType) {
FSDataInputStream stream = null;
FileSystem fs = ShifuFileUtils.getFileSystemBySourceType(sourceType);
try {
stream = fs.open(this.checkpointOutput);
int treeSize = stream.readInt();
this.trees = new CopyOnWriteArrayList<TreeNode>();
for (int i = 0; i < treeSize; i++) {
TreeNode treeNode = new TreeNode();
treeNode.readFields(stream);
this.trees.add(treeNode);
}
int queueSize = stream.readInt();
for (int i = 0; i < queueSize; i++) {
TreeNode treeNode = new TreeNode();
treeNode.readFields(stream);
this.toDoQueue.offer(treeNode);
}
if (this.isLeafWise && this.toSplitQueue != null) {
queueSize = stream.readInt();
for (int i = 0; i < queueSize; i++) {
TreeNode treeNode = new TreeNode();
treeNode.readFields(stream);
this.toSplitQueue.offer(treeNode);
}
}
this.cpMasterParams = new DTMasterParams();
this.cpMasterParams.readFields(stream);
} catch (IOException e) {
throw new GuaguaRuntimeException(e);
} finally {
org.apache.commons.io.IOUtils.closeQuietly(stream);
}
}
use of ml.shifu.guagua.GuaguaRuntimeException in project shifu by ShifuML.
the class DTWorker method recoverCurrentTrees.
private List<TreeNode> recoverCurrentTrees() {
FSDataInputStream stream = null;
List<TreeNode> trees = null;
try {
if (!ShifuFileUtils.isFileExists(this.checkpointOutput.toString(), this.modelConfig.getDataSet().getSource())) {
return null;
}
FileSystem fs = ShifuFileUtils.getFileSystemBySourceType(this.modelConfig.getDataSet().getSource());
stream = fs.open(this.checkpointOutput);
int treeSize = stream.readInt();
trees = new ArrayList<TreeNode>(treeSize);
for (int i = 0; i < treeSize; i++) {
TreeNode treeNode = new TreeNode();
treeNode.readFields(stream);
trees.add(treeNode);
}
} catch (IOException e) {
throw new GuaguaRuntimeException(e);
} finally {
org.apache.commons.io.IOUtils.closeQuietly(stream);
}
return trees;
}
use of ml.shifu.guagua.GuaguaRuntimeException in project shifu by ShifuML.
the class DTMaster method init.
@Override
public void init(MasterContext<DTMasterParams, DTWorkerParams> context) {
Properties props = context.getProps();
// init model config and column config list at first
SourceType sourceType;
try {
sourceType = SourceType.valueOf(props.getProperty(CommonConstants.MODELSET_SOURCE_TYPE, SourceType.HDFS.toString()));
this.modelConfig = CommonUtils.loadModelConfig(props.getProperty(CommonConstants.SHIFU_MODEL_CONFIG), sourceType);
this.columnConfigList = CommonUtils.loadColumnConfigList(props.getProperty(CommonConstants.SHIFU_COLUMN_CONFIG), sourceType);
} catch (IOException e) {
throw new RuntimeException(e);
}
// worker number is used to estimate nodes per iteration for stats
this.workerNumber = NumberFormatUtils.getInt(props.getProperty(GuaguaConstants.GUAGUA_WORKER_NUMBER), true);
// check if variables are set final selected
int[] inputOutputIndex = DTrainUtils.getNumericAndCategoricalInputAndOutputCounts(this.columnConfigList);
this.inputNum = inputOutputIndex[0] + inputOutputIndex[1];
this.isAfterVarSelect = (inputOutputIndex[3] == 1);
// cache all feature list for sampling features
this.allFeatures = this.getAllFeatureList(columnConfigList, isAfterVarSelect);
int trainerId = Integer.valueOf(context.getProps().getProperty(CommonConstants.SHIFU_TRAINER_ID, "0"));
// If grid search, select valid paramters, if not parameters is what in ModelConfig.json
GridSearch gs = new GridSearch(modelConfig.getTrain().getParams(), modelConfig.getTrain().getGridConfigFileContent());
Map<String, Object> validParams = this.modelConfig.getTrain().getParams();
if (gs.hasHyperParam()) {
validParams = gs.getParams(trainerId);
LOG.info("Start grid search master with params: {}", validParams);
}
Object vtObj = validParams.get("ValidationTolerance");
if (vtObj != null) {
try {
validationTolerance = Double.parseDouble(vtObj.toString());
LOG.warn("Validation by tolerance is enabled with value {}.", validationTolerance);
} catch (NumberFormatException ee) {
validationTolerance = 0d;
LOG.warn("Validation by tolerance isn't enabled because of non numerical value of ValidationTolerance: {}.", vtObj);
}
} else {
LOG.warn("Validation by tolerance isn't enabled.");
}
// tree related parameters initialization
Object fssObj = validParams.get("FeatureSubsetStrategy");
if (fssObj != null) {
try {
this.featureSubsetRate = Double.parseDouble(fssObj.toString());
// no need validate featureSubsetRate is in (0,1], as already validated in ModelInspector
this.featureSubsetStrategy = null;
} catch (NumberFormatException ee) {
this.featureSubsetStrategy = FeatureSubsetStrategy.of(fssObj.toString());
}
} else {
LOG.warn("FeatureSubsetStrategy is not set, set to TWOTHRIDS by default in DTMaster.");
this.featureSubsetStrategy = FeatureSubsetStrategy.TWOTHIRDS;
this.featureSubsetRate = 0;
}
// max depth
Object maxDepthObj = validParams.get("MaxDepth");
if (maxDepthObj != null) {
this.maxDepth = Integer.valueOf(maxDepthObj.toString());
} else {
this.maxDepth = 10;
}
// max leaves which is used for leaf-wised tree building, TODO add more benchmarks
Object maxLeavesObj = validParams.get("MaxLeaves");
if (maxLeavesObj != null) {
this.maxLeaves = Integer.valueOf(maxLeavesObj.toString());
} else {
this.maxLeaves = -1;
}
// enable leaf wise tree building once maxLeaves is configured
if (this.maxLeaves > 0) {
this.isLeafWise = true;
}
// maxBatchSplitSize means each time split # of batch nodes
Object maxBatchSplitSizeObj = validParams.get("MaxBatchSplitSize");
if (maxBatchSplitSizeObj != null) {
this.maxBatchSplitSize = Integer.valueOf(maxBatchSplitSizeObj.toString());
} else {
// by default split 32 at most in a batch
this.maxBatchSplitSize = 32;
}
assert this.maxDepth > 0 && this.maxDepth <= 20;
// hide in parameters, this to avoid OOM issue for each iteration
Object maxStatsMemoryMB = validParams.get("MaxStatsMemoryMB");
if (maxStatsMemoryMB != null) {
this.maxStatsMemory = Long.valueOf(validParams.get("MaxStatsMemoryMB").toString()) * 1024 * 1024;
if (this.maxStatsMemory > ((2L * Runtime.getRuntime().maxMemory()) / 3)) {
// if >= 2/3 max memory, take 2/3 max memory to avoid OOM
this.maxStatsMemory = ((2L * Runtime.getRuntime().maxMemory()) / 3);
}
} else {
// by default it is 1/2 of heap, about 1.5G setting in current Shifu
this.maxStatsMemory = Runtime.getRuntime().maxMemory() / 2L;
}
// assert this.maxStatsMemory <= Math.min(Runtime.getRuntime().maxMemory() * 0.6, 800 * 1024 * 1024L);
this.treeNum = Integer.valueOf(validParams.get("TreeNum").toString());
this.isRF = ALGORITHM.RF.toString().equalsIgnoreCase(modelConfig.getAlgorithm());
this.isGBDT = ALGORITHM.GBT.toString().equalsIgnoreCase(modelConfig.getAlgorithm());
if (this.isGBDT) {
// learning rate only effective in gbdt
this.learningRate = Double.valueOf(validParams.get(CommonConstants.LEARNING_RATE).toString());
}
// initialize impurity type according to regression or classfication
String imStr = validParams.get("Impurity").toString();
int numClasses = 2;
if (this.modelConfig.isClassification()) {
numClasses = this.modelConfig.getTags().size();
}
// these two parameters is to stop tree growth parameters
int minInstancesPerNode = Integer.valueOf(validParams.get("MinInstancesPerNode").toString());
double minInfoGain = Double.valueOf(validParams.get("MinInfoGain").toString());
if (imStr.equalsIgnoreCase("entropy")) {
impurity = new Entropy(numClasses, minInstancesPerNode, minInfoGain);
} else if (imStr.equalsIgnoreCase("gini")) {
impurity = new Gini(numClasses, minInstancesPerNode, minInfoGain);
} else {
impurity = new Variance(minInstancesPerNode, minInfoGain);
}
// checkpoint folder and interval (every # iterations to do checkpoint)
this.checkpointInterval = NumberFormatUtils.getInt(context.getProps().getProperty(CommonConstants.SHIFU_DT_MASTER_CHECKPOINT_INTERVAL, "20"));
this.checkpointOutput = new Path(context.getProps().getProperty(CommonConstants.SHIFU_DT_MASTER_CHECKPOINT_FOLDER, "tmp/cp_" + context.getAppId()));
// cache conf to avoid new
this.conf = new Configuration();
// if continuous model training is enabled
this.isContinuousEnabled = Boolean.TRUE.toString().equalsIgnoreCase(context.getProps().getProperty(CommonConstants.CONTINUOUS_TRAINING));
this.dtEarlyStopDecider = new DTEarlyStopDecider(this.maxDepth);
if (validParams.containsKey("EnableEarlyStop") && Boolean.valueOf(validParams.get("EnableEarlyStop").toString().toLowerCase())) {
this.enableEarlyStop = true;
}
LOG.info("Master init params: isAfterVarSel={}, featureSubsetStrategy={}, featureSubsetRate={} maxDepth={}, maxStatsMemory={}, " + "treeNum={}, impurity={}, workerNumber={}, minInstancesPerNode={}, minInfoGain={}, isRF={}, " + "isGBDT={}, isContinuousEnabled={}, enableEarlyStop={}.", isAfterVarSelect, featureSubsetStrategy, this.featureSubsetRate, maxDepth, maxStatsMemory, treeNum, imStr, this.workerNumber, minInstancesPerNode, minInfoGain, this.isRF, this.isGBDT, this.isContinuousEnabled, this.enableEarlyStop);
this.toDoQueue = new LinkedList<TreeNode>();
if (this.isLeafWise) {
this.toSplitQueue = new PriorityQueue<TreeNode>(64, new Comparator<TreeNode>() {
@Override
public int compare(TreeNode o1, TreeNode o2) {
return Double.compare(o2.getNode().getWgtCntRatio() * o2.getNode().getGain(), o1.getNode().getWgtCntRatio() * o1.getNode().getGain());
}
});
}
// initialize trees
if (context.isFirstIteration()) {
if (this.isRF) {
// for random forest, trees are trained in parallel
this.trees = new CopyOnWriteArrayList<TreeNode>();
for (int i = 0; i < treeNum; i++) {
this.trees.add(new TreeNode(i, new Node(Node.ROOT_INDEX), 1d));
}
}
if (this.isGBDT) {
if (isContinuousEnabled) {
TreeModel existingModel;
try {
Path modelPath = new Path(context.getProps().getProperty(CommonConstants.GUAGUA_OUTPUT));
existingModel = (TreeModel) ModelSpecLoaderUtils.loadModel(modelConfig, modelPath, ShifuFileUtils.getFileSystemBySourceType(this.modelConfig.getDataSet().getSource()));
if (existingModel == null) {
// null means no existing model file or model file is in wrong format
this.trees = new CopyOnWriteArrayList<TreeNode>();
// learning rate is 1 for 1st
this.trees.add(new TreeNode(0, new Node(Node.ROOT_INDEX), 1d));
LOG.info("Starting to train model from scratch and existing model is empty.");
} else {
this.trees = existingModel.getTrees();
this.existingTreeSize = this.trees.size();
// starting from existing models, first tree learning rate is current learning rate
this.trees.add(new TreeNode(this.existingTreeSize, new Node(Node.ROOT_INDEX), this.existingTreeSize == 0 ? 1d : this.learningRate));
LOG.info("Starting to train model from existing model {} with existing trees {}.", modelPath, existingTreeSize);
}
} catch (IOException e) {
throw new GuaguaRuntimeException(e);
}
} else {
this.trees = new CopyOnWriteArrayList<TreeNode>();
// for GBDT, initialize the first tree. trees are trained sequentially,first tree learning rate is 1
this.trees.add(new TreeNode(0, new Node(Node.ROOT_INDEX), 1.0d));
}
}
} else {
// recover all states once master is fail over
LOG.info("Recover master status from checkpoint file {}", this.checkpointOutput);
recoverMasterStatus(sourceType);
}
}
use of ml.shifu.guagua.GuaguaRuntimeException in project shifu by ShifuML.
the class GuaguaParquetRecordReader method buildContext.
/*
* Build context through reflection to make sure code compatible between hadoop 1 and hadoop 2
*/
private TaskAttemptContext buildContext() {
TaskAttemptID id = null;
TaskAttemptContext context = null;
try {
if (isHadoop2()) {
Class<?> taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType");
Constructor<TaskAttemptID> constructor = TaskAttemptID.class.getDeclaredConstructor(String.class, Integer.TYPE, taskTypeClass, Integer.TYPE, Integer.TYPE);
id = constructor.newInstance("mock", -1, fromEnumConstantName(taskTypeClass, "MAP"), -1, -1);
Constructor<?> contextConstructor = Class.forName("org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl").getDeclaredConstructor(Configuration.class, TaskAttemptID.class);
context = (TaskAttemptContext) contextConstructor.newInstance(this.conf, id);
} else {
Constructor<TaskAttemptID> constructor = TaskAttemptID.class.getDeclaredConstructor(String.class, Integer.TYPE, Boolean.TYPE, Integer.TYPE, Integer.TYPE);
constructor.setAccessible(true);
id = constructor.newInstance("mock", -1, false, -1, -1);
Constructor<?> contextConstructor = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptContext").getDeclaredConstructor(Configuration.class, TaskAttemptID.class);
context = (TaskAttemptContext) contextConstructor.newInstance(this.conf, id);
}
} catch (Throwable e) {
throw new GuaguaRuntimeException(e);
}
return context;
}
use of ml.shifu.guagua.GuaguaRuntimeException in project shifu by ShifuML.
the class NNMaster method initOrRecoverParams.
private NNParams initOrRecoverParams(MasterContext<NNParams, NNParams> context) {
// read existing model weights
NNParams params = null;
try {
Path modelPath = new Path(context.getProps().getProperty(CommonConstants.GUAGUA_OUTPUT));
BasicML basicML = ModelSpecLoaderUtils.loadModel(modelConfig, modelPath, ShifuFileUtils.getFileSystemBySourceType(this.modelConfig.getDataSet().getSource()));
params = initWeights();
BasicFloatNetwork existingModel = (BasicFloatNetwork) ModelSpecLoaderUtils.getBasicNetwork(basicML);
if (existingModel != null) {
LOG.info("Starting to train model from existing model {}.", modelPath);
int mspecCompareResult = new NNStructureComparator().compare(this.flatNetwork, existingModel.getFlat());
if (mspecCompareResult == 0) {
// same model structure
params.setWeights(existingModel.getFlat().getWeights());
this.fixedWeightIndexSet = getFixedWights(fixedLayers);
} else if (mspecCompareResult == 1) {
// new model structure is larger than existing one
this.fixedWeightIndexSet = fitExistingModelIn(existingModel.getFlat(), this.flatNetwork, this.fixedLayers, this.fixedBias);
} else {
// new model structure is smaller, couldn't hold existing one
throw new GuaguaRuntimeException("Network changed for recover or continuous training. " + "New network couldn't hold existing network!");
}
} else {
LOG.info("Starting to train model from scratch.");
}
} catch (IOException e) {
throw new GuaguaRuntimeException(e);
}
return params;
}
Aggregations