Search in sources :

Example 1 with LocalTransformations

use of org.dmg.pmml.LocalTransformations in project shifu by ShifuML.

the class PMMLConstructorFactory method produce.

public static PMMLTranslator produce(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, boolean isConcise, boolean isOutBaggingToOne) {
    AbstractPmmlElementCreator<Model> modelCreator = null;
    AbstractSpecifCreator specifCreator = null;
    if (ModelTrainConf.ALGORITHM.NN.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm())) {
        modelCreator = new NNPmmlModelCreator(modelConfig, columnConfigList, isConcise);
        specifCreator = new NNSpecifCreator(modelConfig, columnConfigList);
    } else if (ModelTrainConf.ALGORITHM.LR.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm())) {
        modelCreator = new RegressionPmmlModelCreator(modelConfig, columnConfigList, isConcise);
        specifCreator = new RegressionSpecifCreator(modelConfig, columnConfigList);
    } else if (ModelTrainConf.ALGORITHM.GBT.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm()) || ModelTrainConf.ALGORITHM.RF.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm())) {
        TreeEnsemblePmmlCreator gbtmodelCreator = new TreeEnsemblePmmlCreator(modelConfig, columnConfigList);
        AbstractPmmlElementCreator<DataDictionary> dataDictionaryCreator = new DataDictionaryCreator(modelConfig, columnConfigList);
        AbstractPmmlElementCreator<MiningSchema> miningSchemaCreator = new TreeModelMiningSchemaCreator(modelConfig, columnConfigList);
        return new TreeEnsemblePMMLTranslator(gbtmodelCreator, dataDictionaryCreator, miningSchemaCreator);
    } else {
        throw new RuntimeException("Model not supported: " + modelConfig.getTrain().getAlgorithm());
    }
    AbstractPmmlElementCreator<DataDictionary> dataDictionaryCreator = new DataDictionaryCreator(modelConfig, columnConfigList, isConcise);
    AbstractPmmlElementCreator<MiningSchema> miningSchemaCreator = new MiningSchemaCreator(modelConfig, columnConfigList, isConcise);
    AbstractPmmlElementCreator<ModelStats> modelStatsCreator = new ModelStatsCreator(modelConfig, columnConfigList, isConcise);
    AbstractPmmlElementCreator<LocalTransformations> localTransformationsCreator = null;
    ModelNormalizeConf.NormType normType = modelConfig.getNormalizeType();
    if (normType.equals(ModelNormalizeConf.NormType.WOE) || normType.equals(ModelNormalizeConf.NormType.WEIGHT_WOE)) {
        localTransformationsCreator = new WoeLocalTransformCreator(modelConfig, columnConfigList, isConcise);
    } else if (normType == ModelNormalizeConf.NormType.WOE_ZSCORE || normType == ModelNormalizeConf.NormType.WOE_ZSCALE) {
        localTransformationsCreator = new WoeZscoreLocalTransformCreator(modelConfig, columnConfigList, isConcise, false);
    } else if (normType == ModelNormalizeConf.NormType.WEIGHT_WOE_ZSCORE || normType == ModelNormalizeConf.NormType.WEIGHT_WOE_ZSCALE) {
        localTransformationsCreator = new WoeZscoreLocalTransformCreator(modelConfig, columnConfigList, isConcise, true);
    } else if (normType == ModelNormalizeConf.NormType.ZSCALE_ONEHOT) {
        localTransformationsCreator = new ZscoreOneHotLocalTransformCreator(modelConfig, columnConfigList, isConcise);
    } else {
        localTransformationsCreator = new ZscoreLocalTransformCreator(modelConfig, columnConfigList, isConcise);
    }
    return new PMMLTranslator(modelCreator, dataDictionaryCreator, miningSchemaCreator, modelStatsCreator, localTransformationsCreator, specifCreator, isOutBaggingToOne);
}
Also used : AbstractSpecifCreator(ml.shifu.shifu.core.pmml.builder.creator.AbstractSpecifCreator) TreeEnsemblePMMLTranslator(ml.shifu.shifu.core.pmml.TreeEnsemblePMMLTranslator) ModelNormalizeConf(ml.shifu.shifu.container.obj.ModelNormalizeConf) DataDictionary(org.dmg.pmml.DataDictionary) PMMLTranslator(ml.shifu.shifu.core.pmml.PMMLTranslator) TreeEnsemblePMMLTranslator(ml.shifu.shifu.core.pmml.TreeEnsemblePMMLTranslator) LocalTransformations(org.dmg.pmml.LocalTransformations) MiningSchema(org.dmg.pmml.MiningSchema) ModelStats(org.dmg.pmml.ModelStats) Model(org.dmg.pmml.Model)

Example 2 with LocalTransformations

use of org.dmg.pmml.LocalTransformations in project shifu by ShifuML.

the class NeuralNetworkModelIntegrator method getLocalTranformations.

private LocalTransformations getLocalTranformations(NeuralNetwork model) {
    // delete target
    List<DerivedField> derivedFields = model.getLocalTransformations().getDerivedFields();
    // add bias
    DerivedField field = new DerivedField(OpType.CONTINUOUS, DataType.DOUBLE).setName(new FieldName(PluginConstants.biasValue));
    field.setExpression(new Constant(String.valueOf(PluginConstants.bias)));
    derivedFields.add(field);
    return new LocalTransformations().addDerivedFields(derivedFields.toArray(new DerivedField[derivedFields.size()]));
}
Also used : LocalTransformations(org.dmg.pmml.LocalTransformations) Constant(org.dmg.pmml.Constant) DerivedField(org.dmg.pmml.DerivedField) FieldName(org.dmg.pmml.FieldName)

Example 3 with LocalTransformations

use of org.dmg.pmml.LocalTransformations in project drools by kiegroup.

the class PMMLModelTestUtils method getRandomLocalTransformations.

public static LocalTransformations getRandomLocalTransformations() {
    final LocalTransformations toReturn = new LocalTransformations();
    IntStream.range(0, 3).forEach(i -> {
        toReturn.addDerivedFields(getDerivedField("DerivedField-" + i));
    });
    return toReturn;
}
Also used : LocalTransformations(org.dmg.pmml.LocalTransformations)

Example 4 with LocalTransformations

use of org.dmg.pmml.LocalTransformations in project drools by kiegroup.

the class ModelUtils method getFieldsFromDataDictionaryTransformationDictionaryAndModel.

public static List<Field<?>> getFieldsFromDataDictionaryTransformationDictionaryAndModel(final DataDictionary dataDictionary, final TransformationDictionary transformationDictionary, final Model model) {
    final List<Field<?>> toReturn = getFieldsFromDataDictionaryAndTransformationDictionary(dataDictionary, transformationDictionary);
    LocalTransformations localTransformations = model.getLocalTransformations();
    if (localTransformations != null && localTransformations.hasDerivedFields()) {
        localTransformations.getDerivedFields().stream().map(Field.class::cast).forEach(toReturn::add);
    }
    Output output = model.getOutput();
    if (output != null && output.hasOutputFields()) {
        output.getOutputFields().stream().map(Field.class::cast).forEach(toReturn::add);
    }
    return toReturn;
}
Also used : OutputField(org.dmg.pmml.OutputField) TargetField(org.kie.pmml.api.models.TargetField) Field(org.dmg.pmml.Field) DerivedField(org.dmg.pmml.DerivedField) MiningField(org.dmg.pmml.MiningField) DataField(org.dmg.pmml.DataField) ParameterField(org.dmg.pmml.ParameterField) LocalTransformations(org.dmg.pmml.LocalTransformations) Output(org.dmg.pmml.Output)

Example 5 with LocalTransformations

use of org.dmg.pmml.LocalTransformations in project drools by kiegroup.

the class KiePMMLSegmentFactory method getFieldsFromModel.

static List<Field<?>> getFieldsFromModel(final Model model) {
    final List<Field<?>> toReturn = new ArrayList<>();
    LocalTransformations localTransformations = model.getLocalTransformations();
    if (localTransformations != null && localTransformations.hasDerivedFields()) {
        localTransformations.getDerivedFields().stream().map(Field.class::cast).forEach(toReturn::add);
    }
    Output output = model.getOutput();
    if (output != null && output.hasOutputFields()) {
        output.getOutputFields().stream().map(Field.class::cast).forEach(toReturn::add);
    }
    return toReturn;
}
Also used : Field(org.dmg.pmml.Field) LocalTransformations(org.dmg.pmml.LocalTransformations) Output(org.dmg.pmml.Output) ArrayList(java.util.ArrayList)

Aggregations

LocalTransformations (org.dmg.pmml.LocalTransformations)9 DerivedField (org.dmg.pmml.DerivedField)5 Field (org.dmg.pmml.Field)2 FieldName (org.dmg.pmml.FieldName)2 MiningField (org.dmg.pmml.MiningField)2 Output (org.dmg.pmml.Output)2 Test (org.junit.Test)2 KiePMMLLocalTransformations (org.kie.pmml.commons.transformations.KiePMMLLocalTransformations)2 BlockStmt (com.github.javaparser.ast.stmt.BlockStmt)1 Statement (com.github.javaparser.ast.stmt.Statement)1 ArrayList (java.util.ArrayList)1 Collections (java.util.Collections)1 HashMap (java.util.HashMap)1 List (java.util.List)1 Optional (java.util.Optional)1 ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)1 ModelNormalizeConf (ml.shifu.shifu.container.obj.ModelNormalizeConf)1 BasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)1 PMMLTranslator (ml.shifu.shifu.core.pmml.PMMLTranslator)1 TreeEnsemblePMMLTranslator (ml.shifu.shifu.core.pmml.TreeEnsemblePMMLTranslator)1