Search in sources :

Example 1 with GuaguaRuntimeException

use of ml.shifu.guagua.GuaguaRuntimeException in project shifu by ShifuML.

the class DTMaster method recoverMasterStatus.

private void recoverMasterStatus(SourceType sourceType) {
    FSDataInputStream stream = null;
    FileSystem fs = ShifuFileUtils.getFileSystemBySourceType(sourceType);
    try {
        stream = fs.open(this.checkpointOutput);
        int treeSize = stream.readInt();
        this.trees = new CopyOnWriteArrayList<TreeNode>();
        for (int i = 0; i < treeSize; i++) {
            TreeNode treeNode = new TreeNode();
            treeNode.readFields(stream);
            this.trees.add(treeNode);
        }
        int queueSize = stream.readInt();
        for (int i = 0; i < queueSize; i++) {
            TreeNode treeNode = new TreeNode();
            treeNode.readFields(stream);
            this.toDoQueue.offer(treeNode);
        }
        if (this.isLeafWise && this.toSplitQueue != null) {
            queueSize = stream.readInt();
            for (int i = 0; i < queueSize; i++) {
                TreeNode treeNode = new TreeNode();
                treeNode.readFields(stream);
                this.toSplitQueue.offer(treeNode);
            }
        }
        this.cpMasterParams = new DTMasterParams();
        this.cpMasterParams.readFields(stream);
    } catch (IOException e) {
        throw new GuaguaRuntimeException(e);
    } finally {
        org.apache.commons.io.IOUtils.closeQuietly(stream);
    }
}
Also used : FileSystem(org.apache.hadoop.fs.FileSystem) FSDataInputStream(org.apache.hadoop.fs.FSDataInputStream) IOException(java.io.IOException) GuaguaRuntimeException(ml.shifu.guagua.GuaguaRuntimeException)

Example 2 with GuaguaRuntimeException

use of ml.shifu.guagua.GuaguaRuntimeException in project shifu by ShifuML.

the class DTWorker method recoverCurrentTrees.

private List<TreeNode> recoverCurrentTrees() {
    FSDataInputStream stream = null;
    List<TreeNode> trees = null;
    try {
        if (!ShifuFileUtils.isFileExists(this.checkpointOutput.toString(), this.modelConfig.getDataSet().getSource())) {
            return null;
        }
        FileSystem fs = ShifuFileUtils.getFileSystemBySourceType(this.modelConfig.getDataSet().getSource());
        stream = fs.open(this.checkpointOutput);
        int treeSize = stream.readInt();
        trees = new ArrayList<TreeNode>(treeSize);
        for (int i = 0; i < treeSize; i++) {
            TreeNode treeNode = new TreeNode();
            treeNode.readFields(stream);
            trees.add(treeNode);
        }
    } catch (IOException e) {
        throw new GuaguaRuntimeException(e);
    } finally {
        org.apache.commons.io.IOUtils.closeQuietly(stream);
    }
    return trees;
}
Also used : FileSystem(org.apache.hadoop.fs.FileSystem) FSDataInputStream(org.apache.hadoop.fs.FSDataInputStream) IOException(java.io.IOException) GuaguaRuntimeException(ml.shifu.guagua.GuaguaRuntimeException)

Example 3 with GuaguaRuntimeException

use of ml.shifu.guagua.GuaguaRuntimeException in project shifu by ShifuML.

the class DTMaster method init.

