Search in sources :

Example 1 with TreeModel

use of ml.shifu.shifu.core.TreeModel in project shifu by ShifuML.

the class DTMaster method init.

@Override
public void init(MasterContext<DTMasterParams, DTWorkerParams> context) {
    Properties props = context.getProps();
    // init model config and column config list at first
    SourceType sourceType;
    try {
        sourceType = SourceType.valueOf(props.getProperty(CommonConstants.MODELSET_SOURCE_TYPE, SourceType.HDFS.toString()));
        this.modelConfig = CommonUtils.loadModelConfig(props.getProperty(CommonConstants.SHIFU_MODEL_CONFIG), sourceType);
        this.columnConfigList = CommonUtils.loadColumnConfigList(props.getProperty(CommonConstants.SHIFU_COLUMN_CONFIG), sourceType);
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
    // worker number is used to estimate nodes per iteration for stats
    this.workerNumber = NumberFormatUtils.getInt(props.getProperty(GuaguaConstants.GUAGUA_WORKER_NUMBER), true);
    // check if variables are set final selected
    int[] inputOutputIndex = DTrainUtils.getNumericAndCategoricalInputAndOutputCounts(this.columnConfigList);
    this.inputNum = inputOutputIndex[0] + inputOutputIndex[1];
    this.isAfterVarSelect = (inputOutputIndex[3] == 1);
    // cache all feature list for sampling features
    this.allFeatures = this.getAllFeatureList(columnConfigList, isAfterVarSelect);
    int trainerId = Integer.valueOf(context.getProps().getProperty(CommonConstants.SHIFU_TRAINER_ID, "0"));
    // If grid search, select valid paramters, if not parameters is what in ModelConfig.json
    GridSearch gs = new GridSearch(modelConfig.getTrain().getParams(), modelConfig.getTrain().getGridConfigFileContent());
    Map<String, Object> validParams = this.modelConfig.getTrain().getParams();
    if (gs.hasHyperParam()) {
        validParams = gs.getParams(trainerId);
        LOG.info("Start grid search master with params: {}", validParams);
    }
    Object vtObj = validParams.get("ValidationTolerance");
    if (vtObj != null) {
        try {
            validationTolerance = Double.parseDouble(vtObj.toString());
            LOG.warn("Validation by tolerance is enabled with value {}.", validationTolerance);
        } catch (NumberFormatException ee) {
            validationTolerance = 0d;
            LOG.warn("Validation by tolerance isn't enabled because of non numerical value of ValidationTolerance: {}.", vtObj);
        }
    } else {
        LOG.warn("Validation by tolerance isn't enabled.");
    }
    // tree related parameters initialization
    Object fssObj = validParams.get("FeatureSubsetStrategy");
    if (fssObj != null) {
        try {
            this.featureSubsetRate = Double.parseDouble(fssObj.toString());
            // no need validate featureSubsetRate is in (0,1], as already validated in ModelInspector
            this.featureSubsetStrategy = null;
        } catch (NumberFormatException ee) {
            this.featureSubsetStrategy = FeatureSubsetStrategy.of(fssObj.toString());
        }
    } else {
        LOG.warn("FeatureSubsetStrategy is not set, set to TWOTHRIDS by default in DTMaster.");
        this.featureSubsetStrategy = FeatureSubsetStrategy.TWOTHIRDS;
        this.featureSubsetRate = 0;
    }
    // max depth
    Object maxDepthObj = validParams.get("MaxDepth");
    if (maxDepthObj != null) {
        this.maxDepth = Integer.valueOf(maxDepthObj.toString());
    } else {
        this.maxDepth = 10;
    }
    // max leaves which is used for leaf-wised tree building, TODO add more benchmarks
    Object maxLeavesObj = validParams.get("MaxLeaves");
    if (maxLeavesObj != null) {
        this.maxLeaves = Integer.valueOf(maxLeavesObj.toString());
    } else {
        this.maxLeaves = -1;
    }
    // enable leaf wise tree building once maxLeaves is configured
    if (this.maxLeaves > 0) {
        this.isLeafWise = true;
    }
    // maxBatchSplitSize means each time split # of batch nodes
    Object maxBatchSplitSizeObj = validParams.get("MaxBatchSplitSize");
    if (maxBatchSplitSizeObj != null) {
        this.maxBatchSplitSize = Integer.valueOf(maxBatchSplitSizeObj.toString());
    } else {
        // by default split 32 at most in a batch
        this.maxBatchSplitSize = 32;
    }
    assert this.maxDepth > 0 && this.maxDepth <= 20;
    // hide in parameters, this to avoid OOM issue for each iteration
    Object maxStatsMemoryMB = validParams.get("MaxStatsMemoryMB");
    if (maxStatsMemoryMB != null) {
        this.maxStatsMemory = Long.valueOf(validParams.get("MaxStatsMemoryMB").toString()) * 1024 * 1024;
        if (this.maxStatsMemory > ((2L * Runtime.getRuntime().maxMemory()) / 3)) {
            // if >= 2/3 max memory, take 2/3 max memory to avoid OOM
            this.maxStatsMemory = ((2L * Runtime.getRuntime().maxMemory()) / 3);
        }
    } else {
        // by default it is 1/2 of heap, about 1.5G setting in current Shifu
        this.maxStatsMemory = Runtime.getRuntime().maxMemory() / 2L;
    }
    // assert this.maxStatsMemory <= Math.min(Runtime.getRuntime().maxMemory() * 0.6, 800 * 1024 * 1024L);
    this.treeNum = Integer.valueOf(validParams.get("TreeNum").toString());
    this.isRF = ALGORITHM.RF.toString().equalsIgnoreCase(modelConfig.getAlgorithm());
    this.isGBDT = ALGORITHM.GBT.toString().equalsIgnoreCase(modelConfig.getAlgorithm());
    if (this.isGBDT) {
        // learning rate only effective in gbdt
        this.learningRate = Double.valueOf(validParams.get(CommonConstants.LEARNING_RATE).toString());
    }
    // initialize impurity type according to regression or classfication
    String imStr = validParams.get("Impurity").toString();
    int numClasses = 2;
    if (this.modelConfig.isClassification()) {
        numClasses = this.modelConfig.getTags().size();
    }
    // these two parameters is to stop tree growth parameters
    int minInstancesPerNode = Integer.valueOf(validParams.get("MinInstancesPerNode").toString());
    double minInfoGain = Double.valueOf(validParams.get("MinInfoGain").toString());
    if (imStr.equalsIgnoreCase("entropy")) {
        impurity = new Entropy(numClasses, minInstancesPerNode, minInfoGain);
    } else if (imStr.equalsIgnoreCase("gini")) {
        impurity = new Gini(numClasses, minInstancesPerNode, minInfoGain);
    } else {
        impurity = new Variance(minInstancesPerNode, minInfoGain);
    }
    // checkpoint folder and interval (every # iterations to do checkpoint)
    this.checkpointInterval = NumberFormatUtils.getInt(context.getProps().getProperty(CommonConstants.SHIFU_DT_MASTER_CHECKPOINT_INTERVAL, "20"));
    this.checkpointOutput = new Path(context.getProps().getProperty(CommonConstants.SHIFU_DT_MASTER_CHECKPOINT_FOLDER, "tmp/cp_" + context.getAppId()));
    // cache conf to avoid new
    this.conf = new Configuration();
    // if continuous model training is enabled
    this.isContinuousEnabled = Boolean.TRUE.toString().equalsIgnoreCase(context.getProps().getProperty(CommonConstants.CONTINUOUS_TRAINING));
    this.dtEarlyStopDecider = new DTEarlyStopDecider(this.maxDepth);
    if (validParams.containsKey("EnableEarlyStop") && Boolean.valueOf(validParams.get("EnableEarlyStop").toString().toLowerCase())) {
        this.enableEarlyStop = true;
    }
    LOG.info("Master init params: isAfterVarSel={}, featureSubsetStrategy={}, featureSubsetRate={} maxDepth={}, maxStatsMemory={}, " + "treeNum={}, impurity={}, workerNumber={}, minInstancesPerNode={}, minInfoGain={}, isRF={}, " + "isGBDT={}, isContinuousEnabled={}, enableEarlyStop={}.", isAfterVarSelect, featureSubsetStrategy, this.featureSubsetRate, maxDepth, maxStatsMemory, treeNum, imStr, this.workerNumber, minInstancesPerNode, minInfoGain, this.isRF, this.isGBDT, this.isContinuousEnabled, this.enableEarlyStop);
    this.toDoQueue = new LinkedList<TreeNode>();
    if (this.isLeafWise) {
        this.toSplitQueue = new PriorityQueue<TreeNode>(64, new Comparator<TreeNode>() {

            @Override
            public int compare(TreeNode o1, TreeNode o2) {
                return Double.compare(o2.getNode().getWgtCntRatio() * o2.getNode().getGain(), o1.getNode().getWgtCntRatio() * o1.getNode().getGain());
            }
        });
    }
    // initialize trees
    if (context.isFirstIteration()) {
        if (this.isRF) {
            // for random forest, trees are trained in parallel
            this.trees = new CopyOnWriteArrayList<TreeNode>();
            for (int i = 0; i < treeNum; i++) {
                this.trees.add(new TreeNode(i, new Node(Node.ROOT_INDEX), 1d));
            }
        }
        if (this.isGBDT) {
            if (isContinuousEnabled) {
                TreeModel existingModel;
                try {
                    Path modelPath = new Path(context.getProps().getProperty(CommonConstants.GUAGUA_OUTPUT));
                    existingModel = (TreeModel) ModelSpecLoaderUtils.loadModel(modelConfig, modelPath, ShifuFileUtils.getFileSystemBySourceType(this.modelConfig.getDataSet().getSource()));
                    if (existingModel == null) {
                        // null means no existing model file or model file is in wrong format
                        this.trees = new CopyOnWriteArrayList<TreeNode>();
                        // learning rate is 1 for 1st
                        this.trees.add(new TreeNode(0, new Node(Node.ROOT_INDEX), 1d));
                        LOG.info("Starting to train model from scratch and existing model is empty.");
                    } else {
                        this.trees = existingModel.getTrees();
                        this.existingTreeSize = this.trees.size();
                        // starting from existing models, first tree learning rate is current learning rate
                        this.trees.add(new TreeNode(this.existingTreeSize, new Node(Node.ROOT_INDEX), this.existingTreeSize == 0 ? 1d : this.learningRate));
                        LOG.info("Starting to train model from existing model {} with existing trees {}.", modelPath, existingTreeSize);
                    }
                } catch (IOException e) {
                    throw new GuaguaRuntimeException(e);
                }
            } else {
                this.trees = new CopyOnWriteArrayList<TreeNode>();
                // for GBDT, initialize the first tree. trees are trained sequentially,first tree learning rate is 1
                this.trees.add(new TreeNode(0, new Node(Node.ROOT_INDEX), 1.0d));
            }
        }
    } else {
        // recover all states once master is fail over
        LOG.info("Recover master status from checkpoint file {}", this.checkpointOutput);
        recoverMasterStatus(sourceType);
    }
}
Also used : Configuration(org.apache.hadoop.conf.Configuration) SourceType(ml.shifu.shifu.container.obj.RawSourceData.SourceType) Properties(java.util.Properties) Comparator(java.util.Comparator) TreeModel(ml.shifu.shifu.core.TreeModel) GuaguaRuntimeException(ml.shifu.guagua.GuaguaRuntimeException) Path(org.apache.hadoop.fs.Path) IOException(java.io.IOException) GridSearch(ml.shifu.shifu.core.dtrain.gs.GridSearch) GuaguaRuntimeException(ml.shifu.guagua.GuaguaRuntimeException)

Example 2 with TreeModel

use of ml.shifu.shifu.core.TreeModel in project shifu by ShifuML.

the class CommonUtils method computeTreeModelFeatureImportance.

/**
 * Compute feature importance for all bagging tree models.
 *
 * @param models
 *            the tree models, should be instance of TreeModel
 * @return feature importance per each column id
 * @throws IllegalStateException
 *             if no any feature importance from models
 */
public static Map<Integer, MutablePair<String, Double>> computeTreeModelFeatureImportance(List<BasicML> models) {
    List<Map<Integer, MutablePair<String, Double>>> importanceList = new ArrayList<Map<Integer, MutablePair<String, Double>>>();
    for (BasicML basicModel : models) {
        if (basicModel instanceof TreeModel) {
            TreeModel model = (TreeModel) basicModel;
            Map<Integer, MutablePair<String, Double>> importances = model.getFeatureImportances();
            importanceList.add(importances);
        }
    }
    if (importanceList.size() < 1) {
        throw new IllegalStateException("Feature importance calculation abort due to no tree model found!!");
    }
    return mergeImportanceList(importanceList);
}
Also used : MutablePair(org.apache.commons.lang3.tuple.MutablePair) TreeModel(ml.shifu.shifu.core.TreeModel) BasicML(org.encog.ml.BasicML)

Example 3 with TreeModel

use of ml.shifu.shifu.core.TreeModel in project shifu by ShifuML.

the class TreeEnsemblePMMLTranslator method build.

public PMML build(BasicML basicML) {
    PMML pmml = new PMML();
    Header header = new Header();
    pmml.setHeader(header);
    header.setCopyright(" Copyright [2013-2017] PayPal Software Foundation\n" + "\n" + " Licensed under the Apache License, Version 2.0 (the \"License\");\n" + " you may not use this file except in compliance with the License.\n" + " You may obtain a copy of the License at\n" + "\n" + "    http://www.apache.org/licenses/LICENSE-2.0\n" + "\n" + " Unless required by applicable law or agreed to in writing, software\n" + " distributed under the License is distributed on an \"AS IS\" BASIS,\n" + " WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n" + " See the License for the specific language governing permissions and\n" + " limitations under the License.\n");
    Application application = new Application();
    header.setApplication(application);
    String findContainingJar = JarManager.findContainingJar(TreeEnsemblePMMLTranslator.class);
    JarFile jar = null;
    try {
        jar = new JarFile(findContainingJar);
        final Manifest manifest = jar.getManifest();
        String vendor = manifest.getMainAttributes().getValue("vendor");
        application.setName(vendor);
        String version = manifest.getMainAttributes().getValue("version");
        application.setVersion(version);
    } catch (Exception e) {
        LOG.warn(e.getMessage());
    } finally {
        if (jar != null) {
            try {
                jar.close();
            } catch (IOException e) {
                LOG.warn(e.getMessage());
            }
        }
    }
    pmml.setDataDictionary(dataDictionaryCreator.build(basicML));
    List<Model> models = pmml.getModels();
    Model miningModel = modelCreator.convert(((TreeModel) basicML).getIndependentTreeModel());
    models.add(miningModel);
    return pmml;
}
Also used : Header(org.dmg.pmml.Header) TreeModel(ml.shifu.shifu.core.TreeModel) Model(org.dmg.pmml.Model) PMML(org.dmg.pmml.PMML) IOException(java.io.IOException) JarFile(java.util.jar.JarFile) Manifest(java.util.jar.Manifest) Application(org.dmg.pmml.Application) IOException(java.io.IOException)

Example 4 with TreeModel

use of ml.shifu.shifu.core.TreeModel in project shifu by ShifuML.

the class DTWorker method init.

@Override
public void init(WorkerContext<DTMasterParams, DTWorkerParams> context) {
    Properties props = context.getProps();
    try {
        SourceType sourceType = SourceType.valueOf(props.getProperty(CommonConstants.MODELSET_SOURCE_TYPE, SourceType.HDFS.toString()));
        this.modelConfig = CommonUtils.loadModelConfig(props.getProperty(CommonConstants.SHIFU_MODEL_CONFIG), sourceType);
        this.columnConfigList = CommonUtils.loadColumnConfigList(props.getProperty(CommonConstants.SHIFU_COLUMN_CONFIG), sourceType);
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
    this.columnCategoryIndexMapping = new HashMap<Integer, Map<String, Integer>>();
    for (ColumnConfig config : this.columnConfigList) {
        if (config.isCategorical()) {
            if (config.getBinCategory() != null) {
                Map<String, Integer> tmpMap = new HashMap<String, Integer>();
                for (int i = 0; i < config.getBinCategory().size(); i++) {
                    List<String> catVals = CommonUtils.flattenCatValGrp(config.getBinCategory().get(i));
                    for (String cval : catVals) {
                        tmpMap.put(cval, i);
                    }
                }
                this.columnCategoryIndexMapping.put(config.getColumnNum(), tmpMap);
            }
        }
    }
    this.hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
    // create Splitter
    String delimiter = context.getProps().getProperty(Constants.SHIFU_OUTPUT_DATA_DELIMITER);
    this.splitter = MapReduceUtils.generateShifuOutputSplitter(delimiter);
    Integer kCrossValidation = this.modelConfig.getTrain().getNumKFold();
    if (kCrossValidation != null && kCrossValidation > 0) {
        isKFoldCV = true;
        LOG.info("Cross validation is enabled by kCrossValidation: {}.", kCrossValidation);
    }
    Double upSampleWeight = modelConfig.getTrain().getUpSampleWeight();
    if (Double.compare(upSampleWeight, 1d) != 0 && (modelConfig.isRegression() || (modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll()))) {
        // set mean to upSampleWeight -1 and get sample + 1 to make sure no zero sample value
        LOG.info("Enable up sampling with weight {}.", upSampleWeight);
        this.upSampleRng = new PoissonDistribution(upSampleWeight - 1);
    }
    this.isContinuousEnabled = Boolean.TRUE.toString().equalsIgnoreCase(context.getProps().getProperty(CommonConstants.CONTINUOUS_TRAINING));
    this.workerThreadCount = modelConfig.getTrain().getWorkerThreadCount();
    this.threadPool = Executors.newFixedThreadPool(this.workerThreadCount);
    // enable shut down logic
    context.addCompletionCallBack(new WorkerCompletionCallBack<DTMasterParams, DTWorkerParams>() {

        @Override
        public void callback(WorkerContext<DTMasterParams, DTWorkerParams> context) {
            DTWorker.this.threadPool.shutdownNow();
            try {
                DTWorker.this.threadPool.awaitTermination(2, TimeUnit.SECONDS);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
        }
    });
    this.trainerId = Integer.valueOf(context.getProps().getProperty(CommonConstants.SHIFU_TRAINER_ID, "0"));
    this.isOneVsAll = modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll();
    GridSearch gs = new GridSearch(modelConfig.getTrain().getParams(), modelConfig.getTrain().getGridConfigFileContent());
    Map<String, Object> validParams = this.modelConfig.getTrain().getParams();
    if (gs.hasHyperParam()) {
        validParams = gs.getParams(this.trainerId);
        LOG.info("Start grid search worker with params: {}", validParams);
    }
    this.treeNum = Integer.valueOf(validParams.get("TreeNum").toString());
    double memoryFraction = Double.valueOf(context.getProps().getProperty("guagua.data.memoryFraction", "0.6"));
    LOG.info("Max heap memory: {}, fraction: {}", Runtime.getRuntime().maxMemory(), memoryFraction);
    double validationRate = this.modelConfig.getValidSetRate();
    if (StringUtils.isNotBlank(modelConfig.getValidationDataSetRawPath())) {
        // fixed 0.6 and 0.4 of max memory for trainingData and validationData
        this.trainingData = new MemoryLimitedList<Data>((long) (Runtime.getRuntime().maxMemory() * memoryFraction * 0.6), new ArrayList<Data>());
        this.validationData = new MemoryLimitedList<Data>((long) (Runtime.getRuntime().maxMemory() * memoryFraction * 0.4), new ArrayList<Data>());
    } else {
        if (Double.compare(validationRate, 0d) != 0) {
            this.trainingData = new MemoryLimitedList<Data>((long) (Runtime.getRuntime().maxMemory() * memoryFraction * (1 - validationRate)), new ArrayList<Data>());
            this.validationData = new MemoryLimitedList<Data>((long) (Runtime.getRuntime().maxMemory() * memoryFraction * validationRate), new ArrayList<Data>());
        } else {
            this.trainingData = new MemoryLimitedList<Data>((long) (Runtime.getRuntime().maxMemory() * memoryFraction), new ArrayList<Data>());
        }
    }
    int[] inputOutputIndex = DTrainUtils.getNumericAndCategoricalInputAndOutputCounts(this.columnConfigList);
    // numerical + categorical = # of all input
    this.inputCount = inputOutputIndex[0] + inputOutputIndex[1];
    // regression outputNodeCount is 1, binaryClassfication, it is 1, OneVsAll it is 1, Native classification it is
    // 1, with index of 0,1,2,3 denotes different classes
    this.isAfterVarSelect = (inputOutputIndex[3] == 1);
    this.isManualValidation = (modelConfig.getValidationDataSetRawPath() != null && !"".equals(modelConfig.getValidationDataSetRawPath()));
    int numClasses = this.modelConfig.isClassification() ? this.modelConfig.getTags().size() : 2;
    String imStr = validParams.get("Impurity").toString();
    int minInstancesPerNode = Integer.valueOf(validParams.get("MinInstancesPerNode").toString());
    double minInfoGain = Double.valueOf(validParams.get("MinInfoGain").toString());
    if (imStr.equalsIgnoreCase("entropy")) {
        impurity = new Entropy(numClasses, minInstancesPerNode, minInfoGain);
    } else if (imStr.equalsIgnoreCase("gini")) {
        impurity = new Gini(numClasses, minInstancesPerNode, minInfoGain);
    } else if (imStr.equalsIgnoreCase("friedmanmse")) {
        impurity = new FriedmanMSE(minInstancesPerNode, minInfoGain);
    } else {
        impurity = new Variance(minInstancesPerNode, minInfoGain);
    }
    this.isRF = ALGORITHM.RF.toString().equalsIgnoreCase(modelConfig.getAlgorithm());
    this.isGBDT = ALGORITHM.GBT.toString().equalsIgnoreCase(modelConfig.getAlgorithm());
    String lossStr = validParams.get("Loss").toString();
    if (lossStr.equalsIgnoreCase("log")) {
        this.loss = new LogLoss();
    } else if (lossStr.equalsIgnoreCase("absolute")) {
        this.loss = new AbsoluteLoss();
    } else if (lossStr.equalsIgnoreCase("halfgradsquared")) {
        this.loss = new HalfGradSquaredLoss();
    } else if (lossStr.equalsIgnoreCase("squared")) {
        this.loss = new SquaredLoss();
    } else {
        try {
            this.loss = (Loss) ClassUtils.newInstance(Class.forName(lossStr));
        } catch (ClassNotFoundException e) {
            LOG.warn("Class not found for {}, using default SquaredLoss", lossStr);
            this.loss = new SquaredLoss();
        }
    }
    if (this.isGBDT) {
        this.learningRate = Double.valueOf(validParams.get(CommonConstants.LEARNING_RATE).toString());
        Object swrObj = validParams.get("GBTSampleWithReplacement");
        if (swrObj != null) {
            this.gbdtSampleWithReplacement = Boolean.TRUE.toString().equalsIgnoreCase(swrObj.toString());
        }
        Object dropoutObj = validParams.get(CommonConstants.DROPOUT_RATE);
        if (dropoutObj != null) {
            this.dropOutRate = Double.valueOf(dropoutObj.toString());
        }
    }
    this.isStratifiedSampling = this.modelConfig.getTrain().getStratifiedSample();
    this.checkpointOutput = new Path(context.getProps().getProperty(CommonConstants.SHIFU_DT_MASTER_CHECKPOINT_FOLDER, "tmp/cp_" + context.getAppId()));
    LOG.info("Worker init params:isAfterVarSel={}, treeNum={}, impurity={}, loss={}, learningRate={}, gbdtSampleWithReplacement={}, isRF={}, isGBDT={}, isStratifiedSampling={}, isKFoldCV={}, kCrossValidation={}, dropOutRate={}", isAfterVarSelect, treeNum, impurity.getClass().getName(), loss.getClass().getName(), this.learningRate, this.gbdtSampleWithReplacement, this.isRF, this.isGBDT, this.isStratifiedSampling, this.isKFoldCV, kCrossValidation, this.dropOutRate);
    // for fail over, load existing trees
    if (!context.isFirstIteration()) {
        if (this.isGBDT) {
            // set flag here and recover later in doComputing, this is to make sure recover after load part which
            // can load latest trees in #doCompute
            isNeedRecoverGBDTPredict = true;
        } else {
            // RF , trees are recovered from last master results
            recoverTrees = context.getLastMasterResult().getTrees();
        }
    }
    if (context.isFirstIteration() && this.isContinuousEnabled && this.isGBDT) {
        Path modelPath = new Path(context.getProps().getProperty(CommonConstants.GUAGUA_OUTPUT));
        TreeModel existingModel = null;
        try {
            existingModel = (TreeModel) ModelSpecLoaderUtils.loadModel(modelConfig, modelPath, ShifuFileUtils.getFileSystemBySourceType(this.modelConfig.getDataSet().getSource()));
        } catch (IOException e) {
            LOG.error("Error in get existing model, will ignore and start from scratch", e);
        }
        if (existingModel == null) {
            LOG.warn("No model is found even set to continuous model training.");
            return;
        } else {
            recoverTrees = existingModel.getTrees();
            LOG.info("Loading existing {} trees", recoverTrees.size());
        }
    }
}
Also used : PoissonDistribution(org.apache.commons.math3.distribution.PoissonDistribution) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) HashMap(java.util.HashMap) SourceType(ml.shifu.shifu.container.obj.RawSourceData.SourceType) ArrayList(java.util.ArrayList) Properties(java.util.Properties) TreeModel(ml.shifu.shifu.core.TreeModel) GuaguaRuntimeException(ml.shifu.guagua.GuaguaRuntimeException) Path(org.apache.hadoop.fs.Path) IOException(java.io.IOException) GridSearch(ml.shifu.shifu.core.dtrain.gs.GridSearch) Map(java.util.Map) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) HashMap(java.util.HashMap) ConcurrentMap(java.util.concurrent.ConcurrentMap)

Example 5 with TreeModel

use of ml.shifu.shifu.core.TreeModel in project shifu by ShifuML.

the class ExportModelProcessor method run.

/*
     * (non-Javadoc)
     * 
     * @see ml.shifu.shifu.core.processor.Processor#run()
     */
@Override
public int run() throws Exception {
    setUp(ModelStep.EXPORT);
    int status = 0;
    File pmmls = new File("pmmls");
    FileUtils.forceMkdir(pmmls);
    if (StringUtils.isBlank(type)) {
        type = PMML;
    }
    String modelsPath = pathFinder.getModelsPath(SourceType.LOCAL);
    if (type.equalsIgnoreCase(ONE_BAGGING_MODEL)) {
        if (!"nn".equalsIgnoreCase(modelConfig.getAlgorithm()) && !CommonUtils.isTreeModel(modelConfig.getAlgorithm())) {
            log.warn("Currently one bagging model is only supported in NN/GBT/RF algorithm.");
        } else {
            List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(modelsPath, ALGORITHM.valueOf(modelConfig.getAlgorithm().toUpperCase()));
            if (models.size() < 1) {
                log.warn("No model is found in {}.", modelsPath);
            } else {
                log.info("Convert nn models into one binary bagging model.");
                Configuration conf = new Configuration();
                Path output = new Path(pathFinder.getBaggingModelPath(SourceType.LOCAL), "model.b" + modelConfig.getAlgorithm());
                if ("nn".equalsIgnoreCase(modelConfig.getAlgorithm())) {
                    BinaryNNSerializer.save(modelConfig, columnConfigList, models, FileSystem.getLocal(conf), output);
                } else if (CommonUtils.isTreeModel(modelConfig.getAlgorithm())) {
                    List<List<TreeNode>> baggingTrees = new ArrayList<List<TreeNode>>();
                    for (int i = 0; i < models.size(); i++) {
                        TreeModel tm = (TreeModel) models.get(i);
                        // TreeModel only has one TreeNode instance although it is list inside
                        baggingTrees.add(tm.getIndependentTreeModel().getTrees().get(0));
                    }
                    int[] inputOutputIndex = DTrainUtils.getNumericAndCategoricalInputAndOutputCounts(this.columnConfigList);
                    // numerical + categorical = # of all input
                    int inputCount = inputOutputIndex[0] + inputOutputIndex[1];
                    BinaryDTSerializer.save(modelConfig, columnConfigList, baggingTrees, modelConfig.getParams().get("Loss").toString(), inputCount, FileSystem.getLocal(conf), output);
                }
                log.info("Please find one unified bagging model in local {}.", output);
            }
        }
    } else if (type.equalsIgnoreCase(PMML)) {
        // typical pmml generation
        List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(modelsPath, ALGORITHM.valueOf(modelConfig.getAlgorithm().toUpperCase()));
        PMMLTranslator translator = PMMLConstructorFactory.produce(modelConfig, columnConfigList, isConcise(), false);
        for (int index = 0; index < models.size(); index++) {
            String path = "pmmls" + File.separator + modelConfig.getModelSetName() + Integer.toString(index) + ".pmml";
            log.info("\t Start to generate " + path);
            PMML pmml = translator.build(Arrays.asList(new BasicML[] { models.get(index) }));
            PMMLUtils.savePMML(pmml, path);
        }
    } else if (type.equalsIgnoreCase(ONE_BAGGING_PMML_MODEL)) {
        // one unified bagging pmml generation
        log.info("Convert models into one bagging pmml model {} format", type);
        if (!"nn".equalsIgnoreCase(modelConfig.getAlgorithm())) {
            log.warn("Currently one bagging pmml model is only supported in NN algorithm.");
        } else {
            List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(modelsPath, ALGORITHM.valueOf(modelConfig.getAlgorithm().toUpperCase()));
            PMMLTranslator translator = PMMLConstructorFactory.produce(modelConfig, columnConfigList, isConcise(), true);
            String path = "pmmls" + File.separator + modelConfig.getModelSetName() + ".pmml";
            log.info("\t Start to generate one unified model to: " + path);
            PMML pmml = translator.build(models);
            PMMLUtils.savePMML(pmml, path);
        }
    } else if (type.equalsIgnoreCase(COLUMN_STATS)) {
        saveColumnStatus();
    } else if (type.equalsIgnoreCase(WOE_MAPPING)) {
        List<ColumnConfig> exportCatColumns = new ArrayList<ColumnConfig>();
        List<String> catVariables = getRequestVars();
        for (ColumnConfig columnConfig : this.columnConfigList) {
            if (CollectionUtils.isEmpty(catVariables) || isRequestColumn(catVariables, columnConfig)) {
                exportCatColumns.add(columnConfig);
            }
        }
        if (CollectionUtils.isNotEmpty(exportCatColumns)) {
            List<String> woeMappings = new ArrayList<String>();
            for (ColumnConfig columnConfig : exportCatColumns) {
                String woeMapText = rebinAndExportWoeMapping(columnConfig);
                woeMappings.add(woeMapText);
            }
            FileUtils.write(new File("woemapping.txt"), StringUtils.join(woeMappings, ",\n"));
        }
    } else if (type.equalsIgnoreCase(WOE)) {
        List<String> woeInfos = new ArrayList<String>();
        for (ColumnConfig columnConfig : this.columnConfigList) {
            if (columnConfig.getBinLength() > 1 && ((columnConfig.isCategorical() && CollectionUtils.isNotEmpty(columnConfig.getBinCategory())) || (columnConfig.isNumerical() && CollectionUtils.isNotEmpty(columnConfig.getBinBoundary()) && columnConfig.getBinBoundary().size() > 1))) {
                List<String> varWoeInfos = generateWoeInfos(columnConfig);
                if (CollectionUtils.isNotEmpty(varWoeInfos)) {
                    woeInfos.addAll(varWoeInfos);
                    woeInfos.add("");
                }
            }
            FileUtils.writeLines(new File("varwoe_info.txt"), woeInfos);
        }
    } else if (type.equalsIgnoreCase(CORRELATION)) {
        // export correlation into mapping list
        if (!ShifuFileUtils.isFileExists(pathFinder.getLocalCorrelationCsvPath(), SourceType.LOCAL)) {
            log.warn("The correlation file doesn't exist. Please make sure you have ran `shifu stats -c`.");
            return 2;
        }
        return exportVariableCorr();
    } else {
        log.error("Unsupported output format - {}", type);
        status = -1;
    }
    clearUp(ModelStep.EXPORT);
    log.info("Done.");
    return status;
}
Also used : Path(org.apache.hadoop.fs.Path) Configuration(org.apache.hadoop.conf.Configuration) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) BasicML(org.encog.ml.BasicML) PMMLTranslator(ml.shifu.shifu.core.pmml.PMMLTranslator) TreeModel(ml.shifu.shifu.core.TreeModel) TreeNode(ml.shifu.shifu.core.dtrain.dt.TreeNode) PMML(org.dmg.pmml.PMML) File(java.io.File)

Aggregations

TreeModel (ml.shifu.shifu.core.TreeModel)5 IOException (java.io.IOException)3 Path (org.apache.hadoop.fs.Path)3 Properties (java.util.Properties)2 GuaguaRuntimeException (ml.shifu.guagua.GuaguaRuntimeException)2 ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)2 SourceType (ml.shifu.shifu.container.obj.RawSourceData.SourceType)2 GridSearch (ml.shifu.shifu.core.dtrain.gs.GridSearch)2 Configuration (org.apache.hadoop.conf.Configuration)2 PMML (org.dmg.pmml.PMML)2 BasicML (org.encog.ml.BasicML)2 File (java.io.File)1 ArrayList (java.util.ArrayList)1 Comparator (java.util.Comparator)1 HashMap (java.util.HashMap)1 Map (java.util.Map)1 ConcurrentHashMap (java.util.concurrent.ConcurrentHashMap)1 ConcurrentMap (java.util.concurrent.ConcurrentMap)1 JarFile (java.util.jar.JarFile)1 Manifest (java.util.jar.Manifest)1