use of org.dmg.pmml.MapValues in project jpmml-sparkml by jpmml.
the class ClassificationModelConverter method registerOutputFields.
@Override
public List<OutputField> registerOutputFields(Label label, Model pmmlModel, SparkMLEncoder encoder) {
T model = getTransformer();
CategoricalLabel categoricalLabel = (CategoricalLabel) label;
List<Integer> categories = LabelUtil.createTargetCategories(categoricalLabel.size());
String predictionCol = model.getPredictionCol();
Boolean keepPredictionCol = (Boolean) getOption(HasPredictionModelOptions.OPTION_KEEP_PREDICTIONCOL, Boolean.TRUE);
OutputField pmmlPredictedOutputField = ModelUtil.createPredictedField(FieldNameUtil.create("pmml", predictionCol), OpType.CATEGORICAL, categoricalLabel.getDataType()).setFinalResult(false);
DerivedOutputField pmmlPredictedField = encoder.createDerivedField(pmmlModel, pmmlPredictedOutputField, keepPredictionCol);
MapValues mapValues = PMMLUtil.createMapValues(pmmlPredictedField.getName(), categoricalLabel.getValues(), categories).setDataType(DataType.DOUBLE);
OutputField predictedOutputField = new OutputField(FieldName.create(predictionCol), OpType.CONTINUOUS, DataType.DOUBLE).setResultFeature(ResultFeature.TRANSFORMED_VALUE).setExpression(mapValues);
DerivedOutputField predictedField = encoder.createDerivedField(pmmlModel, predictedOutputField, keepPredictionCol);
encoder.putOnlyFeature(predictionCol, new IndexFeature(encoder, predictedField, categories));
List<OutputField> result = new ArrayList<>();
if (model instanceof HasProbabilityCol) {
HasProbabilityCol hasProbabilityCol = (HasProbabilityCol) model;
String probabilityCol = hasProbabilityCol.getProbabilityCol();
List<Feature> features = new ArrayList<>();
for (int i = 0; i < categoricalLabel.size(); i++) {
Object value = categoricalLabel.getValue(i);
OutputField probabilityField = ModelUtil.createProbabilityField(FieldNameUtil.create(probabilityCol, value), DataType.DOUBLE, value);
result.add(probabilityField);
features.add(new ContinuousFeature(encoder, probabilityField));
}
// XXX
encoder.putFeatures(probabilityCol, features);
}
return result;
}
Aggregations