@Override
public void init(MasterContext<DTMasterParams, DTWorkerParams> context) {
    Properties props = context.getProps();
    // init model config and column config list at first
    SourceType sourceType;
    try {
        sourceType = SourceType.valueOf(props.getProperty(CommonConstants.MODELSET_SOURCE_TYPE, SourceType.HDFS.toString()));
        this.modelConfig = CommonUtils.loadModelConfig(props.getProperty(CommonConstants.SHIFU_MODEL_CONFIG), sourceType);
        this.columnConfigList = CommonUtils.loadColumnConfigList(props.getProperty(CommonConstants.SHIFU_COLUMN_CONFIG), sourceType);
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
    // worker number is used to estimate nodes per iteration for stats
    this.workerNumber = NumberFormatUtils.getInt(props.getProperty(GuaguaConstants.GUAGUA_WORKER_NUMBER), true);
    // check if variables are set final selected
    int[] inputOutputIndex = DTrainUtils.getNumericAndCategoricalInputAndOutputCounts(this.columnConfigList);
    this.inputNum = inputOutputIndex[0] + inputOutputIndex[1];
    this.isAfterVarSelect = (inputOutputIndex[3] == 1);
    // cache all feature list for sampling features
    this.allFeatures = this.getAllFeatureList(columnConfigList, isAfterVarSelect);
    int trainerId = Integer.valueOf(context.getProps().getProperty(CommonConstants.SHIFU_TRAINER_ID, "0"));
    // If grid search, select valid paramters, if not parameters is what in ModelConfig.json
    GridSearch gs = new GridSearch(modelConfig.getTrain().getParams(), modelConfig.getTrain().getGridConfigFileContent());
    Map<String, Object> validParams = this.modelConfig.getTrain().getParams();
    if (gs.hasHyperParam()) {
        validParams = gs.getParams(trainerId);
        LOG.info("Start grid search master with params: {}", validParams);
    }
    Object vtObj = validParams.get("ValidationTolerance");
    if (vtObj != null) {
        try {
            validationTolerance = Double.parseDouble(vtObj.toString());
            LOG.warn("Validation by tolerance is enabled with value {}.", validationTolerance);
        } catch (NumberFormatException ee) {
            validationTolerance = 0d;
            LOG.warn("Validation by tolerance isn't enabled because of non numerical value of ValidationTolerance: {}.", vtObj);
        }
    } else {
        LOG.warn("Validation by tolerance isn't enabled.");
    }
    // tree related parameters initialization
    Object fssObj = validParams.get("FeatureSubsetStrategy");
    if (fssObj != null) {
        try {
            this.featureSubsetRate = Double.parseDouble(fssObj.toString());
            // no need validate featureSubsetRate is in (0,1], as already validated in ModelInspector
            this.featureSubsetStrategy = null;
        } catch (NumberFormatException ee) {
            this.featureSubsetStrategy = FeatureSubsetStrategy.of(fssObj.toString());
        }
    } else {
        LOG.warn("FeatureSubsetStrategy is not set, set to TWOTHRIDS by default in DTMaster.");
        this.featureSubsetStrategy = FeatureSubsetStrategy.TWOTHIRDS;
        this.featureSubsetRate = 0;
    }
    // max depth
    Object maxDepthObj = validParams.get("MaxDepth");
    if (maxDepthObj != null) {
        this.maxDepth = Integer.valueOf(maxDepthObj.toString());
    } else {
        this.maxDepth = 10;
    }
    // max leaves which is used for leaf-wised tree building, TODO add more benchmarks
    Object maxLeavesObj = validParams.get("MaxLeaves");
    if (maxLeavesObj != null) {
        this.maxLeaves = Integer.valueOf(maxLeavesObj.toString());
    } else {
        this.maxLeaves = -1;
    }
    // enable leaf wise tree building once maxLeaves is configured
    if (this.maxLeaves > 0) {
        this.isLeafWise = true;
    }
    // maxBatchSplitSize means each time split # of batch nodes
    Object maxBatchSplitSizeObj = validParams.get("MaxBatchSplitSize");
    if (maxBatchSplitSizeObj != null) {
        this.maxBatchSplitSize = Integer.valueOf(maxBatchSplitSizeObj.toString());
    } else {
        // by default split 32 at most in a batch
        this.maxBatchSplitSize = 32;
    }
    assert this.maxDepth > 0 && this.maxDepth <= 20;
    // hide in parameters, this to avoid OOM issue for each iteration
    Object maxStatsMemoryMB = validParams.get("MaxStatsMemoryMB");
    if (maxStatsMemoryMB != null) {
        this.maxStatsMemory = Long.valueOf(validParams.get("MaxStatsMemoryMB").toString()) * 1024 * 1024;
        if (this.maxStatsMemory > ((2L * Runtime.getRuntime().maxMemory()) / 3)) {
            // if >= 2/3 max memory, take 2/3 max memory to avoid OOM
            this.maxStatsMemory = ((2L * Runtime.getRuntime().maxMemory()) / 3);
        }
    } else {
        // by default it is 1/2 of heap, about 1.5G setting in current Shifu
        this.maxStatsMemory = Runtime.getRuntime().maxMemory() / 2L;
    }
    // assert this.maxStatsMemory <= Math.min(Runtime.getRuntime().maxMemory() * 0.6, 800 * 1024 * 1024L);
    this.treeNum = Integer.valueOf(validParams.get("TreeNum").toString());
    this.isRF = ALGORITHM.RF.toString().equalsIgnoreCase(modelConfig.getAlgorithm());
    this.isGBDT = ALGORITHM.GBT.toString().equalsIgnoreCase(modelConfig.getAlgorithm());
    if (this.isGBDT) {
        // learning rate only effective in gbdt
        this.learningRate = Double.valueOf(validParams.get(CommonConstants.LEARNING_RATE).toString());
    }
    // initialize impurity type according to regression or classfication
    String imStr = validParams.get("Impurity").toString();
    int numClasses = 2;
    if (this.modelConfig.isClassification()) {
        numClasses = this.modelConfig.getTags().size();
    }
    // these two parameters is to stop tree growth parameters
    int minInstancesPerNode = Integer.valueOf(validParams.get("MinInstancesPerNode").toString());
    double minInfoGain = Double.valueOf(validParams.get("MinInfoGain").toString());
    if (imStr.equalsIgnoreCase("entropy")) {
        impurity = new Entropy(numClasses, minInstancesPerNode, minInfoGain);
    } else if (imStr.equalsIgnoreCase("gini")) {
        impurity = new Gini(numClasses, minInstancesPerNode, minInfoGain);
    } else {
        impurity = new Variance(minInstancesPerNode, minInfoGain);
    }
    // checkpoint folder and interval (every # iterations to do checkpoint)
    this.checkpointInterval = NumberFormatUtils.getInt(context.getProps().getProperty(CommonConstants.SHIFU_DT_MASTER_CHECKPOINT_INTERVAL, "20"));
    this.checkpointOutput = new Path(context.getProps().getProperty(CommonConstants.SHIFU_DT_MASTER_CHECKPOINT_FOLDER, "tmp/cp_" + context.getAppId()));
    // cache conf to avoid new
    this.conf = new Configuration();
    // if continuous model training is enabled
    this.isContinuousEnabled = Boolean.TRUE.toString().equalsIgnoreCase(context.getProps().getProperty(CommonConstants.CONTINUOUS_TRAINING));
    this.dtEarlyStopDecider = new DTEarlyStopDecider(this.maxDepth);
    if (validParams.containsKey("EnableEarlyStop") && Boolean.valueOf(validParams.get("EnableEarlyStop").toString().toLowerCase())) {
        this.enableEarlyStop = true;
    }
    LOG.info("Master init params: isAfterVarSel={}, featureSubsetStrategy={}, featureSubsetRate={} maxDepth={}, maxStatsMemory={}, " + "treeNum={}, impurity={}, workerNumber={}, minInstancesPerNode={}, minInfoGain={}, isRF={}, " + "isGBDT={}, isContinuousEnabled={}, enableEarlyStop={}.", isAfterVarSelect, featureSubsetStrategy, this.featureSubsetRate, maxDepth, maxStatsMemory, treeNum, imStr, this.workerNumber, minInstancesPerNode, minInfoGain, this.isRF, this.isGBDT, this.isContinuousEnabled, this.enableEarlyStop);
    this.toDoQueue = new LinkedList<TreeNode>();
    if (this.isLeafWise) {
        this.toSplitQueue = new PriorityQueue<TreeNode>(64, new Comparator<TreeNode>() {

            @Override
            public int compare(TreeNode o1, TreeNode o2) {
                return Double.compare(o2.getNode().getWgtCntRatio() * o2.getNode().getGain(), o1.getNode().getWgtCntRatio() * o1.getNode().getGain());
            }
        });
    }
    // initialize trees
    if (context.isFirstIteration()) {
        if (this.isRF) {
            // for random forest, trees are trained in parallel
            this.trees = new CopyOnWriteArrayList<TreeNode>();
            for (int i = 0; i < treeNum; i++) {
                this.trees.add(new TreeNode(i, new Node(Node.ROOT_INDEX), 1d));
            }
        }
        if (this.isGBDT) {
            if (isContinuousEnabled) {
                TreeModel existingModel;
                try {
                    Path modelPath = new Path(context.getProps().getProperty(CommonConstants.GUAGUA_OUTPUT));
                    existingModel = (TreeModel) ModelSpecLoaderUtils.loadModel(modelConfig, modelPath, ShifuFileUtils.getFileSystemBySourceType(this.modelConfig.getDataSet().getSource()));
                    if (existingModel == null) {
                        // null means no existing model file or model file is in wrong format
                        this.trees = new CopyOnWriteArrayList<TreeNode>();
                        // learning rate is 1 for 1st
                        this.trees.add(new TreeNode(0, new Node(Node.ROOT_INDEX), 1d));
                        LOG.info("Starting to train model from scratch and existing model is empty.");
                    } else {
                        this.trees = existingModel.getTrees();
                        this.existingTreeSize = this.trees.size();
                        // starting from existing models, first tree learning rate is current learning rate
                        this.trees.add(new TreeNode(this.existingTreeSize, new Node(Node.ROOT_INDEX), this.existingTreeSize == 0 ? 1d : this.learningRate));
                        LOG.info("Starting to train model from existing model {} with existing trees {}.", modelPath, existingTreeSize);
                    }
                } catch (IOException e) {
                    throw new GuaguaRuntimeException(e);
                }
            } else {
                this.trees = new CopyOnWriteArrayList<TreeNode>();
                // for GBDT, initialize the first tree. trees are trained sequentially,first tree learning rate is 1
                this.trees.add(new TreeNode(0, new Node(Node.ROOT_INDEX), 1.0d));
            }
        }
    } else {
        // recover all states once master is fail over
        LOG.info("Recover master status from checkpoint file {}", this.checkpointOutput);
        recoverMasterStatus(sourceType);
    }
}
Also used : Configuration(org.apache.hadoop.conf.Configuration) SourceType(ml.shifu.shifu.container.obj.RawSourceData.SourceType) Properties(java.util.Properties) Comparator(java.util.Comparator) TreeModel(ml.shifu.shifu.core.TreeModel) GuaguaRuntimeException(ml.shifu.guagua.GuaguaRuntimeException) Path(org.apache.hadoop.fs.Path) IOException(java.io.IOException) GridSearch(ml.shifu.shifu.core.dtrain.gs.GridSearch) GuaguaRuntimeException(ml.shifu.guagua.GuaguaRuntimeException)

Example 4 with GuaguaRuntimeException

use of ml.shifu.guagua.GuaguaRuntimeException in project shifu by ShifuML.

the class GuaguaParquetRecordReader method buildContext.

/*
     * Build context through reflection to make sure code compatible between hadoop 1 and hadoop 2
     */
private TaskAttemptContext buildContext() {
    TaskAttemptID id = null;
    TaskAttemptContext context = null;
    try {
        if (isHadoop2()) {
            Class<?> taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType");
            Constructor<TaskAttemptID> constructor = TaskAttemptID.class.getDeclaredConstructor(String.class, Integer.TYPE, taskTypeClass, Integer.TYPE, Integer.TYPE);
            id = constructor.newInstance("mock", -1, fromEnumConstantName(taskTypeClass, "MAP"), -1, -1);
            Constructor<?> contextConstructor = Class.forName("org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl").getDeclaredConstructor(Configuration.class, TaskAttemptID.class);
            context = (TaskAttemptContext) contextConstructor.newInstance(this.conf, id);
        } else {
            Constructor<TaskAttemptID> constructor = TaskAttemptID.class.getDeclaredConstructor(String.class, Integer.TYPE, Boolean.TYPE, Integer.TYPE, Integer.TYPE);
            constructor.setAccessible(true);
            id = constructor.newInstance("mock", -1, false, -1, -1);
            Constructor<?> contextConstructor = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptContext").getDeclaredConstructor(Configuration.class, TaskAttemptID.class);
            context = (TaskAttemptContext) contextConstructor.newInstance(this.conf, id);
        }
    } catch (Throwable e) {
        throw new GuaguaRuntimeException(e);
    }
    return context;
}
Also used : TaskAttemptID(org.apache.hadoop.mapreduce.TaskAttemptID) TaskAttemptContext(org.apache.hadoop.mapreduce.TaskAttemptContext) GuaguaRuntimeException(ml.shifu.guagua.GuaguaRuntimeException)

Example 5 with GuaguaRuntimeException

use of ml.shifu.guagua.GuaguaRuntimeException in project shifu by ShifuML.

the class NNMaster method initOrRecoverParams.

private NNParams initOrRecoverParams(MasterContext<NNParams, NNParams> context) {
    // read existing model weights
    NNParams params = null;
    try {
        Path modelPath = new Path(context.getProps().getProperty(CommonConstants.GUAGUA_OUTPUT));
        BasicML basicML = ModelSpecLoaderUtils.loadModel(modelConfig, modelPath, ShifuFileUtils.getFileSystemBySourceType(this.modelConfig.getDataSet().getSource()));
        params = initWeights();
        BasicFloatNetwork existingModel = (BasicFloatNetwork) ModelSpecLoaderUtils.getBasicNetwork(basicML);
        if (existingModel != null) {
            LOG.info("Starting to train model from existing model {}.", modelPath);
            int mspecCompareResult = new NNStructureComparator().compare(this.flatNetwork, existingModel.getFlat());
            if (mspecCompareResult == 0) {
                // same model structure
                params.setWeights(existingModel.getFlat().getWeights());
                this.fixedWeightIndexSet = getFixedWights(fixedLayers);
            } else if (mspecCompareResult == 1) {
                // new model structure is larger than existing one
                this.fixedWeightIndexSet = fitExistingModelIn(existingModel.getFlat(), this.flatNetwork, this.fixedLayers, this.fixedBias);
            } else {
                // new model structure is smaller, couldn't hold existing one
                throw new GuaguaRuntimeException("Network changed for recover or continuous training. " + "New network couldn't hold existing network!");
            }
        } else {
            LOG.info("Starting to train model from scratch.");
        }
    } catch (IOException e) {
        throw new GuaguaRuntimeException(e);
    }
    return params;
}
Also used : Path(org.apache.hadoop.fs.Path) BasicML(org.encog.ml.BasicML) BasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork) GuaguaRuntimeException(ml.shifu.guagua.GuaguaRuntimeException) IOException(java.io.IOException)

