Search in sources :

Example 46 with Feature

use of org.jpmml.converter.Feature in project jpmml-sparkml by jpmml.

the class CountVectorizerModelConverter method encodeFeatures.

@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
    CountVectorizerModel transformer = getTransformer();
    DocumentFeature documentFeature = (DocumentFeature) encoder.getOnlyFeature(transformer.getInputCol());
    ParameterField documentField = new ParameterField(FieldName.create("document"));
    ParameterField termField = new ParameterField(FieldName.create("term"));
    TextIndex textIndex = new TextIndex(documentField.getName()).setTokenize(Boolean.TRUE).setWordSeparatorCharacterRE(documentFeature.getWordSeparatorRE()).setLocalTermWeights(transformer.getBinary() ? TextIndex.LocalTermWeights.BINARY : null).setExpression(new FieldRef(termField.getName()));
    Set<DocumentFeature.StopWordSet> stopWordSets = documentFeature.getStopWordSets();
    for (DocumentFeature.StopWordSet stopWordSet : stopWordSets) {
        if (stopWordSet.isEmpty()) {
            continue;
        }
        DocumentBuilder documentBuilder = DOMUtil.createDocumentBuilder();
        String tokenRE;
        String wordSeparatorRE = documentFeature.getWordSeparatorRE();
        switch(wordSeparatorRE) {
            case "\\s+":
                tokenRE = "(^|\\s+)\\p{Punct}*(" + JOINER.join(stopWordSet) + ")\\p{Punct}*(\\s+|$)";
                break;
            case "\\W+":
                tokenRE = "(\\W+)(" + JOINER.join(stopWordSet) + ")(\\W+)";
                break;
            default:
                throw new IllegalArgumentException("Expected \"\\s+\" or \"\\W+\" as splitter regex pattern, got \"" + wordSeparatorRE + "\"");
        }
        InlineTable inlineTable = new InlineTable().addRows(DOMUtil.createRow(documentBuilder, Arrays.asList("string", "stem", "regex"), Arrays.asList(tokenRE, " ", "true")));
        TextIndexNormalization textIndexNormalization = new TextIndexNormalization().setCaseSensitive(stopWordSet.isCaseSensitive()).setRecursive(// Handles consecutive matches. See http://stackoverflow.com/a/25085385
        Boolean.TRUE).setInlineTable(inlineTable);
        textIndex.addTextIndexNormalizations(textIndexNormalization);
    }
    DefineFunction defineFunction = new DefineFunction("tf" + "@" + String.valueOf(CountVectorizerModelConverter.SEQUENCE.getAndIncrement()), OpType.CONTINUOUS, null).setDataType(DataType.INTEGER).addParameterFields(documentField, termField).setExpression(textIndex);
    encoder.addDefineFunction(defineFunction);
    List<Feature> result = new ArrayList<>();
    String[] vocabulary = transformer.vocabulary();
    for (int i = 0; i < vocabulary.length; i++) {
        String term = vocabulary[i];
        if (TermUtil.hasPunctuation(term)) {
            throw new IllegalArgumentException(term);
        }
        result.add(new TermFeature(encoder, defineFunction, documentFeature, term));
    }
    return result;
}
Also used : InlineTable(org.dmg.pmml.InlineTable) FieldRef(org.dmg.pmml.FieldRef) TextIndex(org.dmg.pmml.TextIndex) DocumentFeature(org.jpmml.sparkml.DocumentFeature) ArrayList(java.util.ArrayList) Feature(org.jpmml.converter.Feature) DocumentFeature(org.jpmml.sparkml.DocumentFeature) TermFeature(org.jpmml.sparkml.TermFeature) TermFeature(org.jpmml.sparkml.TermFeature) TextIndexNormalization(org.dmg.pmml.TextIndexNormalization) CountVectorizerModel(org.apache.spark.ml.feature.CountVectorizerModel) DocumentBuilder(javax.xml.parsers.DocumentBuilder) DefineFunction(org.dmg.pmml.DefineFunction) ParameterField(org.dmg.pmml.ParameterField)

Example 47 with Feature

use of org.jpmml.converter.Feature in project jpmml-sparkml by jpmml.

the class IDFModelConverter method encodeFeatures.

@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
    IDFModel transformer = getTransformer();
    List<Feature> features = encoder.getFeatures(transformer.getInputCol());
    Vector idf = transformer.idf();
    if (idf.size() != features.size()) {
        throw new IllegalArgumentException();
    }
    List<Feature> result = new ArrayList<>();
    for (int i = 0; i < features.size(); i++) {
        Feature feature = features.get(i);
        TermFeature termFeature = (TermFeature) feature;
        result.add(termFeature.toWeightedTermFeature(idf.apply(i)));
    }
    return result;
}
Also used : TermFeature(org.jpmml.sparkml.TermFeature) ArrayList(java.util.ArrayList) Feature(org.jpmml.converter.Feature) TermFeature(org.jpmml.sparkml.TermFeature) Vector(org.apache.spark.ml.linalg.Vector) IDFModel(org.apache.spark.ml.feature.IDFModel)

Example 48 with Feature

use of org.jpmml.converter.Feature in project jpmml-sparkml by jpmml.

the class ClassificationModelConverter method registerOutputFields.

@Override
public List<OutputField> registerOutputFields(Label label, SparkMLEncoder encoder) {
    T model = getTransformer();
    CategoricalLabel categoricalLabel = (CategoricalLabel) label;
    List<OutputField> result = new ArrayList<>();
    String predictionCol = model.getPredictionCol();
    OutputField pmmlPredictedField = ModelUtil.createPredictedField(FieldName.create("pmml(" + predictionCol + ")"), categoricalLabel.getDataType(), OpType.CATEGORICAL);
    result.add(pmmlPredictedField);
    List<String> categories = new ArrayList<>();
    DocumentBuilder documentBuilder = DOMUtil.createDocumentBuilder();
    InlineTable inlineTable = new InlineTable();
    List<String> columns = Arrays.asList("input", "output");
    for (int i = 0; i < categoricalLabel.size(); i++) {
        String value = categoricalLabel.getValue(i);
        String category = String.valueOf(i);
        categories.add(category);
        Row row = DOMUtil.createRow(documentBuilder, columns, Arrays.asList(value, category));
        inlineTable.addRows(row);
    }
    MapValues mapValues = new MapValues().addFieldColumnPairs(new FieldColumnPair(pmmlPredictedField.getName(), columns.get(0))).setOutputColumn(columns.get(1)).setInlineTable(inlineTable);
    final OutputField predictedField = new OutputField(FieldName.create(predictionCol), DataType.DOUBLE).setOpType(OpType.CATEGORICAL).setResultFeature(ResultFeature.TRANSFORMED_VALUE).setExpression(mapValues);
    result.add(predictedField);
    Feature feature = new CategoricalFeature(encoder, predictedField.getName(), predictedField.getDataType(), categories) {

        @Override
        public ContinuousFeature toContinuousFeature() {
            PMMLEncoder encoder = ensureEncoder();
            return new ContinuousFeature(encoder, getName(), getDataType());
        }
    };
    encoder.putOnlyFeature(predictionCol, feature);
    if (model instanceof HasProbabilityCol) {
        HasProbabilityCol hasProbabilityCol = (HasProbabilityCol) model;
        String probabilityCol = hasProbabilityCol.getProbabilityCol();
        List<Feature> features = new ArrayList<>();
        for (int i = 0; i < categoricalLabel.size(); i++) {
            String value = categoricalLabel.getValue(i);
            OutputField probabilityField = ModelUtil.createProbabilityField(FieldName.create(probabilityCol + "(" + value + ")"), DataType.DOUBLE, value);
            result.add(probabilityField);
            features.add(new ContinuousFeature(encoder, probabilityField.getName(), probabilityField.getDataType()));
        }
        encoder.putFeatures(probabilityCol, features);
    }
    return result;
}
Also used : InlineTable(org.dmg.pmml.InlineTable) HasProbabilityCol(org.apache.spark.ml.param.shared.HasProbabilityCol) PMMLEncoder(org.jpmml.converter.PMMLEncoder) ArrayList(java.util.ArrayList) FieldColumnPair(org.dmg.pmml.FieldColumnPair) ResultFeature(org.dmg.pmml.ResultFeature) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Feature(org.jpmml.converter.Feature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) ContinuousFeature(org.jpmml.converter.ContinuousFeature) DocumentBuilder(javax.xml.parsers.DocumentBuilder) MapValues(org.dmg.pmml.MapValues) CategoricalLabel(org.jpmml.converter.CategoricalLabel) OutputField(org.dmg.pmml.OutputField) Row(org.dmg.pmml.Row)

Example 49 with Feature

use of org.jpmml.converter.Feature in project jpmml-sparkml by jpmml.

the class ConverterUtil method toPMML.

public static PMML toPMML(StructType schema, PipelineModel pipelineModel) {
    checkVersion();
    SparkMLEncoder encoder = new SparkMLEncoder(schema);
    List<org.dmg.pmml.Model> models = new ArrayList<>();
    Iterable<Transformer> transformers = getTransformers(pipelineModel);
    for (Transformer transformer : transformers) {
        TransformerConverter<?> converter = ConverterUtil.createConverter(transformer);
        if (converter instanceof FeatureConverter) {
            FeatureConverter<?> featureConverter = (FeatureConverter<?>) converter;
            featureConverter.registerFeatures(encoder);
        } else if (converter instanceof ModelConverter) {
            ModelConverter<?> modelConverter = (ModelConverter<?>) converter;
            org.dmg.pmml.Model model = modelConverter.registerModel(encoder);
            models.add(model);
        } else {
            throw new IllegalArgumentException("Expected a " + FeatureConverter.class.getName() + " or " + ModelConverter.class.getName() + " instance, got " + converter);
        }
    }
    org.dmg.pmml.Model rootModel;
    if (models.size() == 1) {
        rootModel = Iterables.getOnlyElement(models);
    } else if (models.size() > 1) {
        List<MiningField> targetMiningFields = new ArrayList<>();
        for (org.dmg.pmml.Model model : models) {
            MiningSchema miningSchema = model.getMiningSchema();
            List<MiningField> miningFields = miningSchema.getMiningFields();
            for (MiningField miningField : miningFields) {
                MiningField.UsageType usageType = miningField.getUsageType();
                switch(usageType) {
                    case PREDICTED:
                    case TARGET:
                        targetMiningFields.add(miningField);
                        break;
                    default:
                        break;
                }
            }
        }
        MiningSchema miningSchema = new MiningSchema(targetMiningFields);
        MiningModel miningModel = MiningModelUtil.createModelChain(models, new Schema(null, Collections.<Feature>emptyList())).setMiningSchema(miningSchema);
        rootModel = miningModel;
    } else {
        throw new IllegalArgumentException("Expected a pipeline with one or more models, got a pipeline with zero models");
    }
    PMML pmml = encoder.encodePMML(rootModel);
    return pmml;
}
Also used : MiningField(org.dmg.pmml.MiningField) Transformer(org.apache.spark.ml.Transformer) MiningSchema(org.dmg.pmml.MiningSchema) Schema(org.jpmml.converter.Schema) ArrayList(java.util.ArrayList) Feature(org.jpmml.converter.Feature) MiningSchema(org.dmg.pmml.MiningSchema) MiningModel(org.dmg.pmml.mining.MiningModel) MiningModel(org.dmg.pmml.mining.MiningModel) PipelineModel(org.apache.spark.ml.PipelineModel) TrainValidationSplitModel(org.apache.spark.ml.tuning.TrainValidationSplitModel) CrossValidatorModel(org.apache.spark.ml.tuning.CrossValidatorModel) PMML(org.dmg.pmml.PMML) ArrayList(java.util.ArrayList) List(java.util.List)

Example 50 with Feature

use of org.jpmml.converter.Feature in project jpmml-sparkml by jpmml.

the class FeatureConverter method registerFeatures.

public void registerFeatures(SparkMLEncoder encoder) {
    Transformer transformer = getTransformer();
    if (transformer instanceof HasOutputCol) {
        HasOutputCol hasOutputCol = (HasOutputCol) transformer;
        String outputCol = hasOutputCol.getOutputCol();
        List<Feature> features = encodeFeatures(encoder);
        encoder.putFeatures(outputCol, features);
    } else if (transformer instanceof HasOutputCols) {
        HasOutputCols hasOutputCols = (HasOutputCols) transformer;
        String[] outputCols = hasOutputCols.getOutputCols();
        List<Feature> features = encodeFeatures(encoder);
        if (outputCols.length != features.size()) {
            throw new IllegalArgumentException("Expected " + outputCols.length + " features, got " + features.size() + " features");
        }
        for (int i = 0; i < outputCols.length; i++) {
            String outputCol = outputCols[i];
            Feature feature = features.get(i);
            if (feature instanceof BinarizedCategoricalFeature) {
                BinarizedCategoricalFeature binarizedCategoricalFeature = (BinarizedCategoricalFeature) feature;
                encoder.putFeatures(outputCol, (List) binarizedCategoricalFeature.getBinaryFeatures());
            } else {
                encoder.putOnlyFeature(outputCol, feature);
            }
        }
    }
}
Also used : Transformer(org.apache.spark.ml.Transformer) List(java.util.List) HasOutputCols(org.apache.spark.ml.param.shared.HasOutputCols) Feature(org.jpmml.converter.Feature) HasOutputCol(org.apache.spark.ml.param.shared.HasOutputCol)

Aggregations

Feature (org.jpmml.converter.Feature)53 ContinuousFeature (org.jpmml.converter.ContinuousFeature)30 ArrayList (java.util.ArrayList)27 CategoricalFeature (org.jpmml.converter.CategoricalFeature)19 DerivedField (org.dmg.pmml.DerivedField)14 DataField (org.dmg.pmml.DataField)13 FieldName (org.dmg.pmml.FieldName)10 Apply (org.dmg.pmml.Apply)9 BooleanFeature (org.jpmml.converter.BooleanFeature)9 BinaryFeature (org.jpmml.converter.BinaryFeature)7 List (java.util.List)6 Expression (org.dmg.pmml.Expression)6 SimplePredicate (org.dmg.pmml.SimplePredicate)6 Vector (org.apache.spark.ml.linalg.Vector)5 Predicate (org.dmg.pmml.Predicate)5 Node (org.dmg.pmml.tree.Node)5 DocumentFeature (org.jpmml.sparkml.DocumentFeature)5 InteractionFeature (org.jpmml.converter.InteractionFeature)4 DocumentBuilder (javax.xml.parsers.DocumentBuilder)3 Transformer (org.apache.spark.ml.Transformer)3