Search in sources :

Example 1 with MinMaxScalerModel

use of org.apache.spark.ml.feature.MinMaxScalerModel in project jpmml-sparkml by jpmml.

the class MinMaxScalerModelConverter method encodeFeatures.

@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
    MinMaxScalerModel transformer = getTransformer();
    double rescaleFactor = (transformer.getMax() - transformer.getMin());
    double rescaleConstant = transformer.getMin();
    List<Feature> features = encoder.getFeatures(transformer.getInputCol());
    Vector originalMax = transformer.originalMax();
    if (originalMax.size() != features.size()) {
        throw new IllegalArgumentException();
    }
    Vector originalMin = transformer.originalMin();
    if (originalMin.size() != features.size()) {
        throw new IllegalArgumentException();
    }
    List<Feature> result = new ArrayList<>();
    for (int i = 0; i < features.size(); i++) {
        Feature feature = features.get(i);
        ContinuousFeature continuousFeature = feature.toContinuousFeature();
        double max = originalMax.apply(i);
        double min = originalMin.apply(i);
        Expression expression = PMMLUtil.createApply("/", PMMLUtil.createApply("-", continuousFeature.ref(), PMMLUtil.createConstant(min)), PMMLUtil.createConstant(max - min));
        if (!ValueUtil.isOne(rescaleFactor)) {
            expression = PMMLUtil.createApply("*", expression, PMMLUtil.createConstant(rescaleFactor));
        }
        if (!ValueUtil.isZero(rescaleConstant)) {
            expression = PMMLUtil.createApply("+", expression, PMMLUtil.createConstant(rescaleConstant));
        }
        DerivedField derivedField = encoder.createDerivedField(formatName(transformer, i), OpType.CONTINUOUS, DataType.DOUBLE, expression);
        result.add(new ContinuousFeature(encoder, derivedField));
    }
    return result;
}
Also used : MinMaxScalerModel(org.apache.spark.ml.feature.MinMaxScalerModel) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Expression(org.dmg.pmml.Expression) ArrayList(java.util.ArrayList) Feature(org.jpmml.converter.Feature) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Vector(org.apache.spark.ml.linalg.Vector) DerivedField(org.dmg.pmml.DerivedField)

Aggregations

ArrayList (java.util.ArrayList)1 MinMaxScalerModel (org.apache.spark.ml.feature.MinMaxScalerModel)1 Vector (org.apache.spark.ml.linalg.Vector)1 DerivedField (org.dmg.pmml.DerivedField)1 Expression (org.dmg.pmml.Expression)1 ContinuousFeature (org.jpmml.converter.ContinuousFeature)1 Feature (org.jpmml.converter.Feature)1