use of ml.shifu.shifu.core.pmml.TreeEnsemblePMMLTranslator in project shifu by ShifuML.
the class PMMLConstructorFactory method produce.
public static PMMLTranslator produce(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, boolean isConcise, boolean isOutBaggingToOne) {
AbstractPmmlElementCreator<Model> modelCreator = null;
AbstractSpecifCreator specifCreator = null;
if (ModelTrainConf.ALGORITHM.NN.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm())) {
modelCreator = new NNPmmlModelCreator(modelConfig, columnConfigList, isConcise);
specifCreator = new NNSpecifCreator(modelConfig, columnConfigList);
} else if (ModelTrainConf.ALGORITHM.LR.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm())) {
modelCreator = new RegressionPmmlModelCreator(modelConfig, columnConfigList, isConcise);
specifCreator = new RegressionSpecifCreator(modelConfig, columnConfigList);
} else if (ModelTrainConf.ALGORITHM.GBT.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm()) || ModelTrainConf.ALGORITHM.RF.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm())) {
TreeEnsemblePmmlCreator gbtmodelCreator = new TreeEnsemblePmmlCreator(modelConfig, columnConfigList);
AbstractPmmlElementCreator<DataDictionary> dataDictionaryCreator = new DataDictionaryCreator(modelConfig, columnConfigList);
AbstractPmmlElementCreator<MiningSchema> miningSchemaCreator = new TreeModelMiningSchemaCreator(modelConfig, columnConfigList);
return new TreeEnsemblePMMLTranslator(gbtmodelCreator, dataDictionaryCreator, miningSchemaCreator);
} else {
throw new RuntimeException("Model not supported: " + modelConfig.getTrain().getAlgorithm());
}
AbstractPmmlElementCreator<DataDictionary> dataDictionaryCreator = new DataDictionaryCreator(modelConfig, columnConfigList, isConcise);
AbstractPmmlElementCreator<MiningSchema> miningSchemaCreator = new MiningSchemaCreator(modelConfig, columnConfigList, isConcise);
AbstractPmmlElementCreator<ModelStats> modelStatsCreator = new ModelStatsCreator(modelConfig, columnConfigList, isConcise);
AbstractPmmlElementCreator<LocalTransformations> localTransformationsCreator = null;
ModelNormalizeConf.NormType normType = modelConfig.getNormalizeType();
if (normType.equals(ModelNormalizeConf.NormType.WOE) || normType.equals(ModelNormalizeConf.NormType.WEIGHT_WOE)) {
localTransformationsCreator = new WoeLocalTransformCreator(modelConfig, columnConfigList, isConcise);
} else if (normType == ModelNormalizeConf.NormType.WOE_ZSCORE || normType == ModelNormalizeConf.NormType.WOE_ZSCALE) {
localTransformationsCreator = new WoeZscoreLocalTransformCreator(modelConfig, columnConfigList, isConcise, false);
} else if (normType == ModelNormalizeConf.NormType.WEIGHT_WOE_ZSCORE || normType == ModelNormalizeConf.NormType.WEIGHT_WOE_ZSCALE) {
localTransformationsCreator = new WoeZscoreLocalTransformCreator(modelConfig, columnConfigList, isConcise, true);
} else if (normType == ModelNormalizeConf.NormType.ZSCALE_ONEHOT) {
localTransformationsCreator = new ZscoreOneHotLocalTransformCreator(modelConfig, columnConfigList, isConcise);
} else {
localTransformationsCreator = new ZscoreLocalTransformCreator(modelConfig, columnConfigList, isConcise);
}
return new PMMLTranslator(modelCreator, dataDictionaryCreator, miningSchemaCreator, modelStatsCreator, localTransformationsCreator, specifCreator, isOutBaggingToOne);
}
Aggregations