Search in sources :

Example 16 with OutputField

use of org.dmg.pmml.OutputField in project jpmml-sparkml by jpmml.

the class ClassificationModelConverter method registerOutputFields.

@Override
public List<OutputField> registerOutputFields(Label label, SparkMLEncoder encoder) {
    T model = getTransformer();
    CategoricalLabel categoricalLabel = (CategoricalLabel) label;
    List<OutputField> result = new ArrayList<>();
    String predictionCol = model.getPredictionCol();
    OutputField pmmlPredictedField = ModelUtil.createPredictedField(FieldName.create("pmml(" + predictionCol + ")"), categoricalLabel.getDataType(), OpType.CATEGORICAL);
    result.add(pmmlPredictedField);
    List<String> categories = new ArrayList<>();
    DocumentBuilder documentBuilder = DOMUtil.createDocumentBuilder();
    InlineTable inlineTable = new InlineTable();
    List<String> columns = Arrays.asList("input", "output");
    for (int i = 0; i < categoricalLabel.size(); i++) {
        String value = categoricalLabel.getValue(i);
        String category = String.valueOf(i);
        categories.add(category);
        Row row = DOMUtil.createRow(documentBuilder, columns, Arrays.asList(value, category));
        inlineTable.addRows(row);
    }
    MapValues mapValues = new MapValues().addFieldColumnPairs(new FieldColumnPair(pmmlPredictedField.getName(), columns.get(0))).setOutputColumn(columns.get(1)).setInlineTable(inlineTable);
    final OutputField predictedField = new OutputField(FieldName.create(predictionCol), DataType.DOUBLE).setOpType(OpType.CATEGORICAL).setResultFeature(ResultFeature.TRANSFORMED_VALUE).setExpression(mapValues);
    result.add(predictedField);
    Feature feature = new CategoricalFeature(encoder, predictedField.getName(), predictedField.getDataType(), categories) {

        @Override
        public ContinuousFeature toContinuousFeature() {
            PMMLEncoder encoder = ensureEncoder();
            return new ContinuousFeature(encoder, getName(), getDataType());
        }
    };
    encoder.putOnlyFeature(predictionCol, feature);
    if (model instanceof HasProbabilityCol) {
        HasProbabilityCol hasProbabilityCol = (HasProbabilityCol) model;
        String probabilityCol = hasProbabilityCol.getProbabilityCol();
        List<Feature> features = new ArrayList<>();
        for (int i = 0; i < categoricalLabel.size(); i++) {
            String value = categoricalLabel.getValue(i);
            OutputField probabilityField = ModelUtil.createProbabilityField(FieldName.create(probabilityCol + "(" + value + ")"), DataType.DOUBLE, value);
            result.add(probabilityField);
            features.add(new ContinuousFeature(encoder, probabilityField.getName(), probabilityField.getDataType()));
        }
        encoder.putFeatures(probabilityCol, features);
    }
    return result;
}
Also used : InlineTable(org.dmg.pmml.InlineTable) HasProbabilityCol(org.apache.spark.ml.param.shared.HasProbabilityCol) PMMLEncoder(org.jpmml.converter.PMMLEncoder) ArrayList(java.util.ArrayList) FieldColumnPair(org.dmg.pmml.FieldColumnPair) ResultFeature(org.dmg.pmml.ResultFeature) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Feature(org.jpmml.converter.Feature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) ContinuousFeature(org.jpmml.converter.ContinuousFeature) DocumentBuilder(javax.xml.parsers.DocumentBuilder) MapValues(org.dmg.pmml.MapValues) CategoricalLabel(org.jpmml.converter.CategoricalLabel) OutputField(org.dmg.pmml.OutputField) Row(org.dmg.pmml.Row)

Example 17 with OutputField

use of org.dmg.pmml.OutputField in project jpmml-sparkml by jpmml.

the class ModelConverter method registerModel.

public org.dmg.pmml.Model registerModel(SparkMLEncoder encoder) {
    Schema schema = encodeSchema(encoder);
    Label label = schema.getLabel();
    org.dmg.pmml.Model model = encodeModel(schema);
    List<OutputField> sparkOutputFields = registerOutputFields(label, encoder);
    if (sparkOutputFields != null && sparkOutputFields.size() > 0) {
        org.dmg.pmml.Model lastModel = getLastModel(model);
        Output output = lastModel.getOutput();
        if (output == null) {
            output = new Output();
            lastModel.setOutput(output);
        }
        List<OutputField> outputFields = output.getOutputFields();
        outputFields.addAll(0, sparkOutputFields);
    }
    return model;
}
Also used : Schema(org.jpmml.converter.Schema) Output(org.dmg.pmml.Output) ContinuousLabel(org.jpmml.converter.ContinuousLabel) CategoricalLabel(org.jpmml.converter.CategoricalLabel) Label(org.jpmml.converter.Label) OutputField(org.dmg.pmml.OutputField)

