use of org.dmg.pmml.ModelStats in project shifu by ShifuML.
the class ModelStatsCreator method build.
@Override
public ModelStats build(BasicML basicML) {
ModelStats modelStats = new ModelStats();
if (basicML instanceof BasicFloatNetwork) {
BasicFloatNetwork bfn = (BasicFloatNetwork) basicML;
Set<Integer> featureSet = bfn.getFeatureSet();
for (ColumnConfig columnConfig : columnConfigList) {
if (columnConfig.isFinalSelect() && (CollectionUtils.isEmpty(featureSet) || featureSet.contains(columnConfig.getColumnNum()))) {
UnivariateStats univariateStats = new UnivariateStats();
// here, no need to consider if column is in segment expansion
// as we need to address new stats variable
// set simple column name in PMML
univariateStats.setField(FieldName.create(NormalUtils.getSimpleColumnName(columnConfig.getColumnName())));
if (columnConfig.isCategorical()) {
DiscrStats discrStats = new DiscrStats();
Array countArray = createCountArray(columnConfig);
discrStats.addArrays(countArray);
if (!isConcise) {
List<Extension> extensions = createExtensions(columnConfig);
discrStats.addExtensions(extensions.toArray(new Extension[extensions.size()]));
}
univariateStats.setDiscrStats(discrStats);
} else {
// numerical column
univariateStats.setNumericInfo(createNumericInfo(columnConfig));
if (!isConcise) {
univariateStats.setContStats(createConStats(columnConfig));
}
}
modelStats.addUnivariateStats(univariateStats);
}
}
} else {
for (ColumnConfig columnConfig : columnConfigList) {
if (columnConfig.isFinalSelect()) {
UnivariateStats univariateStats = new UnivariateStats();
// here, no need to consider if column is in segment expansion as we need to address new stats
// variable
univariateStats.setField(FieldName.create(NormalUtils.getSimpleColumnName(columnConfig.getColumnName())));
if (columnConfig.isCategorical()) {
DiscrStats discrStats = new DiscrStats();
Array countArray = createCountArray(columnConfig);
discrStats.addArrays(countArray);
if (!isConcise) {
List<Extension> extensions = createExtensions(columnConfig);
discrStats.addExtensions(extensions.toArray(new Extension[extensions.size()]));
}
univariateStats.setDiscrStats(discrStats);
} else {
// numerical column
univariateStats.setNumericInfo(createNumericInfo(columnConfig));
if (!isConcise) {
univariateStats.setContStats(createConStats(columnConfig));
}
}
modelStats.addUnivariateStats(univariateStats);
}
}
}
return modelStats;
}
use of org.dmg.pmml.ModelStats in project shifu by ShifuML.
the class PMMLConstructorFactory method produce.
public static PMMLTranslator produce(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, boolean isConcise, boolean isOutBaggingToOne) {
AbstractPmmlElementCreator<Model> modelCreator = null;
AbstractSpecifCreator specifCreator = null;
if (ModelTrainConf.ALGORITHM.NN.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm())) {
modelCreator = new NNPmmlModelCreator(modelConfig, columnConfigList, isConcise);
specifCreator = new NNSpecifCreator(modelConfig, columnConfigList);
} else if (ModelTrainConf.ALGORITHM.LR.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm())) {
modelCreator = new RegressionPmmlModelCreator(modelConfig, columnConfigList, isConcise);
specifCreator = new RegressionSpecifCreator(modelConfig, columnConfigList);
} else if (ModelTrainConf.ALGORITHM.GBT.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm()) || ModelTrainConf.ALGORITHM.RF.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm())) {
TreeEnsemblePmmlCreator gbtmodelCreator = new TreeEnsemblePmmlCreator(modelConfig, columnConfigList);
AbstractPmmlElementCreator<DataDictionary> dataDictionaryCreator = new DataDictionaryCreator(modelConfig, columnConfigList);
AbstractPmmlElementCreator<MiningSchema> miningSchemaCreator = new TreeModelMiningSchemaCreator(modelConfig, columnConfigList);
return new TreeEnsemblePMMLTranslator(gbtmodelCreator, dataDictionaryCreator, miningSchemaCreator);
} else {
throw new RuntimeException("Model not supported: " + modelConfig.getTrain().getAlgorithm());
}
AbstractPmmlElementCreator<DataDictionary> dataDictionaryCreator = new DataDictionaryCreator(modelConfig, columnConfigList, isConcise);
AbstractPmmlElementCreator<MiningSchema> miningSchemaCreator = new MiningSchemaCreator(modelConfig, columnConfigList, isConcise);
AbstractPmmlElementCreator<ModelStats> modelStatsCreator = new ModelStatsCreator(modelConfig, columnConfigList, isConcise);
AbstractPmmlElementCreator<LocalTransformations> localTransformationsCreator = null;
ModelNormalizeConf.NormType normType = modelConfig.getNormalizeType();
if (normType.equals(ModelNormalizeConf.NormType.WOE) || normType.equals(ModelNormalizeConf.NormType.WEIGHT_WOE)) {
localTransformationsCreator = new WoeLocalTransformCreator(modelConfig, columnConfigList, isConcise);
} else if (normType == ModelNormalizeConf.NormType.WOE_ZSCORE || normType == ModelNormalizeConf.NormType.WOE_ZSCALE) {
localTransformationsCreator = new WoeZscoreLocalTransformCreator(modelConfig, columnConfigList, isConcise, false);
} else if (normType == ModelNormalizeConf.NormType.WEIGHT_WOE_ZSCORE || normType == ModelNormalizeConf.NormType.WEIGHT_WOE_ZSCALE) {
localTransformationsCreator = new WoeZscoreLocalTransformCreator(modelConfig, columnConfigList, isConcise, true);
} else if (normType == ModelNormalizeConf.NormType.ZSCALE_ONEHOT) {
localTransformationsCreator = new ZscoreOneHotLocalTransformCreator(modelConfig, columnConfigList, isConcise);
} else {
localTransformationsCreator = new ZscoreLocalTransformCreator(modelConfig, columnConfigList, isConcise);
}
return new PMMLTranslator(modelCreator, dataDictionaryCreator, miningSchemaCreator, modelStatsCreator, localTransformationsCreator, specifCreator, isOutBaggingToOne);
}
Aggregations