use of org.apache.spark.ml.tuning.CrossValidatorModel in project jpmml-sparkml by jpmml.
the class ConverterUtil method getTransformers.
private static Iterable<Transformer> getTransformers(PipelineModel pipelineModel) {
List<Transformer> transformers = new ArrayList<>();
transformers.add(pipelineModel);
Function<Transformer, List<Transformer>> function = new Function<Transformer, List<Transformer>>() {
@Override
public List<Transformer> apply(Transformer transformer) {
if (transformer instanceof PipelineModel) {
PipelineModel pipelineModel = (PipelineModel) transformer;
return Arrays.asList(pipelineModel.stages());
} else if (transformer instanceof CrossValidatorModel) {
CrossValidatorModel crossValidatorModel = (CrossValidatorModel) transformer;
return Collections.<Transformer>singletonList(crossValidatorModel.bestModel());
} else if (transformer instanceof TrainValidationSplitModel) {
TrainValidationSplitModel trainValidationSplitModel = (TrainValidationSplitModel) transformer;
return Collections.<Transformer>singletonList(trainValidationSplitModel.bestModel());
}
return null;
}
};
while (true) {
ListIterator<Transformer> transformerIt = transformers.listIterator();
boolean modified = false;
while (transformerIt.hasNext()) {
Transformer transformer = transformerIt.next();
List<Transformer> childTransformers = function.apply(transformer);
if (childTransformers != null) {
transformerIt.remove();
for (Transformer childTransformer : childTransformers) {
transformerIt.add(childTransformer);
}
modified = true;
}
}
if (!modified) {
break;
}
}
return transformers;
}
Aggregations