use of org.dmg.pmml.OutputField in project jpmml-sparkml by jpmml.
the class ClassificationModelConverter method registerOutputFields.
@Override
public List<OutputField> registerOutputFields(Label label, SparkMLEncoder encoder) {
T model = getTransformer();
CategoricalLabel categoricalLabel = (CategoricalLabel) label;
List<OutputField> result = new ArrayList<>();
String predictionCol = model.getPredictionCol();
OutputField pmmlPredictedField = ModelUtil.createPredictedField(FieldName.create("pmml(" + predictionCol + ")"), categoricalLabel.getDataType(), OpType.CATEGORICAL);
result.add(pmmlPredictedField);
List<String> categories = new ArrayList<>();
DocumentBuilder documentBuilder = DOMUtil.createDocumentBuilder();
InlineTable inlineTable = new InlineTable();
List<String> columns = Arrays.asList("input", "output");
for (int i = 0; i < categoricalLabel.size(); i++) {
String value = categoricalLabel.getValue(i);
String category = String.valueOf(i);
categories.add(category);
Row row = DOMUtil.createRow(documentBuilder, columns, Arrays.asList(value, category));
inlineTable.addRows(row);
}
MapValues mapValues = new MapValues().addFieldColumnPairs(new FieldColumnPair(pmmlPredictedField.getName(), columns.get(0))).setOutputColumn(columns.get(1)).setInlineTable(inlineTable);
final OutputField predictedField = new OutputField(FieldName.create(predictionCol), DataType.DOUBLE).setOpType(OpType.CATEGORICAL).setResultFeature(ResultFeature.TRANSFORMED_VALUE).setExpression(mapValues);
result.add(predictedField);
Feature feature = new CategoricalFeature(encoder, predictedField.getName(), predictedField.getDataType(), categories) {
@Override
public ContinuousFeature toContinuousFeature() {
PMMLEncoder encoder = ensureEncoder();
return new ContinuousFeature(encoder, getName(), getDataType());
}
};
encoder.putOnlyFeature(predictionCol, feature);
if (model instanceof HasProbabilityCol) {
HasProbabilityCol hasProbabilityCol = (HasProbabilityCol) model;
String probabilityCol = hasProbabilityCol.getProbabilityCol();
List<Feature> features = new ArrayList<>();
for (int i = 0; i < categoricalLabel.size(); i++) {
String value = categoricalLabel.getValue(i);
OutputField probabilityField = ModelUtil.createProbabilityField(FieldName.create(probabilityCol + "(" + value + ")"), DataType.DOUBLE, value);
result.add(probabilityField);
features.add(new ContinuousFeature(encoder, probabilityField.getName(), probabilityField.getDataType()));
}
encoder.putFeatures(probabilityCol, features);
}
return result;
}
use of org.dmg.pmml.OutputField in project jpmml-sparkml by jpmml.
the class ModelConverter method registerModel.
public org.dmg.pmml.Model registerModel(SparkMLEncoder encoder) {
Schema schema = encodeSchema(encoder);
Label label = schema.getLabel();
org.dmg.pmml.Model model = encodeModel(schema);
List<OutputField> sparkOutputFields = registerOutputFields(label, encoder);
if (sparkOutputFields != null && sparkOutputFields.size() > 0) {
org.dmg.pmml.Model lastModel = getLastModel(model);
Output output = lastModel.getOutput();
if (output == null) {
output = new Output();
lastModel.setOutput(output);
}
List<OutputField> outputFields = output.getOutputFields();
outputFields.addAll(0, sparkOutputFields);
}
return model;
}
use of org.dmg.pmml.OutputField in project drools by kiegroup.
the class KiePMMLClassificationTableFactoryTest method getClassificationTableBuilder.
@Test
public void getClassificationTableBuilder() {
RegressionTable regressionTableProf = getRegressionTable(3.5, "professional");
RegressionTable regressionTableCler = getRegressionTable(27.4, "clerical");
OutputField outputFieldCat = getOutputField("CAT-1", ResultFeature.PROBABILITY, "CatPred-1");
OutputField outputFieldNum = getOutputField("NUM-1", ResultFeature.PROBABILITY, "NumPred-0");
OutputField outputFieldPrev = getOutputField("PREV", ResultFeature.PREDICTED_VALUE, null);
String targetField = "targetField";
DataField dataField = new DataField();
dataField.setName(FieldName.create(targetField));
dataField.setOpType(OpType.CATEGORICAL);
DataDictionary dataDictionary = new DataDictionary();
dataDictionary.addDataFields(dataField);
RegressionModel regressionModel = new RegressionModel();
regressionModel.setNormalizationMethod(RegressionModel.NormalizationMethod.CAUCHIT);
regressionModel.addRegressionTables(regressionTableProf, regressionTableCler);
regressionModel.setModelName(getGeneratedClassName("RegressionModel"));
Output output = new Output();
output.addOutputFields(outputFieldCat, outputFieldNum, outputFieldPrev);
regressionModel.setOutput(output);
MiningField miningField = new MiningField();
miningField.setUsageType(MiningField.UsageType.TARGET);
miningField.setName(dataField.getName());
MiningSchema miningSchema = new MiningSchema();
miningSchema.addMiningFields(miningField);
regressionModel.setMiningSchema(miningSchema);
PMML pmml = new PMML();
pmml.setDataDictionary(dataDictionary);
pmml.addModels(regressionModel);
final CommonCompilationDTO<RegressionModel> source = CommonCompilationDTO.fromGeneratedPackageNameAndFields(PACKAGE_NAME, pmml, regressionModel, new HasClassLoaderMock());
final RegressionCompilationDTO compilationDTO = RegressionCompilationDTO.fromCompilationDTORegressionTablesAndNormalizationMethod(source, regressionModel.getRegressionTables(), regressionModel.getNormalizationMethod());
final LinkedHashMap<String, KiePMMLTableSourceCategory> regressionTablesMap = new LinkedHashMap<>();
regressionModel.getRegressionTables().forEach(regressionTable -> {
String key = compilationDTO.getPackageName() + "." + regressionTable.getTargetCategory().toString().toUpperCase();
KiePMMLTableSourceCategory value = new KiePMMLTableSourceCategory("", regressionTable.getTargetCategory().toString());
regressionTablesMap.put(key, value);
});
Map.Entry<String, String> retrieved = KiePMMLClassificationTableFactory.getClassificationTableBuilder(compilationDTO, regressionTablesMap);
assertNotNull(retrieved);
}
use of org.dmg.pmml.OutputField in project drools by kiegroup.
the class KiePMMLClassificationTableFactoryTest method getOutputField.
private OutputField getOutputField(String name, ResultFeature resultFeature, String targetField) {
OutputField toReturn = new OutputField();
toReturn.setName(FieldName.create(name));
toReturn.setResultFeature(resultFeature);
if (targetField != null) {
toReturn.setTargetField(FieldName.create(targetField));
}
return toReturn;
}
use of org.dmg.pmml.OutputField in project drools by kiegroup.
the class KiePMMLUtil method populateMissingOutputFieldDataType.
/**
* Method to populate the <b>dataType</b> property of <code>OutputField</code>s.
* Such property was optional until 4.4.1 spec
* @param toPopulate
* @param miningFields
* @param dataFields
*/
static void populateMissingOutputFieldDataType(List<OutputField> toPopulate, List<MiningField> miningFields, List<DataField> dataFields) {
// partial implementation to fix missing "dataType" inside OutputField; "dataType" became mandatory only in 4.4.1 version
List<MiningField> targetFields = getMiningTargetFields(miningFields);
toPopulate.stream().filter(outputField -> outputField.getDataType() == null).forEach(outputField -> {
MiningField referencedField = null;
if (outputField.getTargetField() != null) {
referencedField = targetFields.stream().filter(targetField -> outputField.getTargetField().equals(targetField.getName())).findFirst().orElseThrow(() -> new KiePMMLException("Failed to find a target field for OutputField " + outputField.getName().getValue()));
}
if (referencedField == null && (outputField.getResultFeature() == null || outputField.getResultFeature().equals(ResultFeature.PREDICTED_VALUE))) {
// default predictedValue
referencedField = targetFields.stream().findFirst().orElse(// It is allowed to not have any "target" field inside MiningSchema
null);
}
if (referencedField == null && ResultFeature.PROBABILITY.equals(outputField.getResultFeature())) {
// we set the "dataType" to "double" because outputField is a "probability", we may return
outputField.setDataType(DataType.DOUBLE);
return;
}
if (referencedField != null) {
FieldName targetFieldName = referencedField.getName();
DataField dataField = dataFields.stream().filter(df -> df.getName().equals(targetFieldName)).findFirst().orElseThrow(() -> new KiePMMLException("Failed to find a DataField field for " + "MiningField " + targetFieldName.toString()));
outputField.setDataType(dataField.getDataType());
}
});
}
Aggregations