use of org.jpmml.converter.Feature in project jpmml-sparkml by jpmml.
the class OneHotEncoderModelConverter method encodeFeatures.
@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
OneHotEncoderModel transformer = getTransformer();
String[] inputCols = transformer.getInputCols();
boolean dropLast = transformer.getDropLast();
List<Feature> result = new ArrayList<>();
for (int i = 0; i < inputCols.length; i++) {
CategoricalFeature categoricalFeature = (CategoricalFeature) encoder.getOnlyFeature(inputCols[i]);
List<String> values = categoricalFeature.getValues();
if (dropLast) {
values = values.subList(0, values.size() - 1);
}
List<BinaryFeature> binaryFeatures = new ArrayList<>();
for (String value : values) {
binaryFeatures.add(new BinaryFeature(encoder, categoricalFeature.getName(), DataType.STRING, value));
}
result.add(new BinarizedCategoricalFeature(encoder, categoricalFeature.getName(), categoricalFeature.getDataType(), binaryFeatures));
}
return result;
}
use of org.jpmml.converter.Feature in project jpmml-sparkml by jpmml.
the class RFormulaModelConverter method registerFeatures.
@Override
public void registerFeatures(SparkMLEncoder encoder) {
RFormulaModel transformer = getTransformer();
ResolvedRFormula resolvedFormula = transformer.resolvedFormula();
String targetCol = resolvedFormula.label();
String labelCol = transformer.getLabelCol();
if (!(targetCol).equals(labelCol)) {
List<Feature> features = encoder.getFeatures(targetCol);
encoder.putFeatures(labelCol, features);
}
PipelineModel pipelineModel = transformer.pipelineModel();
Transformer[] stages = pipelineModel.stages();
for (Transformer stage : stages) {
FeatureConverter<?> featureConverter = ConverterUtil.createFeatureConverter(stage);
featureConverter.registerFeatures(encoder);
}
}
use of org.jpmml.converter.Feature in project jpmml-sparkml by jpmml.
the class RegexTokenizerConverter method encodeFeatures.
@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
RegexTokenizer transformer = getTransformer();
if (!transformer.getGaps()) {
throw new IllegalArgumentException("Expected splitter mode, got token matching mode");
}
if (transformer.getMinTokenLength() != 1) {
throw new IllegalArgumentException("Expected 1 as minimum token length, got " + transformer.getMinTokenLength() + " as minimum token length");
}
Feature feature = encoder.getOnlyFeature(transformer.getInputCol());
Field<?> field = encoder.getField(feature.getName());
if (transformer.getToLowercase()) {
Apply apply = PMMLUtil.createApply("lowercase", feature.ref());
field = encoder.createDerivedField(FeatureUtil.createName("lowercase", feature), OpType.CATEGORICAL, DataType.STRING, apply);
}
return Collections.<Feature>singletonList(new DocumentFeature(encoder, field, transformer.getPattern()));
}
use of org.jpmml.converter.Feature in project jpmml-sparkml by jpmml.
the class BinarizerConverter method encodeFeatures.
@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
Binarizer transformer = getTransformer();
Feature feature = encoder.getOnlyFeature(transformer.getInputCol());
ContinuousFeature continuousFeature = feature.toContinuousFeature();
Apply apply = new Apply("if").addExpressions(PMMLUtil.createApply("lessOrEqual", continuousFeature.ref(), PMMLUtil.createConstant(transformer.getThreshold()))).addExpressions(PMMLUtil.createConstant(0d), PMMLUtil.createConstant(1d));
DerivedField derivedField = encoder.createDerivedField(formatName(transformer), OpType.CATEGORICAL, DataType.DOUBLE, apply);
return Collections.<Feature>singletonList(new CategoricalFeature(encoder, derivedField, Arrays.asList("0", "1")));
}
use of org.jpmml.converter.Feature in project jpmml-sparkml by jpmml.
the class ClusteringModelConverter method registerOutputFields.
@Override
public List<OutputField> registerOutputFields(Label label, SparkMLEncoder encoder) {
T model = getTransformer();
String predictionCol = model.getPredictionCol();
OutputField predictedField = ModelUtil.createPredictedField(FieldName.create(predictionCol), DataType.STRING, OpType.CATEGORICAL);
Feature feature = new Feature(encoder, predictedField.getName(), predictedField.getDataType()) {
@Override
public ContinuousFeature toContinuousFeature() {
throw new UnsupportedOperationException();
}
};
encoder.putOnlyFeature(predictionCol, feature);
return Collections.singletonList(predictedField);
}
Aggregations