Search in sources :

Example 11 with ShifuException

use of ml.shifu.shifu.exception.ShifuException in project shifu by ShifuML.

the class TrainModelProcessor method runAkkaTrain.

/**
 * run training process with number of bags
 *
 * @param numBags
 *            number of bags, it decide how much trainer will start training
 * @throws IOException
 */
private void runAkkaTrain(int numBags) throws IOException {
    File models = new File("models");
    FileUtils.deleteDirectory(models);
    FileUtils.forceMkdir(models);
    trainers.clear();
    for (int i = 0; i < numBags; i++) {
        AbstractTrainer trainer;
        if (modelConfig.getAlgorithm().equalsIgnoreCase("NN")) {
            trainer = new NNTrainer(modelConfig, i, isDryTrain);
        } else if (modelConfig.getAlgorithm().equalsIgnoreCase("SVM")) {
            trainer = new SVMTrainer(this.modelConfig, i, isDryTrain);
        } else if (modelConfig.getAlgorithm().equalsIgnoreCase("LR")) {
            trainer = new LogisticRegressionTrainer(this.modelConfig, i, isDryTrain);
        } else {
            throw new ShifuException(ShifuErrorCode.ERROR_UNSUPPORT_ALG);
        }
        trainers.add(trainer);
    }
    List<Scanner> scanners = null;
    if (modelConfig.getAlgorithm().equalsIgnoreCase("DT")) {
        LOG.info("Raw Data: " + pathFinder.getNormalizedDataPath());
        try {
            scanners = ShifuFileUtils.getDataScanners(modelConfig.getDataSetRawPath(), modelConfig.getDataSet().getSource());
        } catch (IOException e) {
            throw new ShifuException(ShifuErrorCode.ERROR_INPUT_NOT_FOUND, e, pathFinder.getNormalizedDataPath());
        }
        if (CollectionUtils.isNotEmpty(scanners)) {
            AkkaSystemExecutor.getExecutor().submitDecisionTreeTrainJob(modelConfig, columnConfigList, scanners, trainers);
        }
    } else {
        LOG.info("Normalized Data: " + pathFinder.getNormalizedDataPath());
        try {
            scanners = ShifuFileUtils.getDataScanners(pathFinder.getNormalizedDataPath(), modelConfig.getDataSet().getSource());
        } catch (IOException e) {
            throw new ShifuException(ShifuErrorCode.ERROR_INPUT_NOT_FOUND, e, pathFinder.getNormalizedDataPath());
        }
        if (CollectionUtils.isNotEmpty(scanners)) {
            AkkaSystemExecutor.getExecutor().submitModelTrainJob(modelConfig, columnConfigList, scanners, trainers);
        }
    }
    // release
    closeScanners(scanners);
}
Also used : NNTrainer(ml.shifu.shifu.core.alg.NNTrainer) SVMTrainer(ml.shifu.shifu.core.alg.SVMTrainer) LogisticRegressionTrainer(ml.shifu.shifu.core.alg.LogisticRegressionTrainer) AbstractTrainer(ml.shifu.shifu.core.AbstractTrainer) ShifuException(ml.shifu.shifu.exception.ShifuException)

Example 12 with ShifuException

use of ml.shifu.shifu.exception.ShifuException in project shifu by ShifuML.

the class TrainModelProcessor method run.

/**
 * Training process entry point.
 */
@Override
public int run() throws Exception {
    int status = 0;
    if (!this.isForVarSelect()) {
        LOG.info("Step Start: train");
    }
    long start = System.currentTimeMillis();
    try {
        setUp(ModelStep.TRAIN);
        if (isDebug) {
            File file = new File(LOGS);
            if (!file.exists() && !file.mkdir()) {
                throw new RuntimeException("logs file is created failed.");
            }
        }
        RunMode runMode = super.modelConfig.getBasic().getRunMode();
        switch(runMode) {
            case DIST:
            case MAPRED:
                validateDistributedTrain();
                syncDataToHdfs(super.modelConfig.getDataSet().getSource());
                checkAndCleanDataForTreeModels(this.isToShuffle);
                if (Constants.TENSORFLOW.equalsIgnoreCase(modelConfig.getAlgorithm())) {
                    status = runTensorflowDistributedTrain();
                } else {
                    status = runDistributedTrain();
                }
                break;
            case LOCAL:
            default:
                runLocalTrain();
                break;
        }
        syncDataToHdfs(modelConfig.getDataSet().getSource());
        clearUp(ModelStep.TRAIN);
    } catch (ShifuException e) {
        LOG.error("Error:" + e.getError().toString() + "; msg:" + e.getMessage(), e);
        return -1;
    } catch (Exception e) {
        LOG.error("Error:" + e.getMessage(), e);
        return -1;
    }
    if (!this.isForVarSelect()) {
        LOG.info("Step Finished: train with {} ms", (System.currentTimeMillis() - start));
    }
    return status;
}
Also used : ParquetRuntimeException(parquet.ParquetRuntimeException) RunMode(ml.shifu.shifu.container.obj.ModelBasicConf.RunMode) ShifuException(ml.shifu.shifu.exception.ShifuException) ParquetRuntimeException(parquet.ParquetRuntimeException) RecognitionException(org.antlr.runtime.RecognitionException) ShifuException(ml.shifu.shifu.exception.ShifuException)

Example 13 with ShifuException

use of ml.shifu.shifu.exception.ShifuException in project shifu by ShifuML.

the class TrainModelProcessor method runTensorflowLocalTrain.

private void runTensorflowLocalTrain() throws IOException {
    List<Scanner> scanners = null;
    TensorflowTrainer trainer = new TensorflowTrainer(modelConfig, columnConfigList);
    LOG.info("Normalized data for training {}.", pathFinder.getNormalizedDataPath());
    try {
        scanners = ShifuFileUtils.getDataScanners(pathFinder.getNormalizedDataPath(), modelConfig.getDataSet().getSource());
    } catch (IOException e) {
        throw new ShifuException(ShifuErrorCode.ERROR_INPUT_NOT_FOUND, e, pathFinder.getNormalizedDataPath());
    }
    if (CollectionUtils.isNotEmpty(scanners)) {
        trainer.train();
    }
    closeScanners(scanners);
}
Also used : TensorflowTrainer(ml.shifu.shifu.core.alg.TensorflowTrainer) ShifuException(ml.shifu.shifu.exception.ShifuException)

Example 14 with ShifuException

use of ml.shifu.shifu.exception.ShifuException in project shifu by ShifuML.

the class PostTrainModelProcessor method runPigPostTrain.

/**
 * run pig post train
 *
 * @throws IOException
 *             for any io exception
 */
@SuppressWarnings("unused")
private void runPigPostTrain() throws IOException {
    SourceType sourceType = modelConfig.getDataSet().getSource();
    ShifuFileUtils.deleteFile(pathFinder.getTrainScoresPath(), sourceType);
    ShifuFileUtils.deleteFile(pathFinder.getBinAvgScorePath(), sourceType);
    // prepare special parameters and execute pig
    Map<String, String> paramsMap = new HashMap<String, String>();
    paramsMap.put("pathHeader", modelConfig.getHeaderPath());
    paramsMap.put("pathDelimiter", CommonUtils.escapePigString(modelConfig.getHeaderDelimiter()));
    paramsMap.put("delimiter", CommonUtils.escapePigString(modelConfig.getDataSetDelimiter()));
    try {
        PigExecutor.getExecutor().submitJob(modelConfig, pathFinder.getScriptPath("scripts/PostTrain.pig"), paramsMap);
    } catch (IOException e) {
        throw new ShifuException(ShifuErrorCode.ERROR_RUNNING_PIG_JOB, e);
    } catch (Throwable e) {
        throw new RuntimeException(e);
    }
    // Sync Down
    columnConfigList = updateColumnConfigWithBinAvgScore(columnConfigList);
    saveColumnConfigList();
}
Also used : HashMap(java.util.HashMap) SourceType(ml.shifu.shifu.container.obj.RawSourceData.SourceType) IOException(java.io.IOException) ShifuException(ml.shifu.shifu.exception.ShifuException)

Example 15 with ShifuException

use of ml.shifu.shifu.exception.ShifuException in project shifu by ShifuML.

the class ShifuCLI method main.

public static void main(String[] args) {
    String[] cleanedArgs = cleanArgs(args);
    // invalid input and help options
    if (cleanedArgs.length < 1 || (isHelpOption(cleanedArgs[0]))) {
        printUsage();
        System.exit(cleanedArgs.length < 1 ? -1 : 0);
    }
    // process -v and -version conditions manually
    if (isVersionOption(cleanedArgs[0])) {
        printLogoAndVersion();
        System.exit(0);
    }
    CommandLineParser parser = new GnuParser();
    Options opts = buildModelSetOptions();
    CommandLine cmd = null;
    try {
        cmd = parser.parse(opts, cleanedArgs);
    } catch (ParseException e) {
        log.error("Invalid command options. Please check help message.");
        printUsage();
        System.exit(1);
    }
    int status = 0;
    try {
        if (cleanedArgs[0].equals(NEW) && cleanedArgs.length >= 2 && StringUtils.isNotEmpty(cleanedArgs[1])) {
            // modelset step
            String modelName = cleanedArgs[1];
            status = createNewModel(modelName, cmd.getOptionValue(MODELSET_CMD_TYPE), cmd.getOptionValue(MODELSET_CMD_M));
            if (status == 0) {
                printModelSetCreatedSuccessfulLog(modelName);
            } else {
                log.warn("Error in create new model set, please check your shifu config or report issue");
            }
            System.exit(status);
        // copyModel(manager, cmd.getOptionValues(MODELSET_CMD_CP));
        } else {
            if (cleanedArgs[0].equals(MODELSET_CMD_CP) && cleanedArgs.length >= 3 && StringUtils.isNotEmpty(cleanedArgs[1]) && StringUtils.isNotEmpty(cleanedArgs[2])) {
                String newModelSetName = cleanedArgs[2];
                // modelset step
                copyModel(new String[] { cleanedArgs[1], newModelSetName });
                printModelSetCopiedSuccessfulLog(newModelSetName);
            } else if (cleanedArgs[0].equals(INIT_CMD)) {
                // init step
                if (cmd.getOptions() == null || cmd.getOptions().length == 0) {
                    status = initializeModel();
                    if (status == 0) {
                        log.info("ModelSet initialization is successful. Please continue next step by using 'shifu stats'.");
                    } else {
                        log.warn("Error in ModelSet initialization, please check your shifu config or report issue");
                    }
                } else if (cmd.hasOption(INIT_CMD_MODEL)) {
                    initializeModelParam();
                } else {
                    log.error("Invalid command, please check help message.");
                    printUsage();
                }
            } else if (cleanedArgs[0].equals(STATS_CMD)) {
                Map<String, Object> params = new HashMap<String, Object>();
                params.put(Constants.IS_COMPUTE_CORR, cmd.hasOption(CORRELATION) || cmd.hasOption(SHORT_CORRELATION));
                params.put(Constants.IS_REBIN, cmd.hasOption(REBIN));
                params.put(Constants.REQUEST_VARS, cmd.getOptionValue(VARS));
                params.put(Constants.EXPECTED_BIN_NUM, cmd.getOptionValue(N));
                params.put(Constants.IV_KEEP_RATIO, cmd.getOptionValue(IVR));
                params.put(Constants.MINIMUM_BIN_INST_CNT, cmd.getOptionValue(BIC));
                params.put(Constants.IS_COMPUTE_PSI, cmd.hasOption(PSI) || cmd.hasOption(SHORT_PSI));
                // stats step
                status = calModelStats(params);
                if (status == 0) {
                    if (cmd.hasOption(CORRELATION) || cmd.hasOption(SHORT_CORRELATION)) {
                        log.info("Do model set correlation computing successfully. Please continue next step by using 'shifu normalize or shifu norm'. For tree ensemble model, no need do norm, please continue next step by using 'shifu varsel'");
                    }
                    if (cmd.hasOption(PSI) || cmd.hasOption(SHORT_PSI)) {
                        log.info("Do model set psi computing successfully. Please continue next step by using 'shifu normalize or shifu norm'. For tree ensemble model, no need do norm, please continue next step by using 'shifu varsel'");
                    } else {
                        log.info("Do model set statistic successfully. Please continue next step by using 'shifu normalize or shifu norm or shifu transform'. For tree ensemble model, no need do norm, please continue next step by using 'shifu varsel'");
                    }
                } else {
                    log.warn("Error in model set stats computation, please report issue on http:/github.com/shifuml/shifu/issues.");
                }
            } else if (cleanedArgs[0].equals(NORMALIZE_CMD) || cleanedArgs[0].equals(NORM_CMD) || cleanedArgs[0].equals(TRANSFORM_CMD)) {
                // normalize step
                Map<String, Object> params = new HashMap<String, Object>();
                params.put(Constants.IS_TO_SHUFFLE_DATA, cmd.hasOption(SHUFFLE));
                status = normalizeTrainData(params);
                if (status == 0) {
                    log.info("Do model set normalization successfully. Please continue next step by using 'shifu varselect or shifu varsel'.");
                } else {
                    log.warn("Error in model set stats computation, please report issue on http:/github.com/shifuml/shifu/issues.");
                }
            } else if (cleanedArgs[0].equals(VARSELECT_CMD) || cleanedArgs[0].equals(VARSEL_CMD)) {
                Map<String, Object> params = new HashMap<String, Object>();
                params.put(Constants.IS_TO_RESET, cmd.hasOption(RESET));
                params.put(Constants.IS_TO_LIST, cmd.hasOption(LIST));
                params.put(Constants.IS_TO_FILTER_AUTO, cmd.hasOption(FILTER_AUTO));
                params.put(Constants.IS_TO_RECOVER_AUTO, cmd.hasOption(RECOVER_AUTO));
                params.put(Constants.RECURSIVE_CNT, cmd.getOptionValue(RECURSIVE));
                // variable selected step
                status = selectModelVar(params);
                if (status == 0) {
                    log.info("Do model set variables selection successfully. Please continue next step by using 'shifu train'.");
                } else {
                    log.info("Do variable selection with error, please check error message or report issue.");
                }
            } else if (cleanedArgs[0].equals(TRAIN_CMD)) {
                // train step
                status = trainModel(cmd.hasOption(TRAIN_CMD_DRY), cmd.hasOption(TRAIN_CMD_DEBUG), cmd.hasOption(SHUFFLE));
                if (status == 0) {
                    log.info("Do model set training successfully. Please continue next step by using 'shifu posttrain' or if no need posttrain you can go through with 'shifu eval'.");
                } else {
                    log.info("Do model training with error, please check error message or report issue.");
                }
            } else if (cleanedArgs[0].equals(CMD_ENCODE)) {
                Map<String, Object> params = new HashMap<String, Object>();
                params.put(ModelDataEncodeProcessor.ENCODE_DATA_SET, cmd.getOptionValue(EVAL_CMD_RUN));
                params.put(ModelDataEncodeProcessor.ENCODE_REF_MODEL, cmd.getOptionValue(REF));
                status = runEncode(params);
            } else if (cleanedArgs[0].equals(CMD_COMBO)) {
                if (cmd.hasOption(MODELSET_CMD_NEW)) {
                    log.info("Create new commbo models");
                    status = createNewCombo(cmd.getOptionValue(MODELSET_CMD_NEW));
                } else if (cmd.hasOption(INIT_CMD)) {
                    log.info("Init commbo models");
                    status = initComboModels();
                } else if (cmd.hasOption(EVAL_CMD_RUN)) {
                    log.info("Run combo model - with toShuffle: {}, with toResume: {}", opts.hasOption(SHUFFLE), opts.hasOption(RESUME));
                    status = runComboModels(cmd.hasOption(SHUFFLE), cmd.hasOption(RESUME));
                // train combo models
                } else if (cmd.hasOption(EVAL_CMD)) {
                    log.info("Eval combo model.");
                    // eval combo model performance
                    status = evalComboModels(cmd.hasOption(RESUME));
                } else {
                    log.error("Invalid command usage.");
                    printUsage();
                }
            } else if (cleanedArgs[0].equals(POSTTRAIN_CMD)) {
                // post train step
                status = postTrainModel();
                if (status == 0) {
                    log.info("Do model set post-training successfully. Please configure your eval set in ModelConfig.json and continue next step by using 'shifu eval' or 'shifu eval -new <eval set>' to create a new eval set.");
                } else {
                    log.info("Do model post training with error, please check error message or report issue.");
                }
            } else if (cleanedArgs[0].equals(SAVE)) {
                String newModelSetName = cleanedArgs.length >= 2 ? cleanedArgs[1] : null;
                saveCurrentModel(newModelSetName);
            } else if (cleanedArgs[0].equals(SWITCH)) {
                String newModelSetName = cleanedArgs[1];
                switchCurrentModel(newModelSetName);
            } else if (cleanedArgs[0].equals(SHOW)) {
                ManageModelProcessor p = new ManageModelProcessor(ModelAction.SHOW, null);
                p.run();
            } else if (cleanedArgs[0].equals(EVAL_CMD)) {
                Map<String, Object> params = new HashMap<String, Object>();
                // eval step
                if (cleanedArgs.length == 1) {
                    // run everything
                    status = runEvalSet(cmd.hasOption(TRAIN_CMD_DRY));
                    if (status == 0) {
                        log.info("Run eval performance with all eval sets successfully.");
                    } else {
                        log.info("Do evaluation with error, please check error message or report issue.");
                    }
                } else if (cmd.getOptionValue(MODELSET_CMD_NEW) != null) {
                    // create new eval
                    createNewEvalSet(cmd.getOptionValue(MODELSET_CMD_NEW));
                    log.info("Create eval set successfully. You can configure EvalConfig.json or directly run 'shifu eval -run <evalSetName>' to get performance info.");
                } else if (cmd.hasOption(EVAL_CMD_RUN)) {
                    runEvalSet(cmd.getOptionValue(EVAL_CMD_RUN), cmd.hasOption(TRAIN_CMD_DRY));
                    log.info("Finish run eval performance with eval set {}.", cmd.getOptionValue(EVAL_CMD_RUN));
                } else if (cmd.hasOption(SCORE)) {
                    params.put(EvalModelProcessor.NOSORT, cmd.hasOption(NOSORT));
                    // run score
                    runEvalScore(cmd.getOptionValue(SCORE), params);
                    log.info("Finish run score with eval set {}.", cmd.getOptionValue(SCORE));
                } else if (cmd.hasOption(CONFMAT)) {
                    // run confusion matrix
                    runEvalConfMat(cmd.getOptionValue(CONFMAT));
                    log.info("Finish run confusion matrix with eval set {}.", cmd.getOptionValue(CONFMAT));
                } else if (cmd.hasOption(PERF)) {
                    // run perfermance
                    runEvalPerf(cmd.getOptionValue(PERF));
                    log.info("Finish run performance maxtrix with eval set {}.", cmd.getOptionValue(PERF));
                } else if (cmd.hasOption(LIST)) {
                    // list all evaluation sets
                    listEvalSet();
                } else if (cmd.hasOption(DELETE)) {
                    // delete some evaluation set
                    deleteEvalSet(cmd.getOptionValue(DELETE));
                } else if (cmd.hasOption(NORM)) {
                    runEvalNorm(cmd.getOptionValue(NORM), cmd.hasOption(STRICT));
                } else {
                    log.error("Invalid command, please check help message.");
                    printUsage();
                }
            } else if (cleanedArgs[0].equals(CMD_EXPORT)) {
                Map<String, Object> params = new HashMap<String, Object>();
                params.put(ExportModelProcessor.IS_CONCISE, cmd.hasOption(EXPORT_CONCISE));
                params.put(ExportModelProcessor.REQUEST_VARS, cmd.getOptionValue(VARS));
                params.put(ExportModelProcessor.EXPECTED_BIN_NUM, cmd.getOptionValue(N));
                params.put(ExportModelProcessor.IV_KEEP_RATIO, cmd.getOptionValue(IVR));
                params.put(ExportModelProcessor.MINIMUM_BIN_INST_CNT, cmd.getOptionValue(BIC));
                status = exportModel(cmd.getOptionValue(MODELSET_CMD_TYPE), params);
                if (status == 0) {
                    log.info("Export models/columnstats/corr successfully.");
                } else {
                    log.warn("Fail to export models/columnstats/corr, please check or report issue.");
                }
            } else if (cleanedArgs[0].equals(CMD_TEST)) {
                Map<String, Object> params = new HashMap<String, Object>();
                params.put(ShifuTestProcessor.IS_TO_TEST_FILTER, cmd.hasOption(FILTER));
                params.put(ShifuTestProcessor.TEST_TARGET, cmd.getOptionValue(FILTER));
                params.put(ShifuTestProcessor.TEST_RECORD_CNT, cmd.getOptionValue(N));
                status = runShifuTest(params);
                if (status == 0) {
                    log.info("Run test for Shifu Successfully.");
                } else {
                    log.warn("Fail to run Shifu test.");
                }
            } else if (cleanedArgs[0].equals(CMD_CONVERT)) {
                int optType = -1;
                if (cmd.hasOption(TO_ZIPB)) {
                    optType = 1;
                } else if (cmd.hasOption(TO_TREEB)) {
                    optType = 2;
                }
                String[] convertArgs = new String[2];
                int j = 0;
                for (int i = 1; i < cleanedArgs.length; i++) {
                    if (!cleanedArgs[i].startsWith("-")) {
                        convertArgs[j++] = cleanedArgs[i];
                    }
                }
                if (optType < 0 || StringUtils.isBlank(convertArgs[0]) || StringUtils.isBlank(convertArgs[1])) {
                    printUsage();
                } else {
                    status = runShifuConvert(optType, convertArgs[0], convertArgs[1]);
                }
            } else if (cleanedArgs[0].equals(CMD_ANALYSIS)) {
                if (cmd.hasOption(FI)) {
                    String modelPath = cmd.getOptionValue(FI);
                    analysisModelFi(modelPath);
                }
            } else {
                log.error("Invalid command, please check help message.");
                printUsage();
            }
        }
        // for some case jvm cannot stop
        System.exit(status);
    } catch (ShifuException e) {
        // need define error code in each step.
        log.error(e.getError().toString() + "; msg: " + e.getMessage(), e.getCause());
        exceptionExit(e);
    } catch (Exception e) {
        exceptionExit(e);
    }
}
Also used : Options(org.apache.commons.cli.Options) GnuParser(org.apache.commons.cli.GnuParser) IOException(java.io.IOException) ShifuException(ml.shifu.shifu.exception.ShifuException) ParseException(org.apache.commons.cli.ParseException) CommandLine(org.apache.commons.cli.CommandLine) ManageModelProcessor(ml.shifu.shifu.core.processor.ManageModelProcessor) CommandLineParser(org.apache.commons.cli.CommandLineParser) ParseException(org.apache.commons.cli.ParseException) ShifuException(ml.shifu.shifu.exception.ShifuException)

Aggregations

ShifuException (ml.shifu.shifu.exception.ShifuException)39 IOException (java.io.IOException)22 SourceType (ml.shifu.shifu.container.obj.RawSourceData.SourceType)12 HashMap (java.util.HashMap)8 ArrayList (java.util.ArrayList)5 ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)5 File (java.io.File)4 Scanner (java.util.Scanner)4 Path (org.apache.hadoop.fs.Path)4 SourceFile (ml.shifu.shifu.fs.SourceFile)3 JobStats (org.apache.pig.tools.pigstats.JobStats)3 BufferedReader (java.io.BufferedReader)2 ConfusionMatrixObject (ml.shifu.shifu.container.ConfusionMatrixObject)2 EvalConfig (ml.shifu.shifu.container.obj.EvalConfig)2 RawSourceData (ml.shifu.shifu.container.obj.RawSourceData)2 AbstractStatsExecutor (ml.shifu.shifu.core.processor.stats.AbstractStatsExecutor)2 AkkaStatsWorker (ml.shifu.shifu.core.processor.stats.AkkaStatsWorker)2 DIBStatsExecutor (ml.shifu.shifu.core.processor.stats.DIBStatsExecutor)2 MunroPatIStatsExecutor (ml.shifu.shifu.core.processor.stats.MunroPatIStatsExecutor)2 MunroPatStatsExecutor (ml.shifu.shifu.core.processor.stats.MunroPatStatsExecutor)2