Search in sources :

Example 6 with Model

use of org.dmg.pmml.Model in project shifu by ShifuML.

the class PMMLConstructorFactory method produce.

public static PMMLTranslator produce(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, boolean isConcise, boolean isOutBaggingToOne) {
    AbstractPmmlElementCreator<Model> modelCreator = null;
    AbstractSpecifCreator specifCreator = null;
    if (ModelTrainConf.ALGORITHM.NN.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm())) {
        modelCreator = new NNPmmlModelCreator(modelConfig, columnConfigList, isConcise);
        specifCreator = new NNSpecifCreator(modelConfig, columnConfigList);
    } else if (ModelTrainConf.ALGORITHM.LR.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm())) {
        modelCreator = new RegressionPmmlModelCreator(modelConfig, columnConfigList, isConcise);
        specifCreator = new RegressionSpecifCreator(modelConfig, columnConfigList);
    } else if (ModelTrainConf.ALGORITHM.GBT.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm()) || ModelTrainConf.ALGORITHM.RF.name().equalsIgnoreCase(modelConfig.getTrain().getAlgorithm())) {
        TreeEnsemblePmmlCreator gbtmodelCreator = new TreeEnsemblePmmlCreator(modelConfig, columnConfigList);
        AbstractPmmlElementCreator<DataDictionary> dataDictionaryCreator = new DataDictionaryCreator(modelConfig, columnConfigList);
        AbstractPmmlElementCreator<MiningSchema> miningSchemaCreator = new TreeModelMiningSchemaCreator(modelConfig, columnConfigList);
        return new TreeEnsemblePMMLTranslator(gbtmodelCreator, dataDictionaryCreator, miningSchemaCreator);
    } else {
        throw new RuntimeException("Model not supported: " + modelConfig.getTrain().getAlgorithm());
    }
    AbstractPmmlElementCreator<DataDictionary> dataDictionaryCreator = new DataDictionaryCreator(modelConfig, columnConfigList, isConcise);
    AbstractPmmlElementCreator<MiningSchema> miningSchemaCreator = new MiningSchemaCreator(modelConfig, columnConfigList, isConcise);
    AbstractPmmlElementCreator<ModelStats> modelStatsCreator = new ModelStatsCreator(modelConfig, columnConfigList, isConcise);
    AbstractPmmlElementCreator<LocalTransformations> localTransformationsCreator = null;
    ModelNormalizeConf.NormType normType = modelConfig.getNormalizeType();
    if (normType.equals(ModelNormalizeConf.NormType.WOE) || normType.equals(ModelNormalizeConf.NormType.WEIGHT_WOE)) {
        localTransformationsCreator = new WoeLocalTransformCreator(modelConfig, columnConfigList, isConcise);
    } else if (normType == ModelNormalizeConf.NormType.WOE_ZSCORE || normType == ModelNormalizeConf.NormType.WOE_ZSCALE) {
        localTransformationsCreator = new WoeZscoreLocalTransformCreator(modelConfig, columnConfigList, isConcise, false);
    } else if (normType == ModelNormalizeConf.NormType.WEIGHT_WOE_ZSCORE || normType == ModelNormalizeConf.NormType.WEIGHT_WOE_ZSCALE) {
        localTransformationsCreator = new WoeZscoreLocalTransformCreator(modelConfig, columnConfigList, isConcise, true);
    } else if (normType == ModelNormalizeConf.NormType.ZSCALE_ONEHOT) {
        localTransformationsCreator = new ZscoreOneHotLocalTransformCreator(modelConfig, columnConfigList, isConcise);
    } else {
        localTransformationsCreator = new ZscoreLocalTransformCreator(modelConfig, columnConfigList, isConcise);
    }
    return new PMMLTranslator(modelCreator, dataDictionaryCreator, miningSchemaCreator, modelStatsCreator, localTransformationsCreator, specifCreator, isOutBaggingToOne);
}
Also used : AbstractSpecifCreator(ml.shifu.shifu.core.pmml.builder.creator.AbstractSpecifCreator) TreeEnsemblePMMLTranslator(ml.shifu.shifu.core.pmml.TreeEnsemblePMMLTranslator) ModelNormalizeConf(ml.shifu.shifu.container.obj.ModelNormalizeConf) DataDictionary(org.dmg.pmml.DataDictionary) PMMLTranslator(ml.shifu.shifu.core.pmml.PMMLTranslator) TreeEnsemblePMMLTranslator(ml.shifu.shifu.core.pmml.TreeEnsemblePMMLTranslator) LocalTransformations(org.dmg.pmml.LocalTransformations) MiningSchema(org.dmg.pmml.MiningSchema) ModelStats(org.dmg.pmml.ModelStats) Model(org.dmg.pmml.Model)