Example 18 with OutputField

use of org.dmg.pmml.OutputField in project drools by kiegroup.

the class KiePMMLClassificationTableFactoryTest method getClassificationTableBuilder.

@Test
public void getClassificationTableBuilder() {
    RegressionTable regressionTableProf = getRegressionTable(3.5, "professional");
    RegressionTable regressionTableCler = getRegressionTable(27.4, "clerical");
    OutputField outputFieldCat = getOutputField("CAT-1", ResultFeature.PROBABILITY, "CatPred-1");
    OutputField outputFieldNum = getOutputField("NUM-1", ResultFeature.PROBABILITY, "NumPred-0");
    OutputField outputFieldPrev = getOutputField("PREV", ResultFeature.PREDICTED_VALUE, null);
    String targetField = "targetField";
    DataField dataField = new DataField();
    dataField.setName(FieldName.create(targetField));
    dataField.setOpType(OpType.CATEGORICAL);
    DataDictionary dataDictionary = new DataDictionary();
    dataDictionary.addDataFields(dataField);
    RegressionModel regressionModel = new RegressionModel();
    regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.CAUCHIT);
    regressionModel.addRegressionTables(regressionTableProf, regressionTableCler);
    regressionModel.setModelName(getGeneratedClassName("RegressionModel"));
    Output output = new Output();
    output.addOutputFields(outputFieldCat, outputFieldNum, outputFieldPrev);
    regressionModel.setOutput(output);
    MiningField miningField = new MiningField();
    miningField.setUsageType(MiningField.UsageType.TARGET);
    miningField.setName(dataField.getName());
    MiningSchema miningSchema = new MiningSchema();
    miningSchema.addMiningFields(miningField);
    regressionModel.setMiningSchema(miningSchema);
    PMML pmml = new PMML();
    pmml.setDataDictionary(dataDictionary);
    pmml.addModels(regressionModel);
    final CommonCompilationDTO<RegressionModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
    final RegressionCompilationDTO compilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(source, regressionModel.getRegressionTables(), regressionModel.getNormalizationMethod());
    final LinkedHashMap<String, KiePMMLTableSourceCategory> regressionTablesMap = new LinkedHashMap<>();
    regressionModel.getRegressionTables().forEach(regressionTable -> {
        String key = compilationDTO.getPackageName() + "." + regressionTable.getTargetCategory().toString().toUpperCase();
        KiePMMLTableSourceCategory value = new KiePMMLTableSourceCategory("", regressionTable.getTargetCategory().toString());
        regressionTablesMap.put(key, value);
    });
    Map.Entry<String, String> retrieved = KiePMMLClassificationTableFactory.getClassificationTableBuilder(compilationDTO, regressionTablesMap);
    assertNotNull(retrieved);
}
Also used : MiningField(org.dmg.pmml.MiningField) DataDictionary(org.dmg.pmml.DataDictionary) HasClassLoaderMock(org.kie.pmml.compiler.commons.mocks.HasClassLoaderMock) RegressionTable(org.dmg.pmml.regression.RegressionTable) RegressionModel(org.dmg.pmml.regression.RegressionModel) LinkedHashMap(java.util.LinkedHashMap) DataField(org.dmg.pmml.DataField) MiningSchema(org.dmg.pmml.MiningSchema) Output(org.dmg.pmml.Output) KiePMMLTableSourceCategory(org.kie.pmml.models.regression.model.tuples.KiePMMLTableSourceCategory) OutputField(org.dmg.pmml.OutputField) PMML(org.dmg.pmml.PMML) RegressionCompilationDTO(org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO) LinkedHashMap(java.util.LinkedHashMap) Map(java.util.Map) Test(org.junit.Test)

Example 19 with OutputField

use of org.dmg.pmml.OutputField in project drools by kiegroup.

the class KiePMMLClassificationTableFactoryTest method getOutputField.

private OutputField getOutputField(String name, ResultFeature resultFeature, String targetField) {
    OutputField toReturn = new OutputField();
    toReturn.setName(FieldName.create(name));
    toReturn.setResultFeature(resultFeature);
    if (targetField != null) {
        toReturn.setTargetField(FieldName.create(targetField));
    }
    return toReturn;
}
Also used : OutputField(org.dmg.pmml.OutputField)

Example 20 with OutputField

