use of org.jpmml.converter.MissingValueDecorator in project jpmml-sparkml by jpmml.
the class ImputerModelConverter method encodeFeatures.
@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
ImputerModel transformer = getTransformer();
Double missingValue = transformer.getMissingValue();
String strategy = transformer.getStrategy();
Dataset<Row> surrogateDF = transformer.surrogateDF();
String[] inputCols = transformer.getInputCols();
String[] outputCols = transformer.getOutputCols();
if (inputCols.length != outputCols.length) {
throw new IllegalArgumentException();
}
MissingValueTreatmentMethod missingValueTreatmentMethod = parseStrategy(strategy);
List<Row> surrogateRows = surrogateDF.collectAsList();
if (surrogateRows.size() != 1) {
throw new IllegalArgumentException();
}
Row surrogateRow = surrogateRows.get(0);
List<Feature> result = new ArrayList<>();
for (int i = 0; i < inputCols.length; i++) {
String inputCol = inputCols[i];
String outputCol = outputCols[i];
Feature feature = encoder.getOnlyFeature(inputCol);
Field<?> field = encoder.getField(feature.getName());
if (field instanceof DataField) {
DataField dataField = (DataField) field;
Object surrogate = surrogateRow.getAs(inputCol);
MissingValueDecorator missingValueDecorator = new MissingValueDecorator().setMissingValueReplacement(ValueUtil.formatValue(surrogate)).setMissingValueTreatment(missingValueTreatmentMethod);
if (missingValue != null && !missingValue.isNaN()) {
missingValueDecorator.addValues(ValueUtil.formatValue(missingValue));
}
encoder.addDecorator(feature.getName(), missingValueDecorator);
} else {
throw new IllegalArgumentException();
}
result.add(feature);
}
return result;
}
Aggregations