Search in sources :

Example 16 with CategoricalFeature

use of org.jpmml.converter.CategoricalFeature in project jpmml-sparkml by jpmml.

the class StringIndexerModelConverter method encodeFeatures.

@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
    StringIndexerModel transformer = getTransformer();
    Feature feature = encoder.getOnlyFeature(transformer.getInputCol());
    List<String> categories = new ArrayList<>();
    categories.addAll(Arrays.asList(transformer.labels()));
    String handleInvalid = transformer.getHandleInvalid();
    Field<?> field = encoder.toCategorical(feature.getName(), categories);
    if (field instanceof DataField) {
        DataField dataField = (DataField) field;
        InvalidValueTreatmentMethod invalidValueTreatmentMethod;
        switch(handleInvalid) {
            case "keep":
                invalidValueTreatmentMethod = InvalidValueTreatmentMethod.AS_IS;
                break;
            case "error":
                invalidValueTreatmentMethod = InvalidValueTreatmentMethod.RETURN_INVALID;
                break;
            default:
                throw new IllegalArgumentException(handleInvalid);
        }
        InvalidValueDecorator invalidValueDecorator = new InvalidValueDecorator().setInvalidValueTreatment(invalidValueTreatmentMethod);
        encoder.addDecorator(dataField.getName(), invalidValueDecorator);
    } else if (field instanceof DerivedField) {
    // Ignored
    } else {
        throw new IllegalArgumentException();
    }
    switch(handleInvalid) {
        case "keep":
            Apply setApply = PMMLUtil.createApply("isIn", feature.ref());
            for (String category : categories) {
                setApply.addExpressions(PMMLUtil.createConstant(category, feature.getDataType()));
            }
            categories.add(StringIndexerModelConverter.LABEL_UNKNOWN);
            Apply apply = PMMLUtil.createApply("if", setApply, feature.ref(), PMMLUtil.createConstant(StringIndexerModelConverter.LABEL_UNKNOWN, DataType.STRING));
            field = encoder.createDerivedField(FeatureUtil.createName("handleInvalid", feature), OpType.CATEGORICAL, feature.getDataType(), apply);
            break;
        default:
            break;
    }
    return Collections.<Feature>singletonList(new CategoricalFeature(encoder, field, categories));
}
Also used : Apply(org.dmg.pmml.Apply) ArrayList(java.util.ArrayList) Feature(org.jpmml.converter.Feature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) StringIndexerModel(org.apache.spark.ml.feature.StringIndexerModel) CategoricalFeature(org.jpmml.converter.CategoricalFeature) InvalidValueTreatmentMethod(org.dmg.pmml.InvalidValueTreatmentMethod) InvalidValueDecorator(org.jpmml.converter.InvalidValueDecorator) DataField(org.dmg.pmml.DataField) DerivedField(org.dmg.pmml.DerivedField)

Aggregations

CategoricalFeature (org.jpmml.converter.CategoricalFeature)16 Feature (org.jpmml.converter.Feature)15 ContinuousFeature (org.jpmml.converter.ContinuousFeature)11 ArrayList (java.util.ArrayList)9 Predicate (org.dmg.pmml.Predicate)5 SimplePredicate (org.dmg.pmml.SimplePredicate)5 Node (org.dmg.pmml.tree.Node)5 DataField (org.dmg.pmml.DataField)4 DerivedField (org.dmg.pmml.DerivedField)4 BinaryFeature (org.jpmml.converter.BinaryFeature)4 BooleanFeature (org.jpmml.converter.BooleanFeature)3 CategoricalLabel (org.jpmml.converter.CategoricalLabel)3 List (java.util.List)2 DocumentBuilder (javax.xml.parsers.DocumentBuilder)2 Apply (org.dmg.pmml.Apply)2 FieldColumnPair (org.dmg.pmml.FieldColumnPair)2 FieldName (org.dmg.pmml.FieldName)2 InlineTable (org.dmg.pmml.InlineTable)2 MapValues (org.dmg.pmml.MapValues)2 OutputField (org.dmg.pmml.OutputField)2