use of org.jpmml.sparkml.DocumentFeature in project jpmml-sparkml by jpmml.
the class StopWordsRemoverConverter method encodeFeatures.
@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
StopWordsRemover transformer = getTransformer();
DocumentFeature documentFeature = (DocumentFeature) encoder.getOnlyFeature(transformer.getInputCol());
Pattern pattern = Pattern.compile(documentFeature.getWordSeparatorRE());
DocumentFeature.StopWordSet stopWordSet = new DocumentFeature.StopWordSet(transformer.getCaseSensitive());
String[] stopWords = transformer.getStopWords();
for (String stopWord : stopWords) {
String[] stopTokens = pattern.split(stopWord);
// Skip multi-token stopwords. See https://issues.apache.org/jira/browse/SPARK-18374
if (stopTokens.length > 1) {
continue;
}
if (TermUtil.hasPunctuation(stopWord)) {
throw new IllegalArgumentException(stopWord);
}
stopWordSet.add(stopWord);
}
documentFeature.addStopWordSet(stopWordSet);
return Collections.<Feature>singletonList(documentFeature);
}
use of org.jpmml.sparkml.DocumentFeature in project jpmml-sparkml by jpmml.
the class RegexTokenizerConverter method encodeFeatures.
@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
RegexTokenizer transformer = getTransformer();
if (!transformer.getGaps()) {
throw new IllegalArgumentException("Expected splitter mode, got token matching mode");
}
if (transformer.getMinTokenLength() != 1) {
throw new IllegalArgumentException("Expected 1 as minimum token length, got " + transformer.getMinTokenLength() + " as minimum token length");
}
Feature feature = encoder.getOnlyFeature(transformer.getInputCol());
Field<?> field = encoder.getField(feature.getName());
if (transformer.getToLowercase()) {
Apply apply = PMMLUtil.createApply("lowercase", feature.ref());
field = encoder.createDerivedField(FeatureUtil.createName("lowercase", feature), OpType.CATEGORICAL, DataType.STRING, apply);
}
return Collections.<Feature>singletonList(new DocumentFeature(encoder, field, transformer.getPattern()));
}
use of org.jpmml.sparkml.DocumentFeature in project jpmml-sparkml by jpmml.
the class NGramConverter method encodeFeatures.
@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
NGram transformer = getTransformer();
DocumentFeature documentFeature = (DocumentFeature) encoder.getOnlyFeature(transformer.getInputCol());
return Collections.<Feature>singletonList(documentFeature);
}
use of org.jpmml.sparkml.DocumentFeature in project jpmml-sparkml by jpmml.
the class CountVectorizerModelConverter method encodeFeatures.
@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
CountVectorizerModel transformer = getTransformer();
DocumentFeature documentFeature = (DocumentFeature) encoder.getOnlyFeature(transformer.getInputCol());
ParameterField documentField = new ParameterField(FieldName.create("document"));
ParameterField termField = new ParameterField(FieldName.create("term"));
TextIndex textIndex = new TextIndex(documentField.getName()).setTokenize(Boolean.TRUE).setWordSeparatorCharacterRE(documentFeature.getWordSeparatorRE()).setLocalTermWeights(transformer.getBinary() ? TextIndex.LocalTermWeights.BINARY : null).setExpression(new FieldRef(termField.getName()));
Set<DocumentFeature.StopWordSet> stopWordSets = documentFeature.getStopWordSets();
for (DocumentFeature.StopWordSet stopWordSet : stopWordSets) {
if (stopWordSet.isEmpty()) {
continue;
}
DocumentBuilder documentBuilder = DOMUtil.createDocumentBuilder();
String tokenRE;
String wordSeparatorRE = documentFeature.getWordSeparatorRE();
switch(wordSeparatorRE) {
case "\\s+":
tokenRE = "(^|\\s+)\\p{Punct}*(" + JOINER.join(stopWordSet) + ")\\p{Punct}*(\\s+|$)";
break;
case "\\W+":
tokenRE = "(\\W+)(" + JOINER.join(stopWordSet) + ")(\\W+)";
break;
default:
throw new IllegalArgumentException("Expected \"\\s+\" or \"\\W+\" as splitter regex pattern, got \"" + wordSeparatorRE + "\"");
}
InlineTable inlineTable = new InlineTable().addRows(DOMUtil.createRow(documentBuilder, Arrays.asList("string", "stem", "regex"), Arrays.asList(tokenRE, " ", "true")));
TextIndexNormalization textIndexNormalization = new TextIndexNormalization().setCaseSensitive(stopWordSet.isCaseSensitive()).setRecursive(// Handles consecutive matches. See http://stackoverflow.com/a/25085385
Boolean.TRUE).setInlineTable(inlineTable);
textIndex.addTextIndexNormalizations(textIndexNormalization);
}
DefineFunction defineFunction = new DefineFunction("tf" + "@" + String.valueOf(CountVectorizerModelConverter.SEQUENCE.getAndIncrement()), OpType.CONTINUOUS, null).setDataType(DataType.INTEGER).addParameterFields(documentField, termField).setExpression(textIndex);
encoder.addDefineFunction(defineFunction);
List<Feature> result = new ArrayList<>();
String[] vocabulary = transformer.vocabulary();
for (int i = 0; i < vocabulary.length; i++) {
String term = vocabulary[i];
if (TermUtil.hasPunctuation(term)) {
throw new IllegalArgumentException(term);
}
result.add(new TermFeature(encoder, defineFunction, documentFeature, term));
}
return result;
}
use of org.jpmml.sparkml.DocumentFeature in project jpmml-sparkml by jpmml.
the class TokenizerConverter method encodeFeatures.
@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
Tokenizer transformer = getTransformer();
Feature feature = encoder.getOnlyFeature(transformer.getInputCol());
Apply apply = PMMLUtil.createApply("lowercase", feature.ref());
DerivedField derivedField = encoder.createDerivedField(FeatureUtil.createName("lowercase", feature), OpType.CATEGORICAL, DataType.STRING, apply);
return Collections.<Feature>singletonList(new DocumentFeature(encoder, derivedField, "\\s+"));
}
Aggregations