Search in sources :

Example 1 with ImputerModel

use of org.apache.spark.ml.feature.ImputerModel in project jpmml-sparkml by jpmml.

the class ImputerModelConverter method registerFeatures.

@Override
public void registerFeatures(SparkMLEncoder encoder) {
    ImputerModel transformer = getTransformer();
    List<Feature> features = encodeFeatures(encoder);
    String[] outputCols = transformer.getOutputCols();
    if (outputCols.length != features.size()) {
        throw new IllegalArgumentException();
    }
    for (int i = 0; i < features.size(); i++) {
        String outputCol = outputCols[i];
        Feature feature = features.get(i);
        encoder.putFeatures(outputCol, Collections.singletonList(feature));
    }
}
Also used : ImputerModel(org.apache.spark.ml.feature.ImputerModel) Feature(org.jpmml.converter.Feature)

Example 2 with ImputerModel

use of org.apache.spark.ml.feature.ImputerModel in project jpmml-sparkml by jpmml.

the class ImputerModelConverter method encodeFeatures.

@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
    ImputerModel transformer = getTransformer();
    Double missingValue = transformer.getMissingValue();
    String strategy = transformer.getStrategy();
    Dataset<Row> surrogateDF = transformer.surrogateDF();
    String[] inputCols = transformer.getInputCols();
    String[] outputCols = transformer.getOutputCols();
    if (inputCols.length != outputCols.length) {
        throw new IllegalArgumentException();
    }
    MissingValueTreatmentMethod missingValueTreatmentMethod = parseStrategy(strategy);
    List<Row> surrogateRows = surrogateDF.collectAsList();
    if (surrogateRows.size() != 1) {
        throw new IllegalArgumentException();
    }
    Row surrogateRow = surrogateRows.get(0);
    List<Feature> result = new ArrayList<>();
    for (int i = 0; i < inputCols.length; i++) {
        String inputCol = inputCols[i];
        String outputCol = outputCols[i];
        Feature feature = encoder.getOnlyFeature(inputCol);
        Field<?> field = encoder.getField(feature.getName());
        if (field instanceof DataField) {
            DataField dataField = (DataField) field;
            Object surrogate = surrogateRow.getAs(inputCol);
            MissingValueDecorator missingValueDecorator = new MissingValueDecorator().setMissingValueReplacement(ValueUtil.formatValue(surrogate)).setMissingValueTreatment(missingValueTreatmentMethod);
            if (missingValue != null && !missingValue.isNaN()) {
                missingValueDecorator.addValues(ValueUtil.formatValue(missingValue));
            }
            encoder.addDecorator(feature.getName(), missingValueDecorator);
        } else {
            throw new IllegalArgumentException();
        }
        result.add(feature);
    }
    return result;
}
Also used : ArrayList(java.util.ArrayList) MissingValueDecorator(org.jpmml.converter.MissingValueDecorator) Feature(org.jpmml.converter.Feature) DataField(org.dmg.pmml.DataField) ImputerModel(org.apache.spark.ml.feature.ImputerModel) Row(org.apache.spark.sql.Row) MissingValueTreatmentMethod(org.dmg.pmml.MissingValueTreatmentMethod)

Aggregations

ImputerModel (org.apache.spark.ml.feature.ImputerModel)2 Feature (org.jpmml.converter.Feature)2 ArrayList (java.util.ArrayList)1 Row (org.apache.spark.sql.Row)1 DataField (org.dmg.pmml.DataField)1 MissingValueTreatmentMethod (org.dmg.pmml.MissingValueTreatmentMethod)1 MissingValueDecorator (org.jpmml.converter.MissingValueDecorator)1