Search in sources :

Example 1 with ParameterField

use of org.dmg.pmml.ParameterField in project jpmml-sparkml by jpmml.

the class TermFeature method toWeightedTermFeature.

public WeightedTermFeature toWeightedTermFeature(double weight) {
    PMMLEncoder encoder = ensureEncoder();
    DefineFunction defineFunction = getDefineFunction();
    String name = (defineFunction.getName()).replace("tf@", "tf-idf@");
    DefineFunction weightedDefineFunction = encoder.getDefineFunction(name);
    if (weightedDefineFunction == null) {
        ParameterField weightField = new ParameterField(FieldName.create("weight"));
        List<ParameterField> parameterFields = new ArrayList<>(defineFunction.getParameterFields());
        parameterFields.add(weightField);
        Apply apply = PMMLUtil.createApply("*", defineFunction.getExpression(), new FieldRef(weightField.getName()));
        weightedDefineFunction = new DefineFunction(name, OpType.CONTINUOUS, parameterFields).setDataType(DataType.DOUBLE).setExpression(apply);
        encoder.addDefineFunction(weightedDefineFunction);
    }
    return new WeightedTermFeature(encoder, weightedDefineFunction, getFeature(), getValue(), weight);
}
Also used : FieldRef(org.dmg.pmml.FieldRef) Apply(org.dmg.pmml.Apply) PMMLEncoder(org.jpmml.converter.PMMLEncoder) ArrayList(java.util.ArrayList) DefineFunction(org.dmg.pmml.DefineFunction) ParameterField(org.dmg.pmml.ParameterField)

Example 2 with ParameterField

use of org.dmg.pmml.ParameterField in project jpmml-sparkml by jpmml.

the class CountVectorizerModelConverter method encodeFeatures.

@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
    CountVectorizerModel transformer = getTransformer();
    DocumentFeature documentFeature = (DocumentFeature) encoder.getOnlyFeature(transformer.getInputCol());
    ParameterField documentField = new ParameterField(FieldName.create("document"));
    ParameterField termField = new ParameterField(FieldName.create("term"));
    TextIndex textIndex = new TextIndex(documentField.getName()).setTokenize(Boolean.TRUE).setWordSeparatorCharacterRE(documentFeature.getWordSeparatorRE()).setLocalTermWeights(transformer.getBinary() ? TextIndex.LocalTermWeights.BINARY : null).setExpression(new FieldRef(termField.getName()));
    Set<DocumentFeature.StopWordSet> stopWordSets = documentFeature.getStopWordSets();
    for (DocumentFeature.StopWordSet stopWordSet : stopWordSets) {
        if (stopWordSet.isEmpty()) {
            continue;
        }
        DocumentBuilder documentBuilder = DOMUtil.createDocumentBuilder();
        String tokenRE;
        String wordSeparatorRE = documentFeature.getWordSeparatorRE();
        switch(wordSeparatorRE) {
            case "\\s+":
                tokenRE = "(^|\\s+)\\p{Punct}*(" + JOINER.join(stopWordSet) + ")\\p{Punct}*(\\s+|$)";
                break;
            case "\\W+":
                tokenRE = "(\\W+)(" + JOINER.join(stopWordSet) + ")(\\W+)";
                break;
            default:
                throw new IllegalArgumentException("Expected \"\\s+\" or \"\\W+\" as splitter regex pattern, got \"" + wordSeparatorRE + "\"");
        }
        InlineTable inlineTable = new InlineTable().addRows(DOMUtil.createRow(documentBuilder, Arrays.asList("string", "stem", "regex"), Arrays.asList(tokenRE, " ", "true")));
        TextIndexNormalization textIndexNormalization = new TextIndexNormalization().setCaseSensitive(stopWordSet.isCaseSensitive()).setRecursive(// Handles consecutive matches. See http://stackoverflow.com/a/25085385
        Boolean.TRUE).setInlineTable(inlineTable);
        textIndex.addTextIndexNormalizations(textIndexNormalization);
    }
    DefineFunction defineFunction = new DefineFunction("tf" + "@" + String.valueOf(CountVectorizerModelConverter.SEQUENCE.getAndIncrement()), OpType.CONTINUOUS, null).setDataType(DataType.INTEGER).addParameterFields(documentField, termField).setExpression(textIndex);
    encoder.addDefineFunction(defineFunction);
    List<Feature> result = new ArrayList<>();
    String[] vocabulary = transformer.vocabulary();
    for (int i = 0; i < vocabulary.length; i++) {
        String term = vocabulary[i];
        if (TermUtil.hasPunctuation(term)) {
            throw new IllegalArgumentException(term);
        }
        result.add(new TermFeature(encoder, defineFunction, documentFeature, term));
    }
    return result;
}
Also used : InlineTable(org.dmg.pmml.InlineTable) FieldRef(org.dmg.pmml.FieldRef) TextIndex(org.dmg.pmml.TextIndex) DocumentFeature(org.jpmml.sparkml.DocumentFeature) ArrayList(java.util.ArrayList) Feature(org.jpmml.converter.Feature) DocumentFeature(org.jpmml.sparkml.DocumentFeature) TermFeature(org.jpmml.sparkml.TermFeature) TermFeature(org.jpmml.sparkml.TermFeature) TextIndexNormalization(org.dmg.pmml.TextIndexNormalization) CountVectorizerModel(org.apache.spark.ml.feature.CountVectorizerModel) DocumentBuilder(javax.xml.parsers.DocumentBuilder) DefineFunction(org.dmg.pmml.DefineFunction) ParameterField(org.dmg.pmml.ParameterField)

Aggregations

ArrayList (java.util.ArrayList)2 DefineFunction (org.dmg.pmml.DefineFunction)2 FieldRef (org.dmg.pmml.FieldRef)2 ParameterField (org.dmg.pmml.ParameterField)2 DocumentBuilder (javax.xml.parsers.DocumentBuilder)1 CountVectorizerModel (org.apache.spark.ml.feature.CountVectorizerModel)1 Apply (org.dmg.pmml.Apply)1 InlineTable (org.dmg.pmml.InlineTable)1 TextIndex (org.dmg.pmml.TextIndex)1 TextIndexNormalization (org.dmg.pmml.TextIndexNormalization)1 Feature (org.jpmml.converter.Feature)1 PMMLEncoder (org.jpmml.converter.PMMLEncoder)1 DocumentFeature (org.jpmml.sparkml.DocumentFeature)1 TermFeature (org.jpmml.sparkml.TermFeature)1