use of org.jpmml.converter.DerivedOutputField in project jpmml-sparkml by jpmml.
the class RegressionModelConverter method registerOutputFields.
@Override
public List<OutputField> registerOutputFields(Label label, Model pmmlModel, SparkMLEncoder encoder) {
T model = getTransformer();
String predictionCol = model.getPredictionCol();
Boolean keepPredictionCol = (Boolean) getOption(HasPredictionModelOptions.OPTION_KEEP_PREDICTIONCOL, Boolean.TRUE);
OutputField predictedOutputField = ModelUtil.createPredictedField(FieldName.create(predictionCol), OpType.CONTINUOUS, label.getDataType());
DerivedOutputField predictedField = encoder.createDerivedField(pmmlModel, predictedOutputField, keepPredictionCol);
encoder.putOnlyFeature(predictionCol, new ContinuousFeature(encoder, predictedField));
return Collections.emptyList();
}
use of org.jpmml.converter.DerivedOutputField in project jpmml-sparkml by jpmml.
the class ClusteringModelConverter method registerOutputFields.
@Override
public List<OutputField> registerOutputFields(Label label, org.dmg.pmml.Model pmmlModel, SparkMLEncoder encoder) {
T model = getTransformer();
List<Integer> clusters = LabelUtil.createTargetCategories(getNumberOfClusters());
String predictionCol = model.getPredictionCol();
OutputField pmmlPredictedOutputField = ModelUtil.createPredictedField(FieldNameUtil.create("pmml", predictionCol), OpType.CATEGORICAL, DataType.STRING).setFinalResult(false);
DerivedOutputField pmmlPredictedField = encoder.createDerivedField(pmmlModel, pmmlPredictedOutputField, true);
OutputField predictedOutputField = new OutputField(FieldName.create(predictionCol), OpType.CATEGORICAL, DataType.INTEGER).setResultFeature(ResultFeature.TRANSFORMED_VALUE).setExpression(new FieldRef(pmmlPredictedField.getName()));
DerivedOutputField predictedField = encoder.createDerivedField(pmmlModel, predictedOutputField, true);
encoder.putOnlyFeature(predictionCol, new IndexFeature(encoder, predictedField, clusters));
return Collections.emptyList();
}
use of org.jpmml.converter.DerivedOutputField in project jpmml-sparkml by jpmml.
the class ClassificationModelConverter method registerOutputFields.
@Override
public List<OutputField> registerOutputFields(Label label, Model pmmlModel, SparkMLEncoder encoder) {
T model = getTransformer();
CategoricalLabel categoricalLabel = (CategoricalLabel) label;
List<Integer> categories = LabelUtil.createTargetCategories(categoricalLabel.size());
String predictionCol = model.getPredictionCol();
Boolean keepPredictionCol = (Boolean) getOption(HasPredictionModelOptions.OPTION_KEEP_PREDICTIONCOL, Boolean.TRUE);
OutputField pmmlPredictedOutputField = ModelUtil.createPredictedField(FieldNameUtil.create("pmml", predictionCol), OpType.CATEGORICAL, categoricalLabel.getDataType()).setFinalResult(false);
DerivedOutputField pmmlPredictedField = encoder.createDerivedField(pmmlModel, pmmlPredictedOutputField, keepPredictionCol);
MapValues mapValues = PMMLUtil.createMapValues(pmmlPredictedField.getName(), categoricalLabel.getValues(), categories).setDataType(DataType.DOUBLE);
OutputField predictedOutputField = new OutputField(FieldName.create(predictionCol), OpType.CONTINUOUS, DataType.DOUBLE).setResultFeature(ResultFeature.TRANSFORMED_VALUE).setExpression(mapValues);
DerivedOutputField predictedField = encoder.createDerivedField(pmmlModel, predictedOutputField, keepPredictionCol);
encoder.putOnlyFeature(predictionCol, new IndexFeature(encoder, predictedField, categories));
List<OutputField> result = new ArrayList<>();
if (model instanceof HasProbabilityCol) {
HasProbabilityCol hasProbabilityCol = (HasProbabilityCol) model;
String probabilityCol = hasProbabilityCol.getProbabilityCol();
List<Feature> features = new ArrayList<>();
for (int i = 0; i < categoricalLabel.size(); i++) {
Object value = categoricalLabel.getValue(i);
OutputField probabilityField = ModelUtil.createProbabilityField(FieldNameUtil.create(probabilityCol, value), DataType.DOUBLE, value);
result.add(probabilityField);
features.add(new ContinuousFeature(encoder, probabilityField));
}
// XXX
encoder.putFeatures(probabilityCol, features);
}
return result;
}
Aggregations