Search in sources :

Example 1 with PMMLTranslator

use of ml.shifu.shifu.core.pmml.PMMLTranslator in project shifu by ShifuML.

the class PMMLConstructorFactory method produce.

public static PMMLTranslator produce(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, boolean isConcise, boolean isOutBaggingToOne) {
    AbstractPmmlElementCreator<Model> modelCreator = null;
    AbstractSpecifCreator specifCreator = null;
    if (ModelTrainConf.ALGORITHM.NN.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm())) {
        modelCreator = new NNPmmlModelCreator(modelConfig, columnConfigList, isConcise);
        specifCreator = new NNSpecifCreator(modelConfig, columnConfigList);
    } else if (ModelTrainConf.ALGORITHM.LR.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm())) {
        modelCreator = new RegressionPmmlModelCreator(modelConfig, columnConfigList, isConcise);
        specifCreator = new RegressionSpecifCreator(modelConfig, columnConfigList);
    } else if (ModelTrainConf.ALGORITHM.GBT.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm()) || ModelTrainConf.ALGORITHM.RF.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm())) {
        TreeEnsemblePmmlCreator gbtmodelCreator = new TreeEnsemblePmmlCreator(modelConfig, columnConfigList);
        AbstractPmmlElementCreator<DataDictionary> dataDictionaryCreator = new DataDictionaryCreator(modelConfig, columnConfigList);
        AbstractPmmlElementCreator<MiningSchema> miningSchemaCreator = new TreeModelMiningSchemaCreator(modelConfig, columnConfigList);
        return new TreeEnsemblePMMLTranslator(gbtmodelCreator, dataDictionaryCreator, miningSchemaCreator);
    } else {
        throw new RuntimeException("Model not supported: " + modelConfig.getTrain().getAlgorithm());
    }
    AbstractPmmlElementCreator<DataDictionary> dataDictionaryCreator = new DataDictionaryCreator(modelConfig, columnConfigList, isConcise);
    AbstractPmmlElementCreator<MiningSchema> miningSchemaCreator = new MiningSchemaCreator(modelConfig, columnConfigList, isConcise);
    AbstractPmmlElementCreator<ModelStats> modelStatsCreator = new ModelStatsCreator(modelConfig, columnConfigList, isConcise);
    AbstractPmmlElementCreator<LocalTransformations> localTransformationsCreator = null;
    ModelNormalizeConf.NormType normType = modelConfig.getNormalizeType();
    if (normType.equals(ModelNormalizeConf.NormType.WOE) || normType.equals(ModelNormalizeConf.NormType.WEIGHT_WOE)) {
        localTransformationsCreator = new WoeLocalTransformCreator(modelConfig, columnConfigList, isConcise);
    } else if (normType == ModelNormalizeConf.NormType.WOE_ZSCORE || normType == ModelNormalizeConf.NormType.WOE_ZSCALE) {
        localTransformationsCreator = new WoeZscoreLocalTransformCreator(modelConfig, columnConfigList, isConcise, false);
    } else if (normType == ModelNormalizeConf.NormType.WEIGHT_WOE_ZSCORE || normType == ModelNormalizeConf.NormType.WEIGHT_WOE_ZSCALE) {
        localTransformationsCreator = new WoeZscoreLocalTransformCreator(modelConfig, columnConfigList, isConcise, true);
    } else if (normType == ModelNormalizeConf.NormType.ZSCALE_ONEHOT) {
        localTransformationsCreator = new ZscoreOneHotLocalTransformCreator(modelConfig, columnConfigList, isConcise);
    } else {
        localTransformationsCreator = new ZscoreLocalTransformCreator(modelConfig, columnConfigList, isConcise);
    }
    return new PMMLTranslator(modelCreator, dataDictionaryCreator, miningSchemaCreator, modelStatsCreator, localTransformationsCreator, specifCreator, isOutBaggingToOne);
}
Also used : AbstractSpecifCreator(ml.shifu.shifu.core.pmml.builder.creator.AbstractSpecifCreator) TreeEnsemblePMMLTranslator(ml.shifu.shifu.core.pmml.TreeEnsemblePMMLTranslator) ModelNormalizeConf(ml.shifu.shifu.container.obj.ModelNormalizeConf) DataDictionary(org.dmg.pmml.DataDictionary) PMMLTranslator(ml.shifu.shifu.core.pmml.PMMLTranslator) TreeEnsemblePMMLTranslator(ml.shifu.shifu.core.pmml.TreeEnsemblePMMLTranslator) LocalTransformations(org.dmg.pmml.LocalTransformations) MiningSchema(org.dmg.pmml.MiningSchema) ModelStats(org.dmg.pmml.ModelStats) Model(org.dmg.pmml.Model)

