Search in sources :

Example 1 with ModelStats

use of org.dmg.pmml.ModelStats in project shifu by ShifuML.

the class ModelStatsCreator method build.

@Override
public ModelStats build(BasicML basicML) {
    ModelStats modelStats = new ModelStats();
    if (basicML instanceof BasicFloatNetwork) {
        BasicFloatNetwork bfn = (BasicFloatNetwork) basicML;
        Set<Integer> featureSet = bfn.getFeatureSet();
        for (ColumnConfig columnConfig : columnConfigList) {
            if (columnConfig.isFinalSelect() && (CollectionUtils.isEmpty(featureSet) || featureSet.contains(columnConfig.getColumnNum()))) {
                UnivariateStats univariateStats = new UnivariateStats();
                // here, no need to consider if column is in segment expansion
                // as we need to address new stats variable
                // set simple column name in PMML
                univariateStats.setField(FieldName.create(NormalUtils.getSimpleColumnName(columnConfig.getColumnName())));
                if (columnConfig.isCategorical()) {
                    DiscrStats discrStats = new DiscrStats();
                    Array countArray = createCountArray(columnConfig);
                    discrStats.addArrays(countArray);
                    if (!isConcise) {
                        List<Extension> extensions = createExtensions(columnConfig);
                        discrStats.addExtensions(extensions.toArray(new Extension[extensions.size()]));
                    }
                    univariateStats.setDiscrStats(discrStats);
                } else {
                    // numerical column
                    univariateStats.setNumericInfo(createNumericInfo(columnConfig));
                    if (!isConcise) {
                        univariateStats.setContStats(createConStats(columnConfig));
                    }
                }
                modelStats.addUnivariateStats(univariateStats);
            }
        }
    } else {
        for (ColumnConfig columnConfig : columnConfigList) {
            if (columnConfig.isFinalSelect()) {
                UnivariateStats univariateStats = new UnivariateStats();
                // here, no need to consider if column is in segment expansion as we need to address new stats
                // variable
                univariateStats.setField(FieldName.create(NormalUtils.getSimpleColumnName(columnConfig.getColumnName())));
                if (columnConfig.isCategorical()) {
                    DiscrStats discrStats = new DiscrStats();
                    Array countArray = createCountArray(columnConfig);
                    discrStats.addArrays(countArray);
                    if (!isConcise) {
                        List<Extension> extensions = createExtensions(columnConfig);
                        discrStats.addExtensions(extensions.toArray(new Extension[extensions.size()]));
                    }
                    univariateStats.setDiscrStats(discrStats);
                } else {
                    // numerical column
                    univariateStats.setNumericInfo(createNumericInfo(columnConfig));
                    if (!isConcise) {
                        univariateStats.setContStats(createConStats(columnConfig));
                    }
                }
                modelStats.addUnivariateStats(univariateStats);
            }
        }
    }
    return modelStats;
}
Also used : Array(org.dmg.pmml.Array) Extension(org.dmg.pmml.Extension) DiscrStats(org.dmg.pmml.DiscrStats) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) UnivariateStats(org.dmg.pmml.UnivariateStats) ModelStats(org.dmg.pmml.ModelStats) BasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)

Example 2 with ModelStats

use of org.dmg.pmml.ModelStats in project shifu by ShifuML.

the class PMMLConstructorFactory method produce.

public static PMMLTranslator produce(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, boolean isConcise, boolean isOutBaggingToOne) {
    AbstractPmmlElementCreator<Model> modelCreator = null;
    AbstractSpecifCreator specifCreator = null;
    if (ModelTrainConf.ALGORITHM.NN.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm())) {
        modelCreator = new NNPmmlModelCreator(modelConfig, columnConfigList, isConcise);
        specifCreator = new NNSpecifCreator(modelConfig, columnConfigList);
    } else if (ModelTrainConf.ALGORITHM.LR.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm())) {
        modelCreator = new RegressionPmmlModelCreator(modelConfig, columnConfigList, isConcise);
        specifCreator = new RegressionSpecifCreator(modelConfig, columnConfigList);
    } else if (ModelTrainConf.ALGORITHM.GBT.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm()) || ModelTrainConf.ALGORITHM.RF.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm())) {
        TreeEnsemblePmmlCreator gbtmodelCreator = new TreeEnsemblePmmlCreator(modelConfig, columnConfigList);
        AbstractPmmlElementCreator<DataDictionary> dataDictionaryCreator = new DataDictionaryCreator(modelConfig, columnConfigList);
        AbstractPmmlElementCreator<MiningSchema> miningSchemaCreator = new TreeModelMiningSchemaCreator(modelConfig, columnConfigList);
        return new TreeEnsemblePMMLTranslator(gbtmodelCreator, dataDictionaryCreator, miningSchemaCreator);
    } else {
        throw new RuntimeException("Model not supported: " + modelConfig.getTrain().getAlgorithm());
    }
    AbstractPmmlElementCreator<DataDictionary> dataDictionaryCreator = new DataDictionaryCreator(modelConfig, columnConfigList, isConcise);
    AbstractPmmlElementCreator<MiningSchema> miningSchemaCreator = new MiningSchemaCreator(modelConfig, columnConfigList, isConcise);
    AbstractPmmlElementCreator<ModelStats> modelStatsCreator = new ModelStatsCreator(modelConfig, columnConfigList, isConcise);
    AbstractPmmlElementCreator<LocalTransformations> localTransformationsCreator = null;
    ModelNormalizeConf.NormType normType = modelConfig.getNormalizeType();
    if (normType.equals(ModelNormalizeConf.NormType.WOE) || normType.equals(ModelNormalizeConf.NormType.WEIGHT_WOE)) {
        localTransformationsCreator = new WoeLocalTransformCreator(modelConfig, columnConfigList, isConcise);
    } else if (normType == ModelNormalizeConf.NormType.WOE_ZSCORE || normType == ModelNormalizeConf.NormType.WOE_ZSCALE) {
        localTransformationsCreator = new WoeZscoreLocalTransformCreator(modelConfig, columnConfigList, isConcise, false);
    } else if (normType == ModelNormalizeConf.NormType.WEIGHT_WOE_ZSCORE || normType == ModelNormalizeConf.NormType.WEIGHT_WOE_ZSCALE) {
        localTransformationsCreator = new WoeZscoreLocalTransformCreator(modelConfig, columnConfigList, isConcise, true);
    } else if (normType == ModelNormalizeConf.NormType.ZSCALE_ONEHOT) {
        localTransformationsCreator = new ZscoreOneHotLocalTransformCreator(modelConfig, columnConfigList, isConcise);
    } else {
        localTransformationsCreator = new ZscoreLocalTransformCreator(modelConfig, columnConfigList, isConcise);
    }
    return new PMMLTranslator(modelCreator, dataDictionaryCreator, miningSchemaCreator, modelStatsCreator, localTransformationsCreator, specifCreator, isOutBaggingToOne);
}
Also used : AbstractSpecifCreator(ml.shifu.shifu.core.pmml.builder.creator.AbstractSpecifCreator) TreeEnsemblePMMLTranslator(ml.shifu.shifu.core.pmml.TreeEnsemblePMMLTranslator) ModelNormalizeConf(ml.shifu.shifu.container.obj.ModelNormalizeConf) DataDictionary(org.dmg.pmml.DataDictionary) PMMLTranslator(ml.shifu.shifu.core.pmml.PMMLTranslator) TreeEnsemblePMMLTranslator(ml.shifu.shifu.core.pmml.TreeEnsemblePMMLTranslator) LocalTransformations(org.dmg.pmml.LocalTransformations) MiningSchema(org.dmg.pmml.MiningSchema) ModelStats(org.dmg.pmml.ModelStats) Model(org.dmg.pmml.Model)

Aggregations

ModelStats (org.dmg.pmml.ModelStats)2 ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)1 ModelNormalizeConf (ml.shifu.shifu.container.obj.ModelNormalizeConf)1 BasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)1 PMMLTranslator (ml.shifu.shifu.core.pmml.PMMLTranslator)1 TreeEnsemblePMMLTranslator (ml.shifu.shifu.core.pmml.TreeEnsemblePMMLTranslator)1 AbstractSpecifCreator (ml.shifu.shifu.core.pmml.builder.creator.AbstractSpecifCreator)1 Array (org.dmg.pmml.Array)1 DataDictionary (org.dmg.pmml.DataDictionary)1 DiscrStats (org.dmg.pmml.DiscrStats)1 Extension (org.dmg.pmml.Extension)1 LocalTransformations (org.dmg.pmml.LocalTransformations)1 MiningSchema (org.dmg.pmml.MiningSchema)1 Model (org.dmg.pmml.Model)1 UnivariateStats (org.dmg.pmml.UnivariateStats)1