use of ml.shifu.shifu.core.pmml.PMMLTranslator in project shifu by ShifuML.
the class PMMLConstructorFactory method produce.
public static PMMLTranslator produce(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, boolean isConcise, boolean isOutBaggingToOne) {
AbstractPmmlElementCreator<Model> modelCreator = null;
AbstractSpecifCreator specifCreator = null;
if (ModelTrainConf.ALGORITHM.NN.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm())) {
modelCreator = new NNPmmlModelCreator(modelConfig, columnConfigList, isConcise);
specifCreator = new NNSpecifCreator(modelConfig, columnConfigList);
} else if (ModelTrainConf.ALGORITHM.LR.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm())) {
modelCreator = new RegressionPmmlModelCreator(modelConfig, columnConfigList, isConcise);
specifCreator = new RegressionSpecifCreator(modelConfig, columnConfigList);
} else if (ModelTrainConf.ALGORITHM.GBT.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm()) || ModelTrainConf.ALGORITHM.RF.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm())) {
TreeEnsemblePmmlCreator gbtmodelCreator = new TreeEnsemblePmmlCreator(modelConfig, columnConfigList);
AbstractPmmlElementCreator<DataDictionary> dataDictionaryCreator = new DataDictionaryCreator(modelConfig, columnConfigList);
AbstractPmmlElementCreator<MiningSchema> miningSchemaCreator = new TreeModelMiningSchemaCreator(modelConfig, columnConfigList);
return new TreeEnsemblePMMLTranslator(gbtmodelCreator, dataDictionaryCreator, miningSchemaCreator);
} else {
throw new RuntimeException("Model not supported: " + modelConfig.getTrain().getAlgorithm());
}
AbstractPmmlElementCreator<DataDictionary> dataDictionaryCreator = new DataDictionaryCreator(modelConfig, columnConfigList, isConcise);
AbstractPmmlElementCreator<MiningSchema> miningSchemaCreator = new MiningSchemaCreator(modelConfig, columnConfigList, isConcise);
AbstractPmmlElementCreator<ModelStats> modelStatsCreator = new ModelStatsCreator(modelConfig, columnConfigList, isConcise);
AbstractPmmlElementCreator<LocalTransformations> localTransformationsCreator = null;
ModelNormalizeConf.NormType normType = modelConfig.getNormalizeType();
if (normType.equals(ModelNormalizeConf.NormType.WOE) || normType.equals(ModelNormalizeConf.NormType.WEIGHT_WOE)) {
localTransformationsCreator = new WoeLocalTransformCreator(modelConfig, columnConfigList, isConcise);
} else if (normType == ModelNormalizeConf.NormType.WOE_ZSCORE || normType == ModelNormalizeConf.NormType.WOE_ZSCALE) {
localTransformationsCreator = new WoeZscoreLocalTransformCreator(modelConfig, columnConfigList, isConcise, false);
} else if (normType == ModelNormalizeConf.NormType.WEIGHT_WOE_ZSCORE || normType == ModelNormalizeConf.NormType.WEIGHT_WOE_ZSCALE) {
localTransformationsCreator = new WoeZscoreLocalTransformCreator(modelConfig, columnConfigList, isConcise, true);
} else if (normType == ModelNormalizeConf.NormType.ZSCALE_ONEHOT) {
localTransformationsCreator = new ZscoreOneHotLocalTransformCreator(modelConfig, columnConfigList, isConcise);
} else {
localTransformationsCreator = new ZscoreLocalTransformCreator(modelConfig, columnConfigList, isConcise);
}
return new PMMLTranslator(modelCreator, dataDictionaryCreator, miningSchemaCreator, modelStatsCreator, localTransformationsCreator, specifCreator, isOutBaggingToOne);
}
use of ml.shifu.shifu.core.pmml.PMMLTranslator in project shifu by ShifuML.
the class ExportModelProcessor method run.
/*
* (non-Javadoc)
*
* @see ml.shifu.shifu.core.processor.Processor#run()
*/
@Override
public int run() throws Exception {
setUp(ModelStep.EXPORT);
int status = 0;
File pmmls = new File("pmmls");
FileUtils.forceMkdir(pmmls);
if (StringUtils.isBlank(type)) {
type = PMML;
}
String modelsPath = pathFinder.getModelsPath(SourceType.LOCAL);
if (type.equalsIgnoreCase(ONE_BAGGING_MODEL)) {
if (!"nn".equalsIgnoreCase(modelConfig.getAlgorithm()) && !CommonUtils.isTreeModel(modelConfig.getAlgorithm())) {
log.warn("Currently one bagging model is only supported in NN/GBT/RF algorithm.");
} else {
List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(modelsPath, ALGORITHM.valueOf(modelConfig.getAlgorithm().toUpperCase()));
if (models.size() < 1) {
log.warn("No model is found in {}.", modelsPath);
} else {
log.info("Convert nn models into one binary bagging model.");
Configuration conf = new Configuration();
Path output = new Path(pathFinder.getBaggingModelPath(SourceType.LOCAL), "model.b" + modelConfig.getAlgorithm());
if ("nn".equalsIgnoreCase(modelConfig.getAlgorithm())) {
BinaryNNSerializer.save(modelConfig, columnConfigList, models, FileSystem.getLocal(conf), output);
} else if (CommonUtils.isTreeModel(modelConfig.getAlgorithm())) {
List<List<TreeNode>> baggingTrees = new ArrayList<List<TreeNode>>();
for (int i = 0; i < models.size(); i++) {
TreeModel tm = (TreeModel) models.get(i);
// TreeModel only has one TreeNode instance although it is list inside
baggingTrees.add(tm.getIndependentTreeModel().getTrees().get(0));
}
int[] inputOutputIndex = DTrainUtils.getNumericAndCategoricalInputAndOutputCounts(this.columnConfigList);
// numerical + categorical = # of all input
int inputCount = inputOutputIndex[0] + inputOutputIndex[1];
BinaryDTSerializer.save(modelConfig, columnConfigList, baggingTrees, modelConfig.getParams().get("Loss").toString(), inputCount, FileSystem.getLocal(conf), output);
}
log.info("Please find one unified bagging model in local {}.", output);
}
}
} else if (type.equalsIgnoreCase(PMML)) {
// typical pmml generation
List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(modelsPath, ALGORITHM.valueOf(modelConfig.getAlgorithm().toUpperCase()));
PMMLTranslator translator = PMMLConstructorFactory.produce(modelConfig, columnConfigList, isConcise(), false);
for (int index = 0; index < models.size(); index++) {
String path = "pmmls" + File.separator + modelConfig.getModelSetName() + Integer.toString(index) + ".pmml";
log.info("\t Start to generate " + path);
PMML pmml = translator.build(Arrays.asList(new BasicML[] { models.get(index) }));
PMMLUtils.savePMML(pmml, path);
}
} else if (type.equalsIgnoreCase(ONE_BAGGING_PMML_MODEL)) {
// one unified bagging pmml generation
log.info("Convert models into one bagging pmml model {} format", type);
if (!"nn".equalsIgnoreCase(modelConfig.getAlgorithm())) {
log.warn("Currently one bagging pmml model is only supported in NN algorithm.");
} else {
List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(modelsPath, ALGORITHM.valueOf(modelConfig.getAlgorithm().toUpperCase()));
PMMLTranslator translator = PMMLConstructorFactory.produce(modelConfig, columnConfigList, isConcise(), true);
String path = "pmmls" + File.separator + modelConfig.getModelSetName() + ".pmml";
log.info("\t Start to generate one unified model to: " + path);
PMML pmml = translator.build(models);
PMMLUtils.savePMML(pmml, path);
}
} else if (type.equalsIgnoreCase(COLUMN_STATS)) {
saveColumnStatus();
} else if (type.equalsIgnoreCase(WOE_MAPPING)) {
List<ColumnConfig> exportCatColumns = new ArrayList<ColumnConfig>();
List<String> catVariables = getRequestVars();
for (ColumnConfig columnConfig : this.columnConfigList) {
if (CollectionUtils.isEmpty(catVariables) || isRequestColumn(catVariables, columnConfig)) {
exportCatColumns.add(columnConfig);
}
}
if (CollectionUtils.isNotEmpty(exportCatColumns)) {
List<String> woeMappings = new ArrayList<String>();
for (ColumnConfig columnConfig : exportCatColumns) {
String woeMapText = rebinAndExportWoeMapping(columnConfig);
woeMappings.add(woeMapText);
}
FileUtils.write(new File("woemapping.txt"), StringUtils.join(woeMappings, ",\n"));
}
} else if (type.equalsIgnoreCase(WOE)) {
List<String> woeInfos = new ArrayList<String>();
for (ColumnConfig columnConfig : this.columnConfigList) {
if (columnConfig.getBinLength() > 1 && ((columnConfig.isCategorical() && CollectionUtils.isNotEmpty(columnConfig.getBinCategory())) || (columnConfig.isNumerical() && CollectionUtils.isNotEmpty(columnConfig.getBinBoundary()) && columnConfig.getBinBoundary().size() > 1))) {
List<String> varWoeInfos = generateWoeInfos(columnConfig);
if (CollectionUtils.isNotEmpty(varWoeInfos)) {
woeInfos.addAll(varWoeInfos);
woeInfos.add("");
}
}
FileUtils.writeLines(new File("varwoe_info.txt"), woeInfos);
}
} else if (type.equalsIgnoreCase(CORRELATION)) {
// export correlation into mapping list
if (!ShifuFileUtils.isFileExists(pathFinder.getLocalCorrelationCsvPath(), SourceType.LOCAL)) {
log.warn("The correlation file doesn't exist. Please make sure you have ran `shifu stats -c`.");
return 2;
}
return exportVariableCorr();
} else {
log.error("Unsupported output format - {}", type);
status = -1;
}
clearUp(ModelStep.EXPORT);
log.info("Done.");
return status;
}
Aggregations