Example 2 with PMMLTranslator

use of ml.shifu.shifu.core.pmml.PMMLTranslator in project shifu by ShifuML.

the class ExportModelProcessor method run.

/*
     * (non-Javadoc)
     * 
     * @see ml.shifu.shifu.core.processor.Processor#run()
     */
@Override
public int run() throws Exception {
    setUp(ModelStep.EXPORT);
    int status = 0;
    File pmmls = new File("pmmls");
    FileUtils.forceMkdir(pmmls);
    if (StringUtils.isBlank(type)) {
        type = PMML;
    }
    String modelsPath = pathFinder.getModelsPath(SourceType.LOCAL);
    if (type.equalsIgnoreCase(ONE_BAGGING_MODEL)) {
        if (!"nn".equalsIgnoreCase(modelConfig.getAlgorithm()) && !CommonUtils.isTreeModel(modelConfig.getAlgorithm())) {
            log.warn("Currently one bagging model is only supported in NN/GBT/RF algorithm.");
        } else {
            List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(modelsPath, ALGORITHM.valueOf(modelConfig.getAlgorithm().toUpperCase()));
            if (models.size() < 1) {
                log.warn("No model is found in {}.", modelsPath);
            } else {
                log.info("Convert nn models into one binary bagging model.");
                Configuration conf = new Configuration();
                Path output = new Path(pathFinder.getBaggingModelPath(SourceType.LOCAL), "model.b" + modelConfig.getAlgorithm());
                if ("nn".equalsIgnoreCase(modelConfig.getAlgorithm())) {
                    BinaryNNSerializer.save(modelConfig, columnConfigList, models, FileSystem.getLocal(conf), output);
                } else if (CommonUtils.isTreeModel(modelConfig.getAlgorithm())) {
                    List<List<TreeNode>> baggingTrees = new ArrayList<List<TreeNode>>();
                    for (int i = 0; i < models.size(); i++) {
                        TreeModel tm = (TreeModel) models.get(i);
                        // TreeModel only has one TreeNode instance although it is list inside
                        baggingTrees.add(tm.getIndependentTreeModel().getTrees().get(0));
                    }
                    int[] inputOutputIndex = DTrainUtils.getNumericAndCategoricalInputAndOutputCounts(this.columnConfigList);
                    // numerical + categorical = # of all input
                    int inputCount = inputOutputIndex[0] + inputOutputIndex[1];
                    BinaryDTSerializer.save(modelConfig, columnConfigList, baggingTrees, modelConfig.getParams().get("Loss").toString(), inputCount, FileSystem.getLocal(conf), output);
                }
                log.info("Please find one unified bagging model in local {}.", output);
            }
        }
    } else if (type.equalsIgnoreCase(PMML)) {
        // typical pmml generation
        List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(modelsPath, ALGORITHM.valueOf(modelConfig.getAlgorithm().toUpperCase()));
        PMMLTranslator translator = PMMLConstructorFactory.produce(modelConfig, columnConfigList, isConcise(), false);
        for (int index = 0; index < models.size(); index++) {
            String path = "pmmls" + File.separator + modelConfig.getModelSetName() + Integer.toString(index) + ".pmml";
            log.info("\t Start to generate " + path);
            PMML pmml = translator.build(Arrays.asList(new BasicML[] { models.get(index) }));
            PMMLUtils.savePMML(pmml, path);
        }
    } else if (type.equalsIgnoreCase(ONE_BAGGING_PMML_MODEL)) {
        // one unified bagging pmml generation
        log.info("Convert models into one bagging pmml model {} format", type);
        if (!"nn".equalsIgnoreCase(modelConfig.getAlgorithm())) {
            log.warn("Currently one bagging pmml model is only supported in NN algorithm.");
        } else {
            List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(modelsPath, ALGORITHM.valueOf(modelConfig.getAlgorithm().toUpperCase()));
            PMMLTranslator translator = PMMLConstructorFactory.produce(modelConfig, columnConfigList, isConcise(), true);
            String path = "pmmls" + File.separator + modelConfig.getModelSetName() + ".pmml";
            log.info("\t Start to generate one unified model to: " + path);
            PMML pmml = translator.build(models);
            PMMLUtils.savePMML(pmml, path);
        }
    } else if (type.equalsIgnoreCase(COLUMN_STATS)) {
        saveColumnStatus();
    } else if (type.equalsIgnoreCase(WOE_MAPPING)) {
        List<ColumnConfig> exportCatColumns = new ArrayList<ColumnConfig>();
        List<String> catVariables = getRequestVars();
        for (ColumnConfig columnConfig : this.columnConfigList) {
            if (CollectionUtils.isEmpty(catVariables) || isRequestColumn(catVariables, columnConfig)) {
                exportCatColumns.add(columnConfig);
            }
        }
        if (CollectionUtils.isNotEmpty(exportCatColumns)) {
            List<String> woeMappings = new ArrayList<String>();
            for (ColumnConfig columnConfig : exportCatColumns) {
                String woeMapText = rebinAndExportWoeMapping(columnConfig);
                woeMappings.add(woeMapText);
            }
            FileUtils.write(new File("woemapping.txt"), StringUtils.join(woeMappings, ",\n"));
        }
    } else if (type.equalsIgnoreCase(WOE)) {
        List<String> woeInfos = new ArrayList<String>();
        for (ColumnConfig columnConfig : this.columnConfigList) {
            if (columnConfig.getBinLength() > 1 && ((columnConfig.isCategorical() && CollectionUtils.isNotEmpty(columnConfig.getBinCategory())) || (columnConfig.isNumerical() && CollectionUtils.isNotEmpty(columnConfig.getBinBoundary()) && columnConfig.getBinBoundary().size() > 1))) {
                List<String> varWoeInfos = generateWoeInfos(columnConfig);
                if (CollectionUtils.isNotEmpty(varWoeInfos)) {
                    woeInfos.addAll(varWoeInfos);
                    woeInfos.add("");
                }
            }
            FileUtils.writeLines(new File("varwoe_info.txt"), woeInfos);
        }
    } else if (type.equalsIgnoreCase(CORRELATION)) {
        // export correlation into mapping list
        if (!ShifuFileUtils.isFileExists(pathFinder.getLocalCorrelationCsvPath(), SourceType.LOCAL)) {
            log.warn("The correlation file doesn't exist. Please make sure you have ran `shifu stats -c`.");
            return 2;
        }
        return exportVariableCorr();
    } else {
        log.error("Unsupported output format - {}", type);
        status = -1;
    }
    clearUp(ModelStep.EXPORT);
    log.info("Done.");
    return status;
}
Also used : Path(org.apache.hadoop.fs.Path) Configuration(org.apache.hadoop.conf.Configuration) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) BasicML(org.encog.ml.BasicML) PMMLTranslator(ml.shifu.shifu.core.pmml.PMMLTranslator) TreeModel(ml.shifu.shifu.core.TreeModel) TreeNode(ml.shifu.shifu.core.dtrain.dt.TreeNode) PMML(org.dmg.pmml.PMML) File(java.io.File)

Aggregations

PMMLTranslator (ml.shifu.shifu.core.pmml.PMMLTranslator)2 File (java.io.File)1 ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)1 ModelNormalizeConf (ml.shifu.shifu.container.obj.ModelNormalizeConf)1 TreeModel (ml.shifu.shifu.core.TreeModel)1 TreeNode (ml.shifu.shifu.core.dtrain.dt.TreeNode)1 TreeEnsemblePMMLTranslator (ml.shifu.shifu.core.pmml.TreeEnsemblePMMLTranslator)1 AbstractSpecifCreator (ml.shifu.shifu.core.pmml.builder.creator.AbstractSpecifCreator)1 Configuration (org.apache.hadoop.conf.Configuration)1 Path (org.apache.hadoop.fs.Path)1 DataDictionary (org.dmg.pmml.DataDictionary)1 LocalTransformations (org.dmg.pmml.LocalTransformations)1 MiningSchema (org.dmg.pmml.MiningSchema)1 Model (org.dmg.pmml.Model)1 ModelStats (org.dmg.pmml.ModelStats)1 PMML (org.dmg.pmml.PMML)1 BasicML (org.encog.ml.BasicML)1