Search in sources :

Example 1 with TreeEnsemblePMMLTranslator

use of ml.shifu.shifu.core.pmml.TreeEnsemblePMMLTranslator in project shifu by ShifuML.

the class PMMLConstructorFactory method produce.

public static PMMLTranslator produce(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, boolean isConcise, boolean isOutBaggingToOne) {
    AbstractPmmlElementCreator<Model> modelCreator = null;
    AbstractSpecifCreator specifCreator = null;
    if (ModelTrainConf.ALGORITHM.NN.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm())) {
        modelCreator = new NNPmmlModelCreator(modelConfig, columnConfigList, isConcise);
        specifCreator = new NNSpecifCreator(modelConfig, columnConfigList);
    } else if (ModelTrainConf.ALGORITHM.LR.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm())) {
        modelCreator = new RegressionPmmlModelCreator(modelConfig, columnConfigList, isConcise);
        specifCreator = new RegressionSpecifCreator(modelConfig, columnConfigList);
    } else if (ModelTrainConf.ALGORITHM.GBT.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm()) || ModelTrainConf.ALGORITHM.RF.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm())) {
        TreeEnsemblePmmlCreator gbtmodelCreator = new TreeEnsemblePmmlCreator(modelConfig, columnConfigList);
        AbstractPmmlElementCreator<DataDictionary> dataDictionaryCreator = new DataDictionaryCreator(modelConfig, columnConfigList);
        AbstractPmmlElementCreator<MiningSchema> miningSchemaCreator = new TreeModelMiningSchemaCreator(modelConfig, columnConfigList);
        return new TreeEnsemblePMMLTranslator(gbtmodelCreator, dataDictionaryCreator, miningSchemaCreator);
    } else {
        throw new RuntimeException("Model not supported: " + modelConfig.getTrain().getAlgorithm());
    }
    AbstractPmmlElementCreator<DataDictionary> dataDictionaryCreator = new DataDictionaryCreator(modelConfig, columnConfigList, isConcise);
    AbstractPmmlElementCreator<MiningSchema> miningSchemaCreator = new MiningSchemaCreator(modelConfig, columnConfigList, isConcise);
    AbstractPmmlElementCreator<ModelStats> modelStatsCreator = new ModelStatsCreator(modelConfig, columnConfigList, isConcise);
    AbstractPmmlElementCreator<LocalTransformations> localTransformationsCreator = null;
    ModelNormalizeConf.NormType normType = modelConfig.getNormalizeType();
    if (normType.equals(ModelNormalizeConf.NormType.WOE) || normType.equals(ModelNormalizeConf.NormType.WEIGHT_WOE)) {
        localTransformationsCreator = new WoeLocalTransformCreator(modelConfig, columnConfigList, isConcise);
    } else if (normType == ModelNormalizeConf.NormType.WOE_ZSCORE || normType == ModelNormalizeConf.NormType.WOE_ZSCALE) {
        localTransformationsCreator = new WoeZscoreLocalTransformCreator(modelConfig, columnConfigList, isConcise, false);
    } else if (normType == ModelNormalizeConf.NormType.WEIGHT_WOE_ZSCORE || normType == ModelNormalizeConf.NormType.WEIGHT_WOE_ZSCALE) {
        localTransformationsCreator = new WoeZscoreLocalTransformCreator(modelConfig, columnConfigList, isConcise, true);
    } else if (normType == ModelNormalizeConf.NormType.ZSCALE_ONEHOT) {
        localTransformationsCreator = new ZscoreOneHotLocalTransformCreator(modelConfig, columnConfigList, isConcise);
    } else {
        localTransformationsCreator = new ZscoreLocalTransformCreator(modelConfig, columnConfigList, isConcise);
    }
    return new PMMLTranslator(modelCreator, dataDictionaryCreator, miningSchemaCreator, modelStatsCreator, localTransformationsCreator, specifCreator, isOutBaggingToOne);
}
Also used : AbstractSpecifCreator(ml.shifu.shifu.core.pmml.builder.creator.AbstractSpecifCreator) TreeEnsemblePMMLTranslator(ml.shifu.shifu.core.pmml.TreeEnsemblePMMLTranslator) ModelNormalizeConf(ml.shifu.shifu.container.obj.ModelNormalizeConf) DataDictionary(org.dmg.pmml.DataDictionary) PMMLTranslator(ml.shifu.shifu.core.pmml.PMMLTranslator) TreeEnsemblePMMLTranslator(ml.shifu.shifu.core.pmml.TreeEnsemblePMMLTranslator) LocalTransformations(org.dmg.pmml.LocalTransformations) MiningSchema(org.dmg.pmml.MiningSchema) ModelStats(org.dmg.pmml.ModelStats) Model(org.dmg.pmml.Model)

Aggregations

ModelNormalizeConf (ml.shifu.shifu.container.obj.ModelNormalizeConf)1 PMMLTranslator (ml.shifu.shifu.core.pmml.PMMLTranslator)1 TreeEnsemblePMMLTranslator (ml.shifu.shifu.core.pmml.TreeEnsemblePMMLTranslator)1 AbstractSpecifCreator (ml.shifu.shifu.core.pmml.builder.creator.AbstractSpecifCreator)1 DataDictionary (org.dmg.pmml.DataDictionary)1 LocalTransformations (org.dmg.pmml.LocalTransformations)1 MiningSchema (org.dmg.pmml.MiningSchema)1 Model (org.dmg.pmml.Model)1 ModelStats (org.dmg.pmml.ModelStats)1