use of org.dmg.pmml.LocalTransformations in project shifu by ShifuML.
the class PMMLConstructorFactory method produce.
public static PMMLTranslator produce(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, boolean isConcise, boolean isOutBaggingToOne) {
AbstractPmmlElementCreator<Model> modelCreator = null;
AbstractSpecifCreator specifCreator = null;
if (ModelTrainConf.ALGORITHM.NN.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm())) {
modelCreator = new NNPmmlModelCreator(modelConfig, columnConfigList, isConcise);
specifCreator = new NNSpecifCreator(modelConfig, columnConfigList);
} else if (ModelTrainConf.ALGORITHM.LR.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm())) {
modelCreator = new RegressionPmmlModelCreator(modelConfig, columnConfigList, isConcise);
specifCreator = new RegressionSpecifCreator(modelConfig, columnConfigList);
} else if (ModelTrainConf.ALGORITHM.GBT.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm()) || ModelTrainConf.ALGORITHM.RF.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm())) {
TreeEnsemblePmmlCreator gbtmodelCreator = new TreeEnsemblePmmlCreator(modelConfig, columnConfigList);
AbstractPmmlElementCreator<DataDictionary> dataDictionaryCreator = new DataDictionaryCreator(modelConfig, columnConfigList);
AbstractPmmlElementCreator<MiningSchema> miningSchemaCreator = new TreeModelMiningSchemaCreator(modelConfig, columnConfigList);
return new TreeEnsemblePMMLTranslator(gbtmodelCreator, dataDictionaryCreator, miningSchemaCreator);
} else {
throw new RuntimeException("Model not supported: " + modelConfig.getTrain().getAlgorithm());
}
AbstractPmmlElementCreator<DataDictionary> dataDictionaryCreator = new DataDictionaryCreator(modelConfig, columnConfigList, isConcise);
AbstractPmmlElementCreator<MiningSchema> miningSchemaCreator = new MiningSchemaCreator(modelConfig, columnConfigList, isConcise);
AbstractPmmlElementCreator<ModelStats> modelStatsCreator = new ModelStatsCreator(modelConfig, columnConfigList, isConcise);
AbstractPmmlElementCreator<LocalTransformations> localTransformationsCreator = null;
ModelNormalizeConf.NormType normType = modelConfig.getNormalizeType();
if (normType.equals(ModelNormalizeConf.NormType.WOE) || normType.equals(ModelNormalizeConf.NormType.WEIGHT_WOE)) {
localTransformationsCreator = new WoeLocalTransformCreator(modelConfig, columnConfigList, isConcise);
} else if (normType == ModelNormalizeConf.NormType.WOE_ZSCORE || normType == ModelNormalizeConf.NormType.WOE_ZSCALE) {
localTransformationsCreator = new WoeZscoreLocalTransformCreator(modelConfig, columnConfigList, isConcise, false);
} else if (normType == ModelNormalizeConf.NormType.WEIGHT_WOE_ZSCORE || normType == ModelNormalizeConf.NormType.WEIGHT_WOE_ZSCALE) {
localTransformationsCreator = new WoeZscoreLocalTransformCreator(modelConfig, columnConfigList, isConcise, true);
} else if (normType == ModelNormalizeConf.NormType.ZSCALE_ONEHOT) {
localTransformationsCreator = new ZscoreOneHotLocalTransformCreator(modelConfig, columnConfigList, isConcise);
} else {
localTransformationsCreator = new ZscoreLocalTransformCreator(modelConfig, columnConfigList, isConcise);
}
return new PMMLTranslator(modelCreator, dataDictionaryCreator, miningSchemaCreator, modelStatsCreator, localTransformationsCreator, specifCreator, isOutBaggingToOne);
}
use of org.dmg.pmml.LocalTransformations in project shifu by ShifuML.
the class NeuralNetworkModelIntegrator method getLocalTranformations.
private LocalTransformations getLocalTranformations(NeuralNetwork model) {
// delete target
List<DerivedField> derivedFields = model.getLocalTransformations().getDerivedFields();
// add bias
DerivedField field = new DerivedField(OpType.CONTINUOUS, DataType.DOUBLE).setName(new FieldName(PluginConstants.biasValue));
field.setExpression(new Constant(String.valueOf(PluginConstants.bias)));
derivedFields.add(field);
return new LocalTransformations().addDerivedFields(derivedFields.toArray(new DerivedField[derivedFields.size()]));
}
use of org.dmg.pmml.LocalTransformations in project drools by kiegroup.
the class PMMLModelTestUtils method getRandomLocalTransformations.
public static LocalTransformations getRandomLocalTransformations() {
final LocalTransformations toReturn = new LocalTransformations();
IntStream.range(0, 3).forEach(i -> {
toReturn.addDerivedFields(getDerivedField("DerivedField-" + i));
});
return toReturn;
}
use of org.dmg.pmml.LocalTransformations in project drools by kiegroup.
the class ModelUtils method getFieldsFromDataDictionaryTransformationDictionaryAndModel.
public static List<Field<?>> getFieldsFromDataDictionaryTransformationDictionaryAndModel(final DataDictionary dataDictionary, final TransformationDictionary transformationDictionary, final Model model) {
final List<Field<?>> toReturn = getFieldsFromDataDictionaryAndTransformationDictionary(dataDictionary, transformationDictionary);
LocalTransformations localTransformations = model.getLocalTransformations();
if (localTransformations != null && localTransformations.hasDerivedFields()) {
localTransformations.getDerivedFields().stream().map(Field.class::cast).forEach(toReturn::add);
}
Output output = model.getOutput();
if (output != null && output.hasOutputFields()) {
output.getOutputFields().stream().map(Field.class::cast).forEach(toReturn::add);
}
return toReturn;
}
use of org.dmg.pmml.LocalTransformations in project drools by kiegroup.
the class KiePMMLSegmentFactory method getFieldsFromModel.
static List<Field<?>> getFieldsFromModel(final Model model) {
final List<Field<?>> toReturn = new ArrayList<>();
LocalTransformations localTransformations = model.getLocalTransformations();
if (localTransformations != null && localTransformations.hasDerivedFields()) {
localTransformations.getDerivedFields().stream().map(Field.class::cast).forEach(toReturn::add);
}
Output output = model.getOutput();
if (output != null && output.hasOutputFields()) {
output.getOutputFields().stream().map(Field.class::cast).forEach(toReturn::add);
}
return toReturn;
}
Aggregations