use of org.dmg.pmml.OutputField in project drools by kiegroup.

the class KiePMMLUtil method populateMissingOutputFieldDataType.

/**
 * Method to populate the <b>dataType</b> property of <code>OutputField</code>s.
 * Such property was optional until 4.4.1 spec
 * @param toPopulate
 * @param miningFields
 * @param dataFields
 */
static void populateMissingOutputFieldDataType(List<OutputField> toPopulate, List<MiningField> miningFields, List<DataField> dataFields) {
    // partial implementation to fix missing "dataType" inside OutputField; "dataType" became mandatory only in 4.4.1 version
    List<MiningField> targetFields = getMiningTargetFields(miningFields);
    toPopulate.stream().filter(outputField -> outputField.getDataType() == null).forEach(outputField -> {
        MiningField referencedField = null;
        if (outputField.getTargetField() != null) {
            referencedField = targetFields.stream().filter(targetField -> outputField.getTargetField().equals(targetField.getName())).findFirst().orElseThrow(() -> new KiePMMLException("Failed to find a target field for OutputField " + outputField.getName().getValue()));
        }
        if (referencedField == null && (outputField.getResultFeature() == null || outputField.getResultFeature().equals(ResultFeature.PREDICTED_VALUE))) {
            // default predictedValue
            referencedField = targetFields.stream().findFirst().orElse(// It is allowed to not have any "target" field inside MiningSchema
            null);
        }
        if (referencedField == null && ResultFeature.PROBABILITY.equals(outputField.getResultFeature())) {
            // we set the "dataType" to "double" because outputField is a "probability", we may return
            outputField.setDataType(DataType.DOUBLE);
            return;
        }
        if (referencedField != null) {
            FieldName targetFieldName = referencedField.getName();
            DataField dataField = dataFields.stream().filter(df -> df.getName().equals(targetFieldName)).findFirst().orElseThrow(() -> new KiePMMLException("Failed to find a DataField field for " + "MiningField " + targetFieldName.toString()));
            outputField.setDataType(dataField.getDataType());
        }
    });
}
Also used : PMML(org.dmg.pmml.PMML) Model(org.dmg.pmml.Model) OutputField(org.dmg.pmml.OutputField) Targets(org.dmg.pmml.Targets) DataType(org.dmg.pmml.DataType) ResultFeature(org.dmg.pmml.ResultFeature) MiningSchema(org.dmg.pmml.MiningSchema) Collectors(java.util.stream.Collectors) JAXBException(javax.xml.bind.JAXBException) Target(org.dmg.pmml.Target) DataField(org.dmg.pmml.DataField) FieldName(org.dmg.pmml.FieldName) OpType(org.dmg.pmml.OpType) List(java.util.List) Segment(org.dmg.pmml.mining.Segment) ByteArrayInputStream(java.io.ByteArrayInputStream) SAXException(org.xml.sax.SAXException) Optional(java.util.Optional) MiningFunction(org.dmg.pmml.MiningFunction) MiningField(org.dmg.pmml.MiningField) KiePMMLException(org.kie.pmml.api.exceptions.KiePMMLException) MathContext(org.dmg.pmml.MathContext) InputStream(java.io.InputStream) MiningModel(org.dmg.pmml.mining.MiningModel) MiningField(org.dmg.pmml.MiningField) DataField(org.dmg.pmml.DataField) KiePMMLException(org.kie.pmml.api.exceptions.KiePMMLException) FieldName(org.dmg.pmml.FieldName)

Aggregations

OutputField (org.dmg.pmml.OutputField)28 Test (org.junit.Test)10 DataField (org.dmg.pmml.DataField)9 MiningField (org.dmg.pmml.MiningField)9 MiningSchema (org.dmg.pmml.MiningSchema)7 Output (org.dmg.pmml.Output)7 PMML (org.dmg.pmml.PMML)7 DataDictionary (org.dmg.pmml.DataDictionary)4 FieldName (org.dmg.pmml.FieldName)4 ResultFeature (org.dmg.pmml.ResultFeature)4 MiningModel (org.dmg.pmml.mining.MiningModel)4 DATA_TYPE (org.kie.pmml.api.enums.DATA_TYPE)4 ByteArrayInputStream (java.io.ByteArrayInputStream)3 InputStream (java.io.InputStream)3 ArrayList (java.util.ArrayList)3 LinkedHashMap (java.util.LinkedHashMap)3 Collectors (java.util.stream.Collectors)3 Model (org.dmg.pmml.Model)3 OpType (org.dmg.pmml.OpType)3 RegressionModel (org.dmg.pmml.regression.RegressionModel)3