Search in sources :

Example 1 with VectorIndexerModel

use of org.apache.spark.ml.feature.VectorIndexerModel in project jpmml-sparkml by jpmml.

the class VectorIndexerModelConverter method encodeFeatures.

@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
    VectorIndexerModel transformer = getTransformer();
    List<Feature> features = encoder.getFeatures(transformer.getInputCol());
    int numFeatures = transformer.numFeatures();
    if (numFeatures != features.size()) {
        throw new IllegalArgumentException("Expected " + numFeatures + " features, got " + features.size() + " features");
    }
    Map<Integer, Map<Double, Integer>> categoryMaps = transformer.javaCategoryMaps();
    List<Feature> result = new ArrayList<>();
    for (int i = 0; i < numFeatures; i++) {
        Feature feature = features.get(i);
        Map<Double, Integer> categoryMap = categoryMaps.get(i);
        if (categoryMap != null) {
            List<String> categories = new ArrayList<>();
            List<String> values = new ArrayList<>();
            DocumentBuilder documentBuilder = DOMUtil.createDocumentBuilder();
            InlineTable inlineTable = new InlineTable();
            List<String> columns = Arrays.asList("input", "output");
            List<Map.Entry<Double, Integer>> entries = new ArrayList<>(categoryMap.entrySet());
            Collections.sort(entries, VectorIndexerModelConverter.COMPARATOR);
            for (Map.Entry<Double, Integer> entry : entries) {
                String category = ValueUtil.formatValue(entry.getKey());
                categories.add(category);
                String value = ValueUtil.formatValue(entry.getValue());
                values.add(value);
                Row row = DOMUtil.createRow(documentBuilder, columns, Arrays.asList(category, value));
                inlineTable.addRows(row);
            }
            encoder.toCategorical(feature.getName(), categories);
            MapValues mapValues = new MapValues().addFieldColumnPairs(new FieldColumnPair(feature.getName(), columns.get(0))).setOutputColumn(columns.get(1)).setInlineTable(inlineTable);
            DerivedField derivedField = encoder.createDerivedField(formatName(transformer, i), OpType.CATEGORICAL, DataType.INTEGER, mapValues);
            result.add(new CategoricalFeature(encoder, derivedField, values));
        } else {
            result.add((ContinuousFeature) feature);
        }
    }
    return result;
}
Also used : InlineTable(org.dmg.pmml.InlineTable) ArrayList(java.util.ArrayList) FieldColumnPair(org.dmg.pmml.FieldColumnPair) Feature(org.jpmml.converter.Feature) ContinuousFeature(org.jpmml.converter.ContinuousFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) DocumentBuilder(javax.xml.parsers.DocumentBuilder) MapValues(org.dmg.pmml.MapValues) VectorIndexerModel(org.apache.spark.ml.feature.VectorIndexerModel) Row(org.dmg.pmml.Row) Map(java.util.Map) DerivedField(org.dmg.pmml.DerivedField)

Aggregations

ArrayList (java.util.ArrayList)1 Map (java.util.Map)1 DocumentBuilder (javax.xml.parsers.DocumentBuilder)1 VectorIndexerModel (org.apache.spark.ml.feature.VectorIndexerModel)1 DerivedField (org.dmg.pmml.DerivedField)1 FieldColumnPair (org.dmg.pmml.FieldColumnPair)1 InlineTable (org.dmg.pmml.InlineTable)1 MapValues (org.dmg.pmml.MapValues)1 Row (org.dmg.pmml.Row)1 CategoricalFeature (org.jpmml.converter.CategoricalFeature)1 ContinuousFeature (org.jpmml.converter.ContinuousFeature)1 Feature (org.jpmml.converter.Feature)1