Aggregations

GuaguaRuntimeException (ml.shifu.guagua.GuaguaRuntimeException)10 IOException (java.io.IOException)7 Path (org.apache.hadoop.fs.Path)4 GridSearch (ml.shifu.shifu.core.dtrain.gs.GridSearch)2 FSDataInputStream (org.apache.hadoop.fs.FSDataInputStream)2 FileSystem (org.apache.hadoop.fs.FileSystem)2 Tuple (org.apache.pig.data.Tuple)2 ByteArrayOutputStream (java.io.ByteArrayOutputStream)1 DataOutputStream (java.io.DataOutputStream)1 Comparator (java.util.Comparator)1 Properties (java.util.Properties)1 ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)1 SourceType (ml.shifu.shifu.container.obj.RawSourceData.SourceType)1 LR (ml.shifu.shifu.core.LR)1 TreeModel (ml.shifu.shifu.core.TreeModel)1 BasicFloatMLData (ml.shifu.shifu.core.dtrain.dataset.BasicFloatMLData)1 BasicFloatMLDataPair (ml.shifu.shifu.core.dtrain.dataset.BasicFloatMLDataPair)1 BasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)1 BufferedFloatMLDataSet (ml.shifu.shifu.core.dtrain.dataset.BufferedFloatMLDataSet)1 FloatMLDataPair (ml.shifu.shifu.core.dtrain.dataset.FloatMLDataPair)1