Search in sources :

Example 1 with DocumentFeature

use of org.jpmml.sparkml.DocumentFeature in project jpmml-sparkml by jpmml.

the class StopWordsRemoverConverter method encodeFeatures.

@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
    StopWordsRemover transformer = getTransformer();
    DocumentFeature documentFeature = (DocumentFeature) encoder.getOnlyFeature(transformer.getInputCol());
    Pattern pattern = Pattern.compile(documentFeature.getWordSeparatorRE());
    DocumentFeature.StopWordSet stopWordSet = new DocumentFeature.StopWordSet(transformer.getCaseSensitive());
    String[] stopWords = transformer.getStopWords();
    for (String stopWord : stopWords) {
        String[] stopTokens = pattern.split(stopWord);
        // Skip multi-token stopwords. See https://issues.apache.org/jira/browse/SPARK-18374
        if (stopTokens.length > 1) {
            continue;
        }
        if (TermUtil.hasPunctuation(stopWord)) {
            throw new IllegalArgumentException(stopWord);
        }
        stopWordSet.add(stopWord);
    }
    documentFeature.addStopWordSet(stopWordSet);
    return Collections.<Feature>singletonList(documentFeature);
}
Also used : Pattern(java.util.regex.Pattern) StopWordsRemover(org.apache.spark.ml.feature.StopWordsRemover) DocumentFeature(org.jpmml.sparkml.DocumentFeature) Feature(org.jpmml.converter.Feature) DocumentFeature(org.jpmml.sparkml.DocumentFeature)

Example 2 with DocumentFeature

use of org.jpmml.sparkml.DocumentFeature in project jpmml-sparkml by jpmml.

the class RegexTokenizerConverter method encodeFeatures.

@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
    RegexTokenizer transformer = getTransformer();
    if (!transformer.getGaps()) {
        throw new IllegalArgumentException("Expected splitter mode, got token matching mode");
    }
    if (transformer.getMinTokenLength() != 1) {
        throw new IllegalArgumentException("Expected 1 as minimum token length, got " + transformer.getMinTokenLength() + " as minimum token length");
    }
    Feature feature = encoder.getOnlyFeature(transformer.getInputCol());
    Field<?> field = encoder.getField(feature.getName());
    if (transformer.getToLowercase()) {
        Apply apply = PMMLUtil.createApply("lowercase", feature.ref());
        field = encoder.createDerivedField(FeatureUtil.createName("lowercase", feature), OpType.CATEGORICAL, DataType.STRING, apply);
    }
    return Collections.<Feature>singletonList(new DocumentFeature(encoder, field, transformer.getPattern()));
}
Also used : Apply(org.dmg.pmml.Apply) RegexTokenizer(org.apache.spark.ml.feature.RegexTokenizer) DocumentFeature(org.jpmml.sparkml.DocumentFeature) Feature(org.jpmml.converter.Feature) DocumentFeature(org.jpmml.sparkml.DocumentFeature)

Example 3 with DocumentFeature

use of org.jpmml.sparkml.DocumentFeature in project jpmml-sparkml by jpmml.

the class NGramConverter method encodeFeatures.

@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
    NGram transformer = getTransformer();
    DocumentFeature documentFeature = (DocumentFeature) encoder.getOnlyFeature(transformer.getInputCol());
    return Collections.<Feature>singletonList(documentFeature);
}
Also used : NGram(org.apache.spark.ml.feature.NGram) DocumentFeature(org.jpmml.sparkml.DocumentFeature) Feature(org.jpmml.converter.Feature) DocumentFeature(org.jpmml.sparkml.DocumentFeature)

Example 4 with DocumentFeature

use of org.jpmml.sparkml.DocumentFeature in project jpmml-sparkml by jpmml.

the class CountVectorizerModelConverter method encodeFeatures.

@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
    CountVectorizerModel transformer = getTransformer();
    DocumentFeature documentFeature = (DocumentFeature) encoder.getOnlyFeature(transformer.getInputCol());
    ParameterField documentField = new ParameterField(FieldName.create("document"));
    ParameterField termField = new ParameterField(FieldName.create("term"));
    TextIndex textIndex = new TextIndex(documentField.getName()).setTokenize(Boolean.TRUE).setWordSeparatorCharacterRE(documentFeature.getWordSeparatorRE()).setLocalTermWeights(transformer.getBinary() ? TextIndex.LocalTermWeights.BINARY : null).setExpression(new FieldRef(termField.getName()));
    Set<DocumentFeature.StopWordSet> stopWordSets = documentFeature.getStopWordSets();
    for (DocumentFeature.StopWordSet stopWordSet : stopWordSets) {
        if (stopWordSet.isEmpty()) {
            continue;
        }
        DocumentBuilder documentBuilder = DOMUtil.createDocumentBuilder();
        String tokenRE;
        String wordSeparatorRE = documentFeature.getWordSeparatorRE();
        switch(wordSeparatorRE) {
            case "\\s+":
                tokenRE = "(^|\\s+)\\p{Punct}*(" + JOINER.join(stopWordSet) + ")\\p{Punct}*(\\s+|$)";
                break;
            case "\\W+":
                tokenRE = "(\\W+)(" + JOINER.join(stopWordSet) + ")(\\W+)";
                break;
            default:
                throw new IllegalArgumentException("Expected \"\\s+\" or \"\\W+\" as splitter regex pattern, got \"" + wordSeparatorRE + "\"");
        }
        InlineTable inlineTable = new InlineTable().addRows(DOMUtil.createRow(documentBuilder, Arrays.asList("string", "stem", "regex"), Arrays.asList(tokenRE, " ", "true")));
        TextIndexNormalization textIndexNormalization = new TextIndexNormalization().setCaseSensitive(stopWordSet.isCaseSensitive()).setRecursive(// Handles consecutive matches. See http://stackoverflow.com/a/25085385
        Boolean.TRUE).setInlineTable(inlineTable);
        textIndex.addTextIndexNormalizations(textIndexNormalization);
    }
    DefineFunction defineFunction = new DefineFunction("tf" + "@" + String.valueOf(CountVectorizerModelConverter.SEQUENCE.getAndIncrement()), OpType.CONTINUOUS, null).setDataType(DataType.INTEGER).addParameterFields(documentField, termField).setExpression(textIndex);
    encoder.addDefineFunction(defineFunction);
    List<Feature> result = new ArrayList<>();
    String[] vocabulary = transformer.vocabulary();
    for (int i = 0; i < vocabulary.length; i++) {
        String term = vocabulary[i];
        if (TermUtil.hasPunctuation(term)) {
            throw new IllegalArgumentException(term);
        }
        result.add(new TermFeature(encoder, defineFunction, documentFeature, term));
    }
    return result;
}
Also used : InlineTable(org.dmg.pmml.InlineTable) FieldRef(org.dmg.pmml.FieldRef) TextIndex(org.dmg.pmml.TextIndex) DocumentFeature(org.jpmml.sparkml.DocumentFeature) ArrayList(java.util.ArrayList) Feature(org.jpmml.converter.Feature) DocumentFeature(org.jpmml.sparkml.DocumentFeature) TermFeature(org.jpmml.sparkml.TermFeature) TermFeature(org.jpmml.sparkml.TermFeature) TextIndexNormalization(org.dmg.pmml.TextIndexNormalization) CountVectorizerModel(org.apache.spark.ml.feature.CountVectorizerModel) DocumentBuilder(javax.xml.parsers.DocumentBuilder) DefineFunction(org.dmg.pmml.DefineFunction) ParameterField(org.dmg.pmml.ParameterField)

Example 5 with DocumentFeature

use of org.jpmml.sparkml.DocumentFeature in project jpmml-sparkml by jpmml.

the class TokenizerConverter method encodeFeatures.

@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
    Tokenizer transformer = getTransformer();
    Feature feature = encoder.getOnlyFeature(transformer.getInputCol());
    Apply apply = PMMLUtil.createApply("lowercase", feature.ref());
    DerivedField derivedField = encoder.createDerivedField(FeatureUtil.createName("lowercase", feature), OpType.CATEGORICAL, DataType.STRING, apply);
    return Collections.<Feature>singletonList(new DocumentFeature(encoder, derivedField, "\\s+"));
}
Also used : Apply(org.dmg.pmml.Apply) DocumentFeature(org.jpmml.sparkml.DocumentFeature) Tokenizer(org.apache.spark.ml.feature.Tokenizer) Feature(org.jpmml.converter.Feature) DocumentFeature(org.jpmml.sparkml.DocumentFeature) DerivedField(org.dmg.pmml.DerivedField)

Aggregations

Feature (org.jpmml.converter.Feature)5 DocumentFeature (org.jpmml.sparkml.DocumentFeature)5 Apply (org.dmg.pmml.Apply)2 ArrayList (java.util.ArrayList)1 Pattern (java.util.regex.Pattern)1 DocumentBuilder (javax.xml.parsers.DocumentBuilder)1 CountVectorizerModel (org.apache.spark.ml.feature.CountVectorizerModel)1 NGram (org.apache.spark.ml.feature.NGram)1 RegexTokenizer (org.apache.spark.ml.feature.RegexTokenizer)1 StopWordsRemover (org.apache.spark.ml.feature.StopWordsRemover)1 Tokenizer (org.apache.spark.ml.feature.Tokenizer)1 DefineFunction (org.dmg.pmml.DefineFunction)1 DerivedField (org.dmg.pmml.DerivedField)1 FieldRef (org.dmg.pmml.FieldRef)1 InlineTable (org.dmg.pmml.InlineTable)1 ParameterField (org.dmg.pmml.ParameterField)1 TextIndex (org.dmg.pmml.TextIndex)1 TextIndexNormalization (org.dmg.pmml.TextIndexNormalization)1 TermFeature (org.jpmml.sparkml.TermFeature)1