Search in sources :

Example 1 with TermFeature

use of org.jpmml.sparkml.TermFeature in project jpmml-sparkml by jpmml.

the class CountVectorizerModelConverter method encodeFeatures.

@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
    CountVectorizerModel transformer = getTransformer();
    DocumentFeature documentFeature = (DocumentFeature) encoder.getOnlyFeature(transformer.getInputCol());
    ParameterField documentField = new ParameterField(FieldName.create("document"));
    ParameterField termField = new ParameterField(FieldName.create("term"));
    TextIndex textIndex = new TextIndex(documentField.getName()).setTokenize(Boolean.TRUE).setWordSeparatorCharacterRE(documentFeature.getWordSeparatorRE()).setLocalTermWeights(transformer.getBinary() ? TextIndex.LocalTermWeights.BINARY : null).setExpression(new FieldRef(termField.getName()));
    Set<DocumentFeature.StopWordSet> stopWordSets = documentFeature.getStopWordSets();
    for (DocumentFeature.StopWordSet stopWordSet : stopWordSets) {
        if (stopWordSet.isEmpty()) {
            continue;
        }
        DocumentBuilder documentBuilder = DOMUtil.createDocumentBuilder();
        String tokenRE;
        String wordSeparatorRE = documentFeature.getWordSeparatorRE();
        switch(wordSeparatorRE) {
            case "\\s+":
                tokenRE = "(^|\\s+)\\p{Punct}*(" + JOINER.join(stopWordSet) + ")\\p{Punct}*(\\s+|$)";
                break;
            case "\\W+":
                tokenRE = "(\\W+)(" + JOINER.join(stopWordSet) + ")(\\W+)";
                break;
            default:
                throw new IllegalArgumentException("Expected \"\\s+\" or \"\\W+\" as splitter regex pattern, got \"" + wordSeparatorRE + "\"");
        }
        InlineTable inlineTable = new InlineTable().addRows(DOMUtil.createRow(documentBuilder, Arrays.asList("string", "stem", "regex"), Arrays.asList(tokenRE, " ", "true")));
        TextIndexNormalization textIndexNormalization = new TextIndexNormalization().setCaseSensitive(stopWordSet.isCaseSensitive()).setRecursive(// Handles consecutive matches. See http://stackoverflow.com/a/25085385
        Boolean.TRUE).setInlineTable(inlineTable);
        textIndex.addTextIndexNormalizations(textIndexNormalization);
    }
    DefineFunction defineFunction = new DefineFunction("tf" + "@" + String.valueOf(CountVectorizerModelConverter.SEQUENCE.getAndIncrement()), OpType.CONTINUOUS, null).setDataType(DataType.INTEGER).addParameterFields(documentField, termField).setExpression(textIndex);
    encoder.addDefineFunction(defineFunction);
    List<Feature> result = new ArrayList<>();
    String[] vocabulary = transformer.vocabulary();
    for (int i = 0; i < vocabulary.length; i++) {
        String term = vocabulary[i];
        if (TermUtil.hasPunctuation(term)) {
            throw new IllegalArgumentException(term);
        }
        result.add(new TermFeature(encoder, defineFunction, documentFeature, term));
    }
    return result;
}
Also used : InlineTable(org.dmg.pmml.InlineTable) FieldRef(org.dmg.pmml.FieldRef) TextIndex(org.dmg.pmml.TextIndex) DocumentFeature(org.jpmml.sparkml.DocumentFeature) ArrayList(java.util.ArrayList) Feature(org.jpmml.converter.Feature) DocumentFeature(org.jpmml.sparkml.DocumentFeature) TermFeature(org.jpmml.sparkml.TermFeature) TermFeature(org.jpmml.sparkml.TermFeature) TextIndexNormalization(org.dmg.pmml.TextIndexNormalization) CountVectorizerModel(org.apache.spark.ml.feature.CountVectorizerModel) DocumentBuilder(javax.xml.parsers.DocumentBuilder) DefineFunction(org.dmg.pmml.DefineFunction) ParameterField(org.dmg.pmml.ParameterField)

Example 2 with TermFeature

use of org.jpmml.sparkml.TermFeature in project jpmml-sparkml by jpmml.

the class IDFModelConverter method encodeFeatures.

@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
    IDFModel transformer = getTransformer();
    List<Feature> features = encoder.getFeatures(transformer.getInputCol());
    Vector idf = transformer.idf();
    if (idf.size() != features.size()) {
        throw new IllegalArgumentException();
    }
    List<Feature> result = new ArrayList<>();
    for (int i = 0; i < features.size(); i++) {
        Feature feature = features.get(i);
        TermFeature termFeature = (TermFeature) feature;
        result.add(termFeature.toWeightedTermFeature(idf.apply(i)));
    }
    return result;
}
Also used : TermFeature(org.jpmml.sparkml.TermFeature) ArrayList(java.util.ArrayList) Feature(org.jpmml.converter.Feature) TermFeature(org.jpmml.sparkml.TermFeature) Vector(org.apache.spark.ml.linalg.Vector) IDFModel(org.apache.spark.ml.feature.IDFModel)

Aggregations

ArrayList (java.util.ArrayList)2 Feature (org.jpmml.converter.Feature)2 TermFeature (org.jpmml.sparkml.TermFeature)2 DocumentBuilder (javax.xml.parsers.DocumentBuilder)1 CountVectorizerModel (org.apache.spark.ml.feature.CountVectorizerModel)1 IDFModel (org.apache.spark.ml.feature.IDFModel)1 Vector (org.apache.spark.ml.linalg.Vector)1 DefineFunction (org.dmg.pmml.DefineFunction)1 FieldRef (org.dmg.pmml.FieldRef)1 InlineTable (org.dmg.pmml.InlineTable)1 ParameterField (org.dmg.pmml.ParameterField)1 TextIndex (org.dmg.pmml.TextIndex)1 TextIndexNormalization (org.dmg.pmml.TextIndexNormalization)1 DocumentFeature (org.jpmml.sparkml.DocumentFeature)1