use of org.apache.spark.ml.feature.VectorIndexerModel in project jpmml-sparkml by jpmml.
the class VectorIndexerModelConverter method encodeFeatures.
@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
VectorIndexerModel transformer = getTransformer();
List<Feature> features = encoder.getFeatures(transformer.getInputCol());
int numFeatures = transformer.numFeatures();
if (numFeatures != features.size()) {
throw new IllegalArgumentException("Expected " + numFeatures + " features, got " + features.size() + " features");
}
Map<Integer, Map<Double, Integer>> categoryMaps = transformer.javaCategoryMaps();
List<Feature> result = new ArrayList<>();
for (int i = 0; i < numFeatures; i++) {
Feature feature = features.get(i);
Map<Double, Integer> categoryMap = categoryMaps.get(i);
if (categoryMap != null) {
List<String> categories = new ArrayList<>();
List<String> values = new ArrayList<>();
DocumentBuilder documentBuilder = DOMUtil.createDocumentBuilder();
InlineTable inlineTable = new InlineTable();
List<String> columns = Arrays.asList("input", "output");
List<Map.Entry<Double, Integer>> entries = new ArrayList<>(categoryMap.entrySet());
Collections.sort(entries, VectorIndexerModelConverter.COMPARATOR);
for (Map.Entry<Double, Integer> entry : entries) {
String category = ValueUtil.formatValue(entry.getKey());
categories.add(category);
String value = ValueUtil.formatValue(entry.getValue());
values.add(value);
Row row = DOMUtil.createRow(documentBuilder, columns, Arrays.asList(category, value));
inlineTable.addRows(row);
}
encoder.toCategorical(feature.getName(), categories);
MapValues mapValues = new MapValues().addFieldColumnPairs(new FieldColumnPair(feature.getName(), columns.get(0))).setOutputColumn(columns.get(1)).setInlineTable(inlineTable);
DerivedField derivedField = encoder.createDerivedField(formatName(transformer, i), OpType.CATEGORICAL, DataType.INTEGER, mapValues);
result.add(new CategoricalFeature(encoder, derivedField, values));
} else {
result.add((ContinuousFeature) feature);
}
}
return result;
}
Aggregations