Search in sources :

Example 96 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class TrainModelProcessor method checkAndCleanDataForTreeModels.

/**
 * For RF/GBT model, no need do normalizing, but clean and filter data is needed. Before real training, we have to
 * clean and filter data.
 *
 * @param isToShuffle
 *            if shuffle data before training
 * @throws IOException
 *             the io exception
 */
protected void checkAndCleanDataForTreeModels(boolean isToShuffle) throws IOException {
    String alg = this.getModelConfig().getTrain().getAlgorithm();
    // only for tree models
    if (!CommonUtils.isTreeModel(alg)) {
        return;
    }
    // check if binBoundaries and binCategories are good and log error
    for (ColumnConfig columnConfig : columnConfigList) {
        if (columnConfig.isFinalSelect() && !columnConfig.isTarget() && !columnConfig.isMeta()) {
            if (columnConfig.isNumerical() && columnConfig.getBinBoundary() == null) {
                throw new IllegalArgumentException("Final select " + columnConfig.getColumnName() + "column but binBoundary in ColumnConfig.json is null.");
            }
            if (columnConfig.isNumerical() && columnConfig.getBinBoundary().size() <= 1) {
                LOG.warn("Column {} {} with only one or zero element in binBounday, such column will be ignored in tree model training.", columnConfig.getColumnNum(), columnConfig.getColumnName());
            }
            if (columnConfig.isCategorical() && columnConfig.getBinCategory() == null) {
                throw new IllegalArgumentException("Final select " + columnConfig.getColumnName() + "column but binCategory in ColumnConfig.json is null.");
            }
            if (columnConfig.isCategorical() && columnConfig.getBinCategory().size() <= 0) {
                LOG.warn("Column {} {} with only zero element in binCategory, such column will be ignored in tree model training.", columnConfig.getColumnNum(), columnConfig.getColumnName());
            }
        }
    }
    // run cleaning data logic for model input
    SourceType sourceType = modelConfig.getDataSet().getSource();
    String cleanedDataPath = this.pathFinder.getCleanedDataPath();
    String needReGen = Environment.getProperty("shifu.tree.regeninput", Boolean.FALSE.toString());
    // tree ensemble model training
    if (Boolean.TRUE.toString().equalsIgnoreCase(needReGen) || !ShifuFileUtils.isFileExists(cleanedDataPath, sourceType) || (StringUtils.isNotBlank(modelConfig.getValidationDataSetRawPath()) && !ShifuFileUtils.isFileExists(pathFinder.getCleanedValidationDataPath(), sourceType))) {
        runDataClean(isToShuffle);
    } else {
        // no need regen data
        LOG.warn("For RF/GBT, training input in {} exists, no need to regenerate it.", cleanedDataPath);
        LOG.warn("Need regen it, please set shifu.tree.regeninput in shifuconfig to true.");
    }
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) SourceType(ml.shifu.shifu.container.obj.RawSourceData.SourceType)

Example 97 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class InitModelProcessor method setCategoricalColumnsAndDistinctAccount.

private int setCategoricalColumnsAndDistinctAccount(Map<Integer, Data> distinctCountMap, boolean cateOn, boolean distinctOn) {
    int cateCount = 0;
    for (ColumnConfig columnConfig : columnConfigList) {
        Data data = distinctCountMap.get(columnConfig.getColumnNum());
        if (data == null) {
            continue;
        }
        Long distinctCount = data.distinctCount;
        // disable auto type threshold
        if (distinctCount != null) {
            if (cateOn) {
                String[] items = data.items;
                if (isBinaryVariable(distinctCount, items)) {
                    log.info("Column {} with index {} is set to numeric type because of 0-1 variable. Distinct count {}, items {}.", columnConfig.getColumnName(), columnConfig.getColumnNum(), distinctCount, Arrays.toString(items));
                    columnConfig.setColumnType(ColumnType.N);
                } else if (isDoubleFrequentVariable(items)) {
                    log.info("Column {} with index {} is set to numeric type because of all sampled items are double(including blank). Distinct count {}.", columnConfig.getColumnName(), columnConfig.getColumnNum(), distinctCount);
                    columnConfig.setColumnType(ColumnType.N);
                } else {
                    columnConfig.setColumnType(ColumnType.C);
                    cateCount += 1;
                    log.info("Column {} with index {} is set to categorical type according to auto type checking: distinct count {}.", columnConfig.getColumnName(), columnConfig.getColumnNum(), distinctCount);
                }
            }
            if (distinctOn) {
                columnConfig.getColumnStats().setDistinctCount(distinctCount);
            }
        }
    }
    return cateCount;
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig)

Example 98 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class InitModelProcessor method initColumnConfigList.

/**
 * initialize the columnConfig file
 *
 * @throws IOException
 */
private int initColumnConfigList() throws IOException {
    String[] fields = null;
    boolean isSchemaProvided = true;
    if (StringUtils.isNotBlank(modelConfig.getHeaderPath())) {
        fields = CommonUtils.getHeaders(modelConfig.getHeaderPath(), modelConfig.getHeaderDelimiter(), modelConfig.getDataSet().getSource());
        String[] dataInFirstLine = CommonUtils.takeFirstLine(modelConfig.getDataSetRawPath(), modelConfig.getDataSetDelimiter(), modelConfig.getDataSet().getSource());
        if (fields.length != dataInFirstLine.length) {
            throw new IllegalArgumentException("Header length and data length are not consistent - head length " + fields.length + ", while data length " + dataInFirstLine.length + ", please check you header setting and data set setting.");
        }
    } else {
        fields = CommonUtils.takeFirstLine(modelConfig.getDataSetRawPath(), StringUtils.isBlank(modelConfig.getHeaderDelimiter()) ? modelConfig.getDataSetDelimiter() : modelConfig.getHeaderDelimiter(), modelConfig.getDataSet().getSource());
        if (StringUtils.join(fields, "").contains(modelConfig.getTargetColumnName())) {
            // if first line contains target column name, we guess it is csv format and first line is header.
            isSchemaProvided = true;
            // first line of data meaning second line in data files excluding first header line
            String[] dataInFirstLine = CommonUtils.takeFirstTwoLines(modelConfig.getDataSetRawPath(), StringUtils.isBlank(modelConfig.getHeaderDelimiter()) ? modelConfig.getDataSetDelimiter() : modelConfig.getHeaderDelimiter(), modelConfig.getDataSet().getSource())[1];
            if (dataInFirstLine != null && fields.length != dataInFirstLine.length) {
                throw new IllegalArgumentException("Header length and data length are not consistent, please check you header setting and data set setting.");
            }
            log.warn("No header path is provided, we will try to read first line and detect schema.");
            log.warn("Schema in ColumnConfig.json are named as first line of data set path.");
        } else {
            isSchemaProvided = false;
            log.warn("No header path is provided, we will try to read first line and detect schema.");
            log.warn("Schema in ColumnConfig.json are named as  index 0, 1, 2, 3 ...");
            log.warn("Please make sure weight column and tag column are also taking index as name.");
        }
    }
    columnConfigList = new ArrayList<ColumnConfig>();
    for (int i = 0; i < fields.length; i++) {
        ColumnConfig config = new ColumnConfig();
        config.setColumnNum(i);
        fields[i] = CommonUtils.normColumnName(fields[i]);
        if (isSchemaProvided) {
            // config.setColumnName(CommonUtils.getRelativePigHeaderColumnName(fields[i]));
            config.setColumnName(fields[i]);
        } else {
            config.setColumnName(i + "");
        }
        columnConfigList.add(config);
    }
    ColumnConfigUpdater.updateColumnConfigFlags(modelConfig, columnConfigList, ModelStep.INIT);
    boolean hasTarget = false;
    for (ColumnConfig config : columnConfigList) {
        if (config.isTarget()) {
            hasTarget = true;
        }
    }
    if (!hasTarget) {
        log.error("Target is not valid: " + modelConfig.getTargetColumnName());
        if (StringUtils.isNotBlank(modelConfig.getHeaderPath())) {
            log.error("Please check your first line of data set file {}", modelConfig.getDataSetRawPath());
        } else {
            log.error("Please check your header file {} and your header delimiter {}", modelConfig.getHeaderPath(), modelConfig.getHeaderDelimiter());
        }
        return 1;
    }
    return 0;
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig)

Example 99 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class InitModelProcessor method setCategoricalColumnsByCountInfo.

private int setCategoricalColumnsByCountInfo(Map<Integer, Data> distinctCountMap, boolean distinctOn) {
    int cateCount = 0;
    for (ColumnConfig columnConfig : columnConfigList) {
        Data data = distinctCountMap.get(columnConfig.getColumnNum());
        if (data == null) {
            continue;
        }
        long distinctCount = data.distinctCount;
        if (distinctOn) {
            columnConfig.getColumnStats().setDistinctCount(distinctCount);
        }
        // only update categorical feature when autoTypeThreshold > 0, by default it is 0
        if (modelConfig.getDataSet().getAutoTypeThreshold() > 0) {
            long count = data.count;
            long invalidCount = data.invalidCount;
            long validNumCount = data.validNumcount;
            double numRatio = validNumCount * 1d / (count - invalidCount);
            // if numerical, check and set it
            if (!columnConfig.isCategorical()) {
                if (numRatio > modelConfig.getDataSet().getAutoTypeThreshold() / 100d) {
                    columnConfig.setColumnType(ColumnType.N);
                    log.info("Column {} with index {} is set to numeric type because of enough double values.", columnConfig.getColumnName(), columnConfig.getColumnNum());
                } else {
                    cateCount += 1;
                    columnConfig.setColumnType(ColumnType.C);
                    log.info("Column {} with index {} is set to categorical type because of not enough double values.", columnConfig.getColumnName(), columnConfig.getColumnNum());
                }
            }
        }
    }
    return cateCount;
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig)

Example 100 with ColumnConfig

use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.

the class ExportModelProcessor method run.

/*
     * (non-Javadoc)
     * 
     * @see ml.shifu.shifu.core.processor.Processor#run()
     */
@Override
public int run() throws Exception {
    setUp(ModelStep.EXPORT);
    int status = 0;
    File pmmls = new File("pmmls");
    FileUtils.forceMkdir(pmmls);
    if (StringUtils.isBlank(type)) {
        type = PMML;
    }
    String modelsPath = pathFinder.getModelsPath(SourceType.LOCAL);
    if (type.equalsIgnoreCase(ONE_BAGGING_MODEL)) {
        if (!"nn".equalsIgnoreCase(modelConfig.getAlgorithm()) && !CommonUtils.isTreeModel(modelConfig.getAlgorithm())) {
            log.warn("Currently one bagging model is only supported in NN/GBT/RF algorithm.");
        } else {
            List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(modelsPath, ALGORITHM.valueOf(modelConfig.getAlgorithm().toUpperCase()));
            if (models.size() < 1) {
                log.warn("No model is found in {}.", modelsPath);
            } else {
                log.info("Convert nn models into one binary bagging model.");
                Configuration conf = new Configuration();
                Path output = new Path(pathFinder.getBaggingModelPath(SourceType.LOCAL), "model.b" + modelConfig.getAlgorithm());
                if ("nn".equalsIgnoreCase(modelConfig.getAlgorithm())) {
                    BinaryNNSerializer.save(modelConfig, columnConfigList, models, FileSystem.getLocal(conf), output);
                } else if (CommonUtils.isTreeModel(modelConfig.getAlgorithm())) {
                    List<List<TreeNode>> baggingTrees = new ArrayList<List<TreeNode>>();
                    for (int i = 0; i < models.size(); i++) {
                        TreeModel tm = (TreeModel) models.get(i);
                        // TreeModel only has one TreeNode instance although it is list inside
                        baggingTrees.add(tm.getIndependentTreeModel().getTrees().get(0));
                    }
                    int[] inputOutputIndex = DTrainUtils.getNumericAndCategoricalInputAndOutputCounts(this.columnConfigList);
                    // numerical + categorical = # of all input
                    int inputCount = inputOutputIndex[0] + inputOutputIndex[1];
                    BinaryDTSerializer.save(modelConfig, columnConfigList, baggingTrees, modelConfig.getParams().get("Loss").toString(), inputCount, FileSystem.getLocal(conf), output);
                }
                log.info("Please find one unified bagging model in local {}.", output);
            }
        }
    } else if (type.equalsIgnoreCase(PMML)) {
        // typical pmml generation
        List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(modelsPath, ALGORITHM.valueOf(modelConfig.getAlgorithm().toUpperCase()));
        PMMLTranslator translator = PMMLConstructorFactory.produce(modelConfig, columnConfigList, isConcise(), false);
        for (int index = 0; index < models.size(); index++) {
            String path = "pmmls" + File.separator + modelConfig.getModelSetName() + Integer.toString(index) + ".pmml";
            log.info("\t Start to generate " + path);
            PMML pmml = translator.build(Arrays.asList(new BasicML[] { models.get(index) }));
            PMMLUtils.savePMML(pmml, path);
        }
    } else if (type.equalsIgnoreCase(ONE_BAGGING_PMML_MODEL)) {
        // one unified bagging pmml generation
        log.info("Convert models into one bagging pmml model {} format", type);
        if (!"nn".equalsIgnoreCase(modelConfig.getAlgorithm())) {
            log.warn("Currently one bagging pmml model is only supported in NN algorithm.");
        } else {
            List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(modelsPath, ALGORITHM.valueOf(modelConfig.getAlgorithm().toUpperCase()));
            PMMLTranslator translator = PMMLConstructorFactory.produce(modelConfig, columnConfigList, isConcise(), true);
            String path = "pmmls" + File.separator + modelConfig.getModelSetName() + ".pmml";
            log.info("\t Start to generate one unified model to: " + path);
            PMML pmml = translator.build(models);
            PMMLUtils.savePMML(pmml, path);
        }
    } else if (type.equalsIgnoreCase(COLUMN_STATS)) {
        saveColumnStatus();
    } else if (type.equalsIgnoreCase(WOE_MAPPING)) {
        List<ColumnConfig> exportCatColumns = new ArrayList<ColumnConfig>();
        List<String> catVariables = getRequestVars();
        for (ColumnConfig columnConfig : this.columnConfigList) {
            if (CollectionUtils.isEmpty(catVariables) || isRequestColumn(catVariables, columnConfig)) {
                exportCatColumns.add(columnConfig);
            }
        }
        if (CollectionUtils.isNotEmpty(exportCatColumns)) {
            List<String> woeMappings = new ArrayList<String>();
            for (ColumnConfig columnConfig : exportCatColumns) {
                String woeMapText = rebinAndExportWoeMapping(columnConfig);
                woeMappings.add(woeMapText);
            }
            FileUtils.write(new File("woemapping.txt"), StringUtils.join(woeMappings, ",\n"));
        }
    } else if (type.equalsIgnoreCase(WOE)) {
        List<String> woeInfos = new ArrayList<String>();
        for (ColumnConfig columnConfig : this.columnConfigList) {
            if (columnConfig.getBinLength() > 1 && ((columnConfig.isCategorical() && CollectionUtils.isNotEmpty(columnConfig.getBinCategory())) || (columnConfig.isNumerical() && CollectionUtils.isNotEmpty(columnConfig.getBinBoundary()) && columnConfig.getBinBoundary().size() > 1))) {
                List<String> varWoeInfos = generateWoeInfos(columnConfig);
                if (CollectionUtils.isNotEmpty(varWoeInfos)) {
                    woeInfos.addAll(varWoeInfos);
                    woeInfos.add("");
                }
            }
            FileUtils.writeLines(new File("varwoe_info.txt"), woeInfos);
        }
    } else if (type.equalsIgnoreCase(CORRELATION)) {
        // export correlation into mapping list
        if (!ShifuFileUtils.isFileExists(pathFinder.getLocalCorrelationCsvPath(), SourceType.LOCAL)) {
            log.warn("The correlation file doesn't exist. Please make sure you have ran `shifu stats -c`.");
            return 2;
        }
        return exportVariableCorr();
    } else {
        log.error("Unsupported output format - {}", type);
        status = -1;
    }
    clearUp(ModelStep.EXPORT);
    log.info("Done.");
    return status;
}
Also used : Path(org.apache.hadoop.fs.Path) Configuration(org.apache.hadoop.conf.Configuration) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) BasicML(org.encog.ml.BasicML) PMMLTranslator(ml.shifu.shifu.core.pmml.PMMLTranslator) TreeModel(ml.shifu.shifu.core.TreeModel) TreeNode(ml.shifu.shifu.core.dtrain.dt.TreeNode) PMML(org.dmg.pmml.PMML) File(java.io.File)

Aggregations

ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)131 ArrayList (java.util.ArrayList)36 Test (org.testng.annotations.Test)17 IOException (java.io.IOException)16 HashMap (java.util.HashMap)12 Tuple (org.apache.pig.data.Tuple)10 File (java.io.File)8 NSColumn (ml.shifu.shifu.column.NSColumn)8 ModelConfig (ml.shifu.shifu.container.obj.ModelConfig)8 ShifuException (ml.shifu.shifu.exception.ShifuException)8 Path (org.apache.hadoop.fs.Path)8 List (java.util.List)7 Scanner (java.util.Scanner)7 DataBag (org.apache.pig.data.DataBag)7 SourceType (ml.shifu.shifu.container.obj.RawSourceData.SourceType)5 BasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)5 TrainingDataSet (ml.shifu.shifu.core.dvarsel.dataset.TrainingDataSet)5 BasicMLData (org.encog.ml.data.basic.BasicMLData)5 BufferedWriter (java.io.BufferedWriter)3 FileInputStream (java.io.FileInputStream)3