use of org.dmg.pmml.TextIndex in project jpmml-sparkml by jpmml.
the class CountVectorizerModelConverter method encodeFeatures.
@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
CountVectorizerModel transformer = getTransformer();
DocumentFeature documentFeature = (DocumentFeature) encoder.getOnlyFeature(transformer.getInputCol());
ParameterField documentField = new ParameterField(FieldName.create("document"));
ParameterField termField = new ParameterField(FieldName.create("term"));
TextIndex textIndex = new TextIndex(documentField.getName()).setTokenize(Boolean.TRUE).setWordSeparatorCharacterRE(documentFeature.getWordSeparatorRE()).setLocalTermWeights(transformer.getBinary() ? TextIndex.LocalTermWeights.BINARY : null).setExpression(new FieldRef(termField.getName()));
Set<DocumentFeature.StopWordSet> stopWordSets = documentFeature.getStopWordSets();
for (DocumentFeature.StopWordSet stopWordSet : stopWordSets) {
if (stopWordSet.isEmpty()) {
continue;
}
DocumentBuilder documentBuilder = DOMUtil.createDocumentBuilder();
String tokenRE;
String wordSeparatorRE = documentFeature.getWordSeparatorRE();
switch(wordSeparatorRE) {
case "\\s+":
tokenRE = "(^|\\s+)\\p{Punct}*(" + JOINER.join(stopWordSet) + ")\\p{Punct}*(\\s+|$)";
break;
case "\\W+":
tokenRE = "(\\W+)(" + JOINER.join(stopWordSet) + ")(\\W+)";
break;
default:
throw new IllegalArgumentException("Expected \"\\s+\" or \"\\W+\" as splitter regex pattern, got \"" + wordSeparatorRE + "\"");
}
InlineTable inlineTable = new InlineTable().addRows(DOMUtil.createRow(documentBuilder, Arrays.asList("string", "stem", "regex"), Arrays.asList(tokenRE, " ", "true")));
TextIndexNormalization textIndexNormalization = new TextIndexNormalization().setCaseSensitive(stopWordSet.isCaseSensitive()).setRecursive(// Handles consecutive matches. See http://stackoverflow.com/a/25085385
Boolean.TRUE).setInlineTable(inlineTable);
textIndex.addTextIndexNormalizations(textIndexNormalization);
}
DefineFunction defineFunction = new DefineFunction("tf" + "@" + String.valueOf(CountVectorizerModelConverter.SEQUENCE.getAndIncrement()), OpType.CONTINUOUS, null).setDataType(DataType.INTEGER).addParameterFields(documentField, termField).setExpression(textIndex);
encoder.addDefineFunction(defineFunction);
List<Feature> result = new ArrayList<>();
String[] vocabulary = transformer.vocabulary();
for (int i = 0; i < vocabulary.length; i++) {
String term = vocabulary[i];
if (TermUtil.hasPunctuation(term)) {
throw new IllegalArgumentException(term);
}
result.add(new TermFeature(encoder, defineFunction, documentFeature, term));
}
return result;
}
Aggregations