use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class TrainModelProcessor method checkAndCleanDataForTreeModels.
/**
* For RF/GBT model, no need do normalizing, but clean and filter data is needed. Before real training, we have to
* clean and filter data.
*
* @param isToShuffle
* if shuffle data before training
* @throws IOException
* the io exception
*/
protected void checkAndCleanDataForTreeModels(boolean isToShuffle) throws IOException {
String alg = this.getModelConfig().getTrain().getAlgorithm();
// only for tree models
if (!CommonUtils.isTreeModel(alg)) {
return;
}
// check if binBoundaries and binCategories are good and log error
for (ColumnConfig columnConfig : columnConfigList) {
if (columnConfig.isFinalSelect() && !columnConfig.isTarget() && !columnConfig.isMeta()) {
if (columnConfig.isNumerical() && columnConfig.getBinBoundary() == null) {
throw new IllegalArgumentException("Final select " + columnConfig.getColumnName() + "column but binBoundary in ColumnConfig.json is null.");
}
if (columnConfig.isNumerical() && columnConfig.getBinBoundary().size() <= 1) {
LOG.warn("Column {} {} with only one or zero element in binBounday, such column will be ignored in tree model training.", columnConfig.getColumnNum(), columnConfig.getColumnName());
}
if (columnConfig.isCategorical() && columnConfig.getBinCategory() == null) {
throw new IllegalArgumentException("Final select " + columnConfig.getColumnName() + "column but binCategory in ColumnConfig.json is null.");
}
if (columnConfig.isCategorical() && columnConfig.getBinCategory().size() <= 0) {
LOG.warn("Column {} {} with only zero element in binCategory, such column will be ignored in tree model training.", columnConfig.getColumnNum(), columnConfig.getColumnName());
}
}
}
// run cleaning data logic for model input
SourceType sourceType = modelConfig.getDataSet().getSource();
String cleanedDataPath = this.pathFinder.getCleanedDataPath();
String needReGen = Environment.getProperty("shifu.tree.regeninput", Boolean.FALSE.toString());
// tree ensemble model training
if (Boolean.TRUE.toString().equalsIgnoreCase(needReGen) || !ShifuFileUtils.isFileExists(cleanedDataPath, sourceType) || (StringUtils.isNotBlank(modelConfig.getValidationDataSetRawPath()) && !ShifuFileUtils.isFileExists(pathFinder.getCleanedValidationDataPath(), sourceType))) {
runDataClean(isToShuffle);
} else {
// no need regen data
LOG.warn("For RF/GBT, training input in {} exists, no need to regenerate it.", cleanedDataPath);
LOG.warn("Need regen it, please set shifu.tree.regeninput in shifuconfig to true.");
}
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class InitModelProcessor method setCategoricalColumnsAndDistinctAccount.
private int setCategoricalColumnsAndDistinctAccount(Map<Integer, Data> distinctCountMap, boolean cateOn, boolean distinctOn) {
int cateCount = 0;
for (ColumnConfig columnConfig : columnConfigList) {
Data data = distinctCountMap.get(columnConfig.getColumnNum());
if (data == null) {
continue;
}
Long distinctCount = data.distinctCount;
// disable auto type threshold
if (distinctCount != null) {
if (cateOn) {
String[] items = data.items;
if (isBinaryVariable(distinctCount, items)) {
log.info("Column {} with index {} is set to numeric type because of 0-1 variable. Distinct count {}, items {}.", columnConfig.getColumnName(), columnConfig.getColumnNum(), distinctCount, Arrays.toString(items));
columnConfig.setColumnType(ColumnType.N);
} else if (isDoubleFrequentVariable(items)) {
log.info("Column {} with index {} is set to numeric type because of all sampled items are double(including blank). Distinct count {}.", columnConfig.getColumnName(), columnConfig.getColumnNum(), distinctCount);
columnConfig.setColumnType(ColumnType.N);
} else {
columnConfig.setColumnType(ColumnType.C);
cateCount += 1;
log.info("Column {} with index {} is set to categorical type according to auto type checking: distinct count {}.", columnConfig.getColumnName(), columnConfig.getColumnNum(), distinctCount);
}
}
if (distinctOn) {
columnConfig.getColumnStats().setDistinctCount(distinctCount);
}
}
}
return cateCount;
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class InitModelProcessor method initColumnConfigList.
/**
* initialize the columnConfig file
*
* @throws IOException
*/
private int initColumnConfigList() throws IOException {
String[] fields = null;
boolean isSchemaProvided = true;
if (StringUtils.isNotBlank(modelConfig.getHeaderPath())) {
fields = CommonUtils.getHeaders(modelConfig.getHeaderPath(), modelConfig.getHeaderDelimiter(), modelConfig.getDataSet().getSource());
String[] dataInFirstLine = CommonUtils.takeFirstLine(modelConfig.getDataSetRawPath(), modelConfig.getDataSetDelimiter(), modelConfig.getDataSet().getSource());
if (fields.length != dataInFirstLine.length) {
throw new IllegalArgumentException("Header length and data length are not consistent - head length " + fields.length + ", while data length " + dataInFirstLine.length + ", please check you header setting and data set setting.");
}
} else {
fields = CommonUtils.takeFirstLine(modelConfig.getDataSetRawPath(), StringUtils.isBlank(modelConfig.getHeaderDelimiter()) ? modelConfig.getDataSetDelimiter() : modelConfig.getHeaderDelimiter(), modelConfig.getDataSet().getSource());
if (StringUtils.join(fields, "").contains(modelConfig.getTargetColumnName())) {
// if first line contains target column name, we guess it is csv format and first line is header.
isSchemaProvided = true;
// first line of data meaning second line in data files excluding first header line
String[] dataInFirstLine = CommonUtils.takeFirstTwoLines(modelConfig.getDataSetRawPath(), StringUtils.isBlank(modelConfig.getHeaderDelimiter()) ? modelConfig.getDataSetDelimiter() : modelConfig.getHeaderDelimiter(), modelConfig.getDataSet().getSource())[1];
if (dataInFirstLine != null && fields.length != dataInFirstLine.length) {
throw new IllegalArgumentException("Header length and data length are not consistent, please check you header setting and data set setting.");
}
log.warn("No header path is provided, we will try to read first line and detect schema.");
log.warn("Schema in ColumnConfig.json are named as first line of data set path.");
} else {
isSchemaProvided = false;
log.warn("No header path is provided, we will try to read first line and detect schema.");
log.warn("Schema in ColumnConfig.json are named as index 0, 1, 2, 3 ...");
log.warn("Please make sure weight column and tag column are also taking index as name.");
}
}
columnConfigList = new ArrayList<ColumnConfig>();
for (int i = 0; i < fields.length; i++) {
ColumnConfig config = new ColumnConfig();
config.setColumnNum(i);
fields[i] = CommonUtils.normColumnName(fields[i]);
if (isSchemaProvided) {
// config.setColumnName(CommonUtils.getRelativePigHeaderColumnName(fields[i]));
config.setColumnName(fields[i]);
} else {
config.setColumnName(i + "");
}
columnConfigList.add(config);
}
ColumnConfigUpdater.updateColumnConfigFlags(modelConfig, columnConfigList, ModelStep.INIT);
boolean hasTarget = false;
for (ColumnConfig config : columnConfigList) {
if (config.isTarget()) {
hasTarget = true;
}
}
if (!hasTarget) {
log.error("Target is not valid: " + modelConfig.getTargetColumnName());
if (StringUtils.isNotBlank(modelConfig.getHeaderPath())) {
log.error("Please check your first line of data set file {}", modelConfig.getDataSetRawPath());
} else {
log.error("Please check your header file {} and your header delimiter {}", modelConfig.getHeaderPath(), modelConfig.getHeaderDelimiter());
}
return 1;
}
return 0;
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class InitModelProcessor method setCategoricalColumnsByCountInfo.
private int setCategoricalColumnsByCountInfo(Map<Integer, Data> distinctCountMap, boolean distinctOn) {
int cateCount = 0;
for (ColumnConfig columnConfig : columnConfigList) {
Data data = distinctCountMap.get(columnConfig.getColumnNum());
if (data == null) {
continue;
}
long distinctCount = data.distinctCount;
if (distinctOn) {
columnConfig.getColumnStats().setDistinctCount(distinctCount);
}
// only update categorical feature when autoTypeThreshold > 0, by default it is 0
if (modelConfig.getDataSet().getAutoTypeThreshold() > 0) {
long count = data.count;
long invalidCount = data.invalidCount;
long validNumCount = data.validNumcount;
double numRatio = validNumCount * 1d / (count - invalidCount);
// if numerical, check and set it
if (!columnConfig.isCategorical()) {
if (numRatio > modelConfig.getDataSet().getAutoTypeThreshold() / 100d) {
columnConfig.setColumnType(ColumnType.N);
log.info("Column {} with index {} is set to numeric type because of enough double values.", columnConfig.getColumnName(), columnConfig.getColumnNum());
} else {
cateCount += 1;
columnConfig.setColumnType(ColumnType.C);
log.info("Column {} with index {} is set to categorical type because of not enough double values.", columnConfig.getColumnName(), columnConfig.getColumnNum());
}
}
}
}
return cateCount;
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class ExportModelProcessor method run.
/*
* (non-Javadoc)
*
* @see ml.shifu.shifu.core.processor.Processor#run()
*/
@Override
public int run() throws Exception {
setUp(ModelStep.EXPORT);
int status = 0;
File pmmls = new File("pmmls");
FileUtils.forceMkdir(pmmls);
if (StringUtils.isBlank(type)) {
type = PMML;
}
String modelsPath = pathFinder.getModelsPath(SourceType.LOCAL);
if (type.equalsIgnoreCase(ONE_BAGGING_MODEL)) {
if (!"nn".equalsIgnoreCase(modelConfig.getAlgorithm()) && !CommonUtils.isTreeModel(modelConfig.getAlgorithm())) {
log.warn("Currently one bagging model is only supported in NN/GBT/RF algorithm.");
} else {
List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(modelsPath, ALGORITHM.valueOf(modelConfig.getAlgorithm().toUpperCase()));
if (models.size() < 1) {
log.warn("No model is found in {}.", modelsPath);
} else {
log.info("Convert nn models into one binary bagging model.");
Configuration conf = new Configuration();
Path output = new Path(pathFinder.getBaggingModelPath(SourceType.LOCAL), "model.b" + modelConfig.getAlgorithm());
if ("nn".equalsIgnoreCase(modelConfig.getAlgorithm())) {
BinaryNNSerializer.save(modelConfig, columnConfigList, models, FileSystem.getLocal(conf), output);
} else if (CommonUtils.isTreeModel(modelConfig.getAlgorithm())) {
List<List<TreeNode>> baggingTrees = new ArrayList<List<TreeNode>>();
for (int i = 0; i < models.size(); i++) {
TreeModel tm = (TreeModel) models.get(i);
// TreeModel only has one TreeNode instance although it is list inside
baggingTrees.add(tm.getIndependentTreeModel().getTrees().get(0));
}
int[] inputOutputIndex = DTrainUtils.getNumericAndCategoricalInputAndOutputCounts(this.columnConfigList);
// numerical + categorical = # of all input
int inputCount = inputOutputIndex[0] + inputOutputIndex[1];
BinaryDTSerializer.save(modelConfig, columnConfigList, baggingTrees, modelConfig.getParams().get("Loss").toString(), inputCount, FileSystem.getLocal(conf), output);
}
log.info("Please find one unified bagging model in local {}.", output);
}
}
} else if (type.equalsIgnoreCase(PMML)) {
// typical pmml generation
List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(modelsPath, ALGORITHM.valueOf(modelConfig.getAlgorithm().toUpperCase()));
PMMLTranslator translator = PMMLConstructorFactory.produce(modelConfig, columnConfigList, isConcise(), false);
for (int index = 0; index < models.size(); index++) {
String path = "pmmls" + File.separator + modelConfig.getModelSetName() + Integer.toString(index) + ".pmml";
log.info("\t Start to generate " + path);
PMML pmml = translator.build(Arrays.asList(new BasicML[] { models.get(index) }));
PMMLUtils.savePMML(pmml, path);
}
} else if (type.equalsIgnoreCase(ONE_BAGGING_PMML_MODEL)) {
// one unified bagging pmml generation
log.info("Convert models into one bagging pmml model {} format", type);
if (!"nn".equalsIgnoreCase(modelConfig.getAlgorithm())) {
log.warn("Currently one bagging pmml model is only supported in NN algorithm.");
} else {
List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(modelsPath, ALGORITHM.valueOf(modelConfig.getAlgorithm().toUpperCase()));
PMMLTranslator translator = PMMLConstructorFactory.produce(modelConfig, columnConfigList, isConcise(), true);
String path = "pmmls" + File.separator + modelConfig.getModelSetName() + ".pmml";
log.info("\t Start to generate one unified model to: " + path);
PMML pmml = translator.build(models);
PMMLUtils.savePMML(pmml, path);
}
} else if (type.equalsIgnoreCase(COLUMN_STATS)) {
saveColumnStatus();
} else if (type.equalsIgnoreCase(WOE_MAPPING)) {
List<ColumnConfig> exportCatColumns = new ArrayList<ColumnConfig>();
List<String> catVariables = getRequestVars();
for (ColumnConfig columnConfig : this.columnConfigList) {
if (CollectionUtils.isEmpty(catVariables) || isRequestColumn(catVariables, columnConfig)) {
exportCatColumns.add(columnConfig);
}
}
if (CollectionUtils.isNotEmpty(exportCatColumns)) {
List<String> woeMappings = new ArrayList<String>();
for (ColumnConfig columnConfig : exportCatColumns) {
String woeMapText = rebinAndExportWoeMapping(columnConfig);
woeMappings.add(woeMapText);
}
FileUtils.write(new File("woemapping.txt"), StringUtils.join(woeMappings, ",\n"));
}
} else if (type.equalsIgnoreCase(WOE)) {
List<String> woeInfos = new ArrayList<String>();
for (ColumnConfig columnConfig : this.columnConfigList) {
if (columnConfig.getBinLength() > 1 && ((columnConfig.isCategorical() && CollectionUtils.isNotEmpty(columnConfig.getBinCategory())) || (columnConfig.isNumerical() && CollectionUtils.isNotEmpty(columnConfig.getBinBoundary()) && columnConfig.getBinBoundary().size() > 1))) {
List<String> varWoeInfos = generateWoeInfos(columnConfig);
if (CollectionUtils.isNotEmpty(varWoeInfos)) {
woeInfos.addAll(varWoeInfos);
woeInfos.add("");
}
}
FileUtils.writeLines(new File("varwoe_info.txt"), woeInfos);
}
} else if (type.equalsIgnoreCase(CORRELATION)) {
// export correlation into mapping list
if (!ShifuFileUtils.isFileExists(pathFinder.getLocalCorrelationCsvPath(), SourceType.LOCAL)) {
log.warn("The correlation file doesn't exist. Please make sure you have ran `shifu stats -c`.");
return 2;
}
return exportVariableCorr();
} else {
log.error("Unsupported output format - {}", type);
status = -1;
}
clearUp(ModelStep.EXPORT);
log.info("Done.");
return status;
}
Aggregations