Example 7 with Model

use of org.dmg.pmml.Model in project shifu by ShifuML.

the class NNPmmlModelCreator method build.

@Override
public Model build(BasicML basicML) {
    Model model = new NeuralNetwork();
    /*        if ( modelConfig.isClassification() &&
                ModelTrainConf.MultipleClassification.NATIVE.equals(modelConfig.getTrain().getMultiClassifyMethod())) {
            model.setFunctionName(MiningFunctionType.CLASSIFICATION);
        } else {*/
    model.setMiningFunction(MiningFunction.REGRESSION);
    /*        }*/
    model.setTargets(createTargets());
    return model;
}
Also used : Model(org.dmg.pmml.Model) NeuralNetwork(org.dmg.pmml.neural_network.NeuralNetwork)

Example 8 with Model

use of org.dmg.pmml.Model in project shifu by ShifuML.

the class PMMLVerifySuit method evalLRPmml.

@SuppressWarnings("unchecked")
private void evalLRPmml(String pmmlPath, String DataPath, String OutPath, String sep, String scoreName) throws Exception {
    PMML pmml = PMMLUtils.loadPMML(pmmlPath);
    Model m = pmml.getModels().get(0);
    ModelEvaluator<?> evaluator = ModelEvaluatorFactory.newInstance().newModelEvaluator(pmml, m);
    PrintWriter writer = new PrintWriter(OutPath, "UTF-8");
    writer.println(scoreName);
    List<Map<FieldName, FieldValue>> input = CsvUtil.load(evaluator, DataPath, sep);
    for (Map<FieldName, FieldValue> maps : input) {
        Map<FieldName, Double> regressionTerm = (Map<FieldName, Double>) evaluator.evaluate(maps);
        writer.println(regressionTerm.get(new FieldName(NNSpecifCreator.FINAL_RESULT)).intValue());
    }
    IOUtils.closeQuietly(writer);
}
Also used : Model(org.dmg.pmml.Model) PMML(org.dmg.pmml.PMML) FieldValue(org.jpmml.evaluator.FieldValue) HashMap(java.util.HashMap) Map(java.util.Map) FieldName(org.dmg.pmml.FieldName) PrintWriter(java.io.PrintWriter)

Example 9 with Model

use of org.dmg.pmml.Model in project jpmml-r by jpmml.

the class ModelConverter method encode.

public Model encode(Schema schema) {
    Model model = encodeModel(schema);
    if (this instanceof HasFeatureImportances) {
        HasFeatureImportances hasFeatureImportances = (HasFeatureImportances) this;
        FeatureImportanceMap featureImportances = hasFeatureImportances.getFeatureImportances(schema);
        if (featureImportances != null && !featureImportances.isEmpty()) {
            ModelEncoder encoder = (ModelEncoder) schema.getEncoder();
            Collection<Map.Entry<Feature, Number>> entries = featureImportances.entrySet();
            for (Map.Entry<Feature, Number> entry : entries) {
                encoder.addFeatureImportance(model, entry.getKey(), entry.getValue());
            }
        }
    }
    return model;
}
Also used : ModelEncoder(org.jpmml.converter.ModelEncoder) Model(org.dmg.pmml.Model) LinkedHashMap(java.util.LinkedHashMap) Map(java.util.Map) Feature(org.jpmml.converter.Feature)

Example 10 with Model

use of org.dmg.pmml.Model in project jpmml-r by jpmml.

the class CaretEnsembleConverter method encodePMML.

@Override
public PMML encodePMML(RExpEncoder encoder) {
    RGenericVector caretEnsemble = getObject();
    RGenericVector models = caretEnsemble.getGenericElement("models");
    RGenericVector ensModel = caretEnsemble.getGenericElement("ens_model");
    RStringVector modelNames = models.names();
    List<Model> segmentationModels = new ArrayList<>();
    Function<Schema, Schema> segmentSchemaFunction = new Function<Schema, Schema>() {

        @Override
        public Schema apply(Schema schema) {
            Label label = schema.getLabel();
            if (label instanceof ContinuousLabel) {
                return schema.toAnonymousSchema();
            } else // XXX: Ideally, the categorical target field should also be anonymized
            if (label instanceof CategoricalLabel) {
                return schema;
            } else {
                throw new IllegalArgumentException();
            }
        }
    };
    for (int i = 0; i < models.size(); i++) {
        RGenericVector model = models.getGenericValue(i);
        Conversion conversion = encodeTrainModel(model, segmentSchemaFunction);
        RExpEncoder segmentEncoder = conversion.getEncoder();
        encoder.addFields(segmentEncoder);
        Schema segmentSchema = conversion.getSchema();
        Model segmentModel = conversion.getModel();
        String name = modelNames.getValue(i);
        OutputField outputField;
        MiningFunction miningFunction = segmentModel.requireMiningFunction();
        switch(miningFunction) {
            case REGRESSION:
                {
                    outputField = ModelUtil.createPredictedField(name, OpType.CONTINUOUS, DataType.DOUBLE).setFinalResult(Boolean.FALSE);
                }
                break;
            case CLASSIFICATION:
                {
                    CategoricalLabel categoricalLabel = (CategoricalLabel) segmentSchema.getLabel();
                    SchemaUtil.checkSize(2, categoricalLabel);
                    outputField = ModelUtil.createProbabilityField(name, DataType.DOUBLE, categoricalLabel.getValue(1)).setFinalResult(Boolean.FALSE);
                }
                break;
            default:
                throw new IllegalArgumentException();
        }
        Output output = new Output().addOutputFields(outputField);
        segmentModel.setOutput(output);
        segmentationModels.add(segmentModel);
    }
    Conversion conversion = encodeTrainModel(ensModel, null);
    Model model = conversion.getModel();
    segmentationModels.add(model);
    MiningModel miningModel = MiningModelUtil.createModelChain(segmentationModels);
    PMML pmml = encoder.encodePMML(miningModel);
    return pmml;
}
Also used : Schema(org.jpmml.converter.Schema) ArrayList(java.util.ArrayList) CategoricalLabel(org.jpmml.converter.CategoricalLabel) ContinuousLabel(org.jpmml.converter.ContinuousLabel) Label(org.jpmml.converter.Label) Function(java.util.function.Function) MiningFunction(org.dmg.pmml.MiningFunction) MiningModel(org.dmg.pmml.mining.MiningModel) CategoricalLabel(org.jpmml.converter.CategoricalLabel) Output(org.dmg.pmml.Output) Model(org.dmg.pmml.Model) MiningModel(org.dmg.pmml.mining.MiningModel) OutputField(org.dmg.pmml.OutputField) PMML(org.dmg.pmml.PMML) ContinuousLabel(org.jpmml.converter.ContinuousLabel) MiningFunction(org.dmg.pmml.MiningFunction)

Aggregations

Model (org.dmg.pmml.Model)58 Test (org.junit.Test)33 MiningModel (org.dmg.pmml.mining.MiningModel)23 RegressionModel (org.dmg.pmml.regression.RegressionModel)23 DataField (org.dmg.pmml.DataField)22 MiningField (org.dmg.pmml.MiningField)21 PMML (org.dmg.pmml.PMML)21 DataDictionary (org.dmg.pmml.DataDictionary)18 MiningSchema (org.dmg.pmml.MiningSchema)18 CommonTestingUtils.getFieldsFromDataDictionary (org.kie.pmml.compiler.api.CommonTestingUtils.getFieldsFromDataDictionary)17 PMMLModelTestUtils.getDataField (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getDataField)17 PMMLModelTestUtils.getRandomDataField (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomDataField)17 PMMLModelTestUtils.getMiningField (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getMiningField)15 PMMLModelTestUtils.getRandomMiningField (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomMiningField)15 InputStream (java.io.InputStream)13 ArrayList (java.util.ArrayList)13 HashMap (java.util.HashMap)11 Map (java.util.Map)11 FileUtils.getFileInputStream (org.kie.test.util.filesystem.FileUtils.getFileInputStream)11 List (java.util.List)10