use of org.apache.spark.ml.Transformer in project jpmml-sparkml by jpmml.
the class RFormulaModelConverter method registerFeatures.
@Override
public void registerFeatures(SparkMLEncoder encoder) {
RFormulaModel transformer = getTransformer();
ResolvedRFormula resolvedFormula = transformer.resolvedFormula();
String targetCol = resolvedFormula.label();
String labelCol = transformer.getLabelCol();
if (!(targetCol).equals(labelCol)) {
List<Feature> features = encoder.getFeatures(targetCol);
encoder.putFeatures(labelCol, features);
}
PipelineModel pipelineModel = transformer.pipelineModel();
Transformer[] stages = pipelineModel.stages();
for (Transformer stage : stages) {
FeatureConverter<?> featureConverter = ConverterUtil.createFeatureConverter(stage);
featureConverter.registerFeatures(encoder);
}
}
use of org.apache.spark.ml.Transformer in project jpmml-sparkml by jpmml.
the class ConverterUtil method getTransformers.
private static Iterable<Transformer> getTransformers(PipelineModel pipelineModel) {
List<Transformer> transformers = new ArrayList<>();
transformers.add(pipelineModel);
Function<Transformer, List<Transformer>> function = new Function<Transformer, List<Transformer>>() {
@Override
public List<Transformer> apply(Transformer transformer) {
if (transformer instanceof PipelineModel) {
PipelineModel pipelineModel = (PipelineModel) transformer;
return Arrays.asList(pipelineModel.stages());
} else if (transformer instanceof CrossValidatorModel) {
CrossValidatorModel crossValidatorModel = (CrossValidatorModel) transformer;
return Collections.<Transformer>singletonList(crossValidatorModel.bestModel());
} else if (transformer instanceof TrainValidationSplitModel) {
TrainValidationSplitModel trainValidationSplitModel = (TrainValidationSplitModel) transformer;
return Collections.<Transformer>singletonList(trainValidationSplitModel.bestModel());
}
return null;
}
};
while (true) {
ListIterator<Transformer> transformerIt = transformers.listIterator();
boolean modified = false;
while (transformerIt.hasNext()) {
Transformer transformer = transformerIt.next();
List<Transformer> childTransformers = function.apply(transformer);
if (childTransformers != null) {
transformerIt.remove();
for (Transformer childTransformer : childTransformers) {
transformerIt.add(childTransformer);
}
modified = true;
}
}
if (!modified) {
break;
}
}
return transformers;
}
use of org.apache.spark.ml.Transformer in project jpmml-sparkml by jpmml.
the class ConverterUtil method toPMML.
public static PMML toPMML(StructType schema, PipelineModel pipelineModel) {
checkVersion();
SparkMLEncoder encoder = new SparkMLEncoder(schema);
List<org.dmg.pmml.Model> models = new ArrayList<>();
Iterable<Transformer> transformers = getTransformers(pipelineModel);
for (Transformer transformer : transformers) {
TransformerConverter<?> converter = ConverterUtil.createConverter(transformer);
if (converter instanceof FeatureConverter) {
FeatureConverter<?> featureConverter = (FeatureConverter<?>) converter;
featureConverter.registerFeatures(encoder);
} else if (converter instanceof ModelConverter) {
ModelConverter<?> modelConverter = (ModelConverter<?>) converter;
org.dmg.pmml.Model model = modelConverter.registerModel(encoder);
models.add(model);
} else {
throw new IllegalArgumentException("Expected a " + FeatureConverter.class.getName() + " or " + ModelConverter.class.getName() + " instance, got " + converter);
}
}
org.dmg.pmml.Model rootModel;
if (models.size() == 1) {
rootModel = Iterables.getOnlyElement(models);
} else if (models.size() > 1) {
List<MiningField> targetMiningFields = new ArrayList<>();
for (org.dmg.pmml.Model model : models) {
MiningSchema miningSchema = model.getMiningSchema();
List<MiningField> miningFields = miningSchema.getMiningFields();
for (MiningField miningField : miningFields) {
MiningField.UsageType usageType = miningField.getUsageType();
switch(usageType) {
case PREDICTED:
case TARGET:
targetMiningFields.add(miningField);
break;
default:
break;
}
}
}
MiningSchema miningSchema = new MiningSchema(targetMiningFields);
MiningModel miningModel = MiningModelUtil.createModelChain(models, new Schema(null, Collections.<Feature>emptyList())).setMiningSchema(miningSchema);
rootModel = miningModel;
} else {
throw new IllegalArgumentException("Expected a pipeline with one or more models, got a pipeline with zero models");
}
PMML pmml = encoder.encodePMML(rootModel);
return pmml;
}
use of org.apache.spark.ml.Transformer in project jpmml-sparkml by jpmml.
the class FeatureConverter method registerFeatures.
public void registerFeatures(SparkMLEncoder encoder) {
Transformer transformer = getTransformer();
if (transformer instanceof HasOutputCol) {
HasOutputCol hasOutputCol = (HasOutputCol) transformer;
String outputCol = hasOutputCol.getOutputCol();
List<Feature> features = encodeFeatures(encoder);
encoder.putFeatures(outputCol, features);
} else if (transformer instanceof HasOutputCols) {
HasOutputCols hasOutputCols = (HasOutputCols) transformer;
String[] outputCols = hasOutputCols.getOutputCols();
List<Feature> features = encodeFeatures(encoder);
if (outputCols.length != features.size()) {
throw new IllegalArgumentException("Expected " + outputCols.length + " features, got " + features.size() + " features");
}
for (int i = 0; i < outputCols.length; i++) {
String outputCol = outputCols[i];
Feature feature = features.get(i);
if (feature instanceof BinarizedCategoricalFeature) {
BinarizedCategoricalFeature binarizedCategoricalFeature = (BinarizedCategoricalFeature) feature;
encoder.putFeatures(outputCol, (List) binarizedCategoricalFeature.getBinaryFeatures());
} else {
encoder.putOnlyFeature(outputCol, feature);
}
}
}
}
Aggregations