use of org.jpmml.converter.Feature in project jpmml-sparkml by jpmml.
the class CountVectorizerModelConverter method encodeFeatures.
@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
CountVectorizerModel transformer = getTransformer();
DocumentFeature documentFeature = (DocumentFeature) encoder.getOnlyFeature(transformer.getInputCol());
ParameterField documentField = new ParameterField(FieldName.create("document"));
ParameterField termField = new ParameterField(FieldName.create("term"));
TextIndex textIndex = new TextIndex(documentField.getName()).setTokenize(Boolean.TRUE).setWordSeparatorCharacterRE(documentFeature.getWordSeparatorRE()).setLocalTermWeights(transformer.getBinary() ? TextIndex.LocalTermWeights.BINARY : null).setExpression(new FieldRef(termField.getName()));
Set<DocumentFeature.StopWordSet> stopWordSets = documentFeature.getStopWordSets();
for (DocumentFeature.StopWordSet stopWordSet : stopWordSets) {
if (stopWordSet.isEmpty()) {
continue;
}
DocumentBuilder documentBuilder = DOMUtil.createDocumentBuilder();
String tokenRE;
String wordSeparatorRE = documentFeature.getWordSeparatorRE();
switch(wordSeparatorRE) {
case "\\s+":
tokenRE = "(^|\\s+)\\p{Punct}*(" + JOINER.join(stopWordSet) + ")\\p{Punct}*(\\s+|$)";
break;
case "\\W+":
tokenRE = "(\\W+)(" + JOINER.join(stopWordSet) + ")(\\W+)";
break;
default:
throw new IllegalArgumentException("Expected \"\\s+\" or \"\\W+\" as splitter regex pattern, got \"" + wordSeparatorRE + "\"");
}
InlineTable inlineTable = new InlineTable().addRows(DOMUtil.createRow(documentBuilder, Arrays.asList("string", "stem", "regex"), Arrays.asList(tokenRE, " ", "true")));
TextIndexNormalization textIndexNormalization = new TextIndexNormalization().setCaseSensitive(stopWordSet.isCaseSensitive()).setRecursive(// Handles consecutive matches. See http://stackoverflow.com/a/25085385
Boolean.TRUE).setInlineTable(inlineTable);
textIndex.addTextIndexNormalizations(textIndexNormalization);
}
DefineFunction defineFunction = new DefineFunction("tf" + "@" + String.valueOf(CountVectorizerModelConverter.SEQUENCE.getAndIncrement()), OpType.CONTINUOUS, null).setDataType(DataType.INTEGER).addParameterFields(documentField, termField).setExpression(textIndex);
encoder.addDefineFunction(defineFunction);
List<Feature> result = new ArrayList<>();
String[] vocabulary = transformer.vocabulary();
for (int i = 0; i < vocabulary.length; i++) {
String term = vocabulary[i];
if (TermUtil.hasPunctuation(term)) {
throw new IllegalArgumentException(term);
}
result.add(new TermFeature(encoder, defineFunction, documentFeature, term));
}
return result;
}
use of org.jpmml.converter.Feature in project jpmml-sparkml by jpmml.
the class IDFModelConverter method encodeFeatures.
@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
IDFModel transformer = getTransformer();
List<Feature> features = encoder.getFeatures(transformer.getInputCol());
Vector idf = transformer.idf();
if (idf.size() != features.size()) {
throw new IllegalArgumentException();
}
List<Feature> result = new ArrayList<>();
for (int i = 0; i < features.size(); i++) {
Feature feature = features.get(i);
TermFeature termFeature = (TermFeature) feature;
result.add(termFeature.toWeightedTermFeature(idf.apply(i)));
}
return result;
}
use of org.jpmml.converter.Feature in project jpmml-sparkml by jpmml.
the class ClassificationModelConverter method registerOutputFields.
@Override
public List<OutputField> registerOutputFields(Label label, SparkMLEncoder encoder) {
T model = getTransformer();
CategoricalLabel categoricalLabel = (CategoricalLabel) label;
List<OutputField> result = new ArrayList<>();
String predictionCol = model.getPredictionCol();
OutputField pmmlPredictedField = ModelUtil.createPredictedField(FieldName.create("pmml(" + predictionCol + ")"), categoricalLabel.getDataType(), OpType.CATEGORICAL);
result.add(pmmlPredictedField);
List<String> categories = new ArrayList<>();
DocumentBuilder documentBuilder = DOMUtil.createDocumentBuilder();
InlineTable inlineTable = new InlineTable();
List<String> columns = Arrays.asList("input", "output");
for (int i = 0; i < categoricalLabel.size(); i++) {
String value = categoricalLabel.getValue(i);
String category = String.valueOf(i);
categories.add(category);
Row row = DOMUtil.createRow(documentBuilder, columns, Arrays.asList(value, category));
inlineTable.addRows(row);
}
MapValues mapValues = new MapValues().addFieldColumnPairs(new FieldColumnPair(pmmlPredictedField.getName(), columns.get(0))).setOutputColumn(columns.get(1)).setInlineTable(inlineTable);
final OutputField predictedField = new OutputField(FieldName.create(predictionCol), DataType.DOUBLE).setOpType(OpType.CATEGORICAL).setResultFeature(ResultFeature.TRANSFORMED_VALUE).setExpression(mapValues);
result.add(predictedField);
Feature feature = new CategoricalFeature(encoder, predictedField.getName(), predictedField.getDataType(), categories) {
@Override
public ContinuousFeature toContinuousFeature() {
PMMLEncoder encoder = ensureEncoder();
return new ContinuousFeature(encoder, getName(), getDataType());
}
};
encoder.putOnlyFeature(predictionCol, feature);
if (model instanceof HasProbabilityCol) {
HasProbabilityCol hasProbabilityCol = (HasProbabilityCol) model;
String probabilityCol = hasProbabilityCol.getProbabilityCol();
List<Feature> features = new ArrayList<>();
for (int i = 0; i < categoricalLabel.size(); i++) {
String value = categoricalLabel.getValue(i);
OutputField probabilityField = ModelUtil.createProbabilityField(FieldName.create(probabilityCol + "(" + value + ")"), DataType.DOUBLE, value);
result.add(probabilityField);
features.add(new ContinuousFeature(encoder, probabilityField.getName(), probabilityField.getDataType()));
}
encoder.putFeatures(probabilityCol, features);
}
return result;
}
use of org.jpmml.converter.Feature in project jpmml-sparkml by jpmml.
the class ConverterUtil method toPMML.
public static PMML toPMML(StructType schema, PipelineModel pipelineModel) {
checkVersion();
SparkMLEncoder encoder = new SparkMLEncoder(schema);
List<org.dmg.pmml.Model> models = new ArrayList<>();
Iterable<Transformer> transformers = getTransformers(pipelineModel);
for (Transformer transformer : transformers) {
TransformerConverter<?> converter = ConverterUtil.createConverter(transformer);
if (converter instanceof FeatureConverter) {
FeatureConverter<?> featureConverter = (FeatureConverter<?>) converter;
featureConverter.registerFeatures(encoder);
} else if (converter instanceof ModelConverter) {
ModelConverter<?> modelConverter = (ModelConverter<?>) converter;
org.dmg.pmml.Model model = modelConverter.registerModel(encoder);
models.add(model);
} else {
throw new IllegalArgumentException("Expected a " + FeatureConverter.class.getName() + " or " + ModelConverter.class.getName() + " instance, got " + converter);
}
}
org.dmg.pmml.Model rootModel;
if (models.size() == 1) {
rootModel = Iterables.getOnlyElement(models);
} else if (models.size() > 1) {
List<MiningField> targetMiningFields = new ArrayList<>();
for (org.dmg.pmml.Model model : models) {
MiningSchema miningSchema = model.getMiningSchema();
List<MiningField> miningFields = miningSchema.getMiningFields();
for (MiningField miningField : miningFields) {
MiningField.UsageType usageType = miningField.getUsageType();
switch(usageType) {
case PREDICTED:
case TARGET:
targetMiningFields.add(miningField);
break;
default:
break;
}
}
}
MiningSchema miningSchema = new MiningSchema(targetMiningFields);
MiningModel miningModel = MiningModelUtil.createModelChain(models, new Schema(null, Collections.<Feature>emptyList())).setMiningSchema(miningSchema);
rootModel = miningModel;
} else {
throw new IllegalArgumentException("Expected a pipeline with one or more models, got a pipeline with zero models");
}
PMML pmml = encoder.encodePMML(rootModel);
return pmml;
}
use of org.jpmml.converter.Feature in project jpmml-sparkml by jpmml.
the class FeatureConverter method registerFeatures.
public void registerFeatures(SparkMLEncoder encoder) {
Transformer transformer = getTransformer();
if (transformer instanceof HasOutputCol) {
HasOutputCol hasOutputCol = (HasOutputCol) transformer;
String outputCol = hasOutputCol.getOutputCol();
List<Feature> features = encodeFeatures(encoder);
encoder.putFeatures(outputCol, features);
} else if (transformer instanceof HasOutputCols) {
HasOutputCols hasOutputCols = (HasOutputCols) transformer;
String[] outputCols = hasOutputCols.getOutputCols();
List<Feature> features = encodeFeatures(encoder);
if (outputCols.length != features.size()) {
throw new IllegalArgumentException("Expected " + outputCols.length + " features, got " + features.size() + " features");
}
for (int i = 0; i < outputCols.length; i++) {
String outputCol = outputCols[i];
Feature feature = features.get(i);
if (feature instanceof BinarizedCategoricalFeature) {
BinarizedCategoricalFeature binarizedCategoricalFeature = (BinarizedCategoricalFeature) feature;
encoder.putFeatures(outputCol, (List) binarizedCategoricalFeature.getBinaryFeatures());
} else {
encoder.putOnlyFeature(outputCol, feature);
}
}
}
}
Aggregations