Search in sources :

Example 46 with DerivedField

use of org.dmg.pmml.DerivedField in project jpmml-sparkml by jpmml.

the class ExpressionTranslator method translateInternal.

private org.dmg.pmml.Expression translateInternal(Expression expression) {
    SparkMLEncoder encoder = getEncoder();
    if (expression instanceof Alias) {
        Alias alias = (Alias) expression;
        String name = alias.name();
        Expression child = alias.child();
        org.dmg.pmml.Expression pmmlExpression = translateInternal(child);
        return new AliasExpression(name, pmmlExpression);
    }
    if (expression instanceof AttributeReference) {
        AttributeReference attributeReference = (AttributeReference) expression;
        String name = attributeReference.name();
        return new FieldRef(FieldName.create(name));
    } else if (expression instanceof BinaryMathExpression) {
        BinaryMathExpression binaryMathExpression = (BinaryMathExpression) expression;
        Expression left = binaryMathExpression.left();
        Expression right = binaryMathExpression.right();
        String function;
        if (binaryMathExpression instanceof Hypot) {
            function = PMMLFunctions.HYPOT;
        } else if (binaryMathExpression instanceof Pow) {
            function = PMMLFunctions.POW;
        } else {
            throw new IllegalArgumentException(formatMessage(binaryMathExpression));
        }
        return PMMLUtil.createApply(function, translateInternal(left), translateInternal(right));
    } else if (expression instanceof BinaryOperator) {
        BinaryOperator binaryOperator = (BinaryOperator) expression;
        String symbol = binaryOperator.symbol();
        Expression left = binaryOperator.left();
        Expression right = binaryOperator.right();
        String function;
        if (expression instanceof And || expression instanceof Or) {
            switch(symbol) {
                case "&&":
                    function = PMMLFunctions.AND;
                    break;
                case "||":
                    function = PMMLFunctions.OR;
                    break;
                default:
                    throw new IllegalArgumentException(formatMessage(binaryOperator));
            }
        } else if (expression instanceof Add || expression instanceof Divide || expression instanceof Multiply || expression instanceof Subtract) {
            BinaryArithmetic binaryArithmetic = (BinaryArithmetic) binaryOperator;
            switch(symbol) {
                case "+":
                    function = PMMLFunctions.ADD;
                    break;
                case "/":
                    function = PMMLFunctions.DIVIDE;
                    break;
                case "*":
                    function = PMMLFunctions.MULTIPLY;
                    break;
                case "-":
                    function = PMMLFunctions.SUBTRACT;
                    break;
                default:
                    throw new IllegalArgumentException(formatMessage(binaryArithmetic));
            }
        } else if (expression instanceof EqualTo || expression instanceof GreaterThan || expression instanceof GreaterThanOrEqual || expression instanceof LessThan || expression instanceof LessThanOrEqual) {
            BinaryComparison binaryComparison = (BinaryComparison) binaryOperator;
            switch(symbol) {
                case "=":
                    function = PMMLFunctions.EQUAL;
                    break;
                case ">":
                    function = PMMLFunctions.GREATERTHAN;
                    break;
                case ">=":
                    function = PMMLFunctions.GREATEROREQUAL;
                    break;
                case "<":
                    function = PMMLFunctions.LESSTHAN;
                    break;
                case "<=":
                    function = PMMLFunctions.LESSOREQUAL;
                    break;
                default:
                    throw new IllegalArgumentException(formatMessage(binaryComparison));
            }
        } else {
            throw new IllegalArgumentException(formatMessage(binaryOperator));
        }
        return PMMLUtil.createApply(function, translateInternal(left), translateInternal(right));
    } else if (expression instanceof CaseWhen) {
        CaseWhen caseWhen = (CaseWhen) expression;
        List<Tuple2<Expression, Expression>> branches = JavaConversions.seqAsJavaList(caseWhen.branches());
        Option<Expression> elseValue = caseWhen.elseValue();
        Apply apply = null;
        Iterator<Tuple2<Expression, Expression>> branchIt = branches.iterator();
        Apply prevBranchApply = null;
        do {
            Tuple2<Expression, Expression> branch = branchIt.next();
            Expression predicate = branch._1();
            Expression value = branch._2();
            Apply branchApply = PMMLUtil.createApply(PMMLFunctions.IF, translateInternal(predicate), translateInternal(value));
            if (apply == null) {
                apply = branchApply;
            }
            if (prevBranchApply != null) {
                prevBranchApply.addExpressions(branchApply);
            }
            prevBranchApply = branchApply;
        } while (branchIt.hasNext());
        if (elseValue.isDefined()) {
            Expression value = elseValue.get();
            prevBranchApply.addExpressions(translateInternal(value));
        }
        return apply;
    } else if (expression instanceof Cast) {
        Cast cast = (Cast) expression;
        Expression child = cast.child();
        org.dmg.pmml.Expression pmmlExpression = translateInternal(child);
        DataType dataType = DatasetUtil.translateDataType(cast.dataType());
        if (pmmlExpression instanceof HasDataType) {
            HasDataType<?> hasDataType = (HasDataType<?>) pmmlExpression;
            hasDataType.setDataType(dataType);
            return pmmlExpression;
        } else {
            FieldName name;
            if (pmmlExpression instanceof AliasExpression) {
                AliasExpression aliasExpression = (AliasExpression) pmmlExpression;
                name = FieldName.create(aliasExpression.getName());
            } else {
                name = FieldNameUtil.create(dataType, ExpressionUtil.format(child));
            }
            OpType opType = ExpressionUtil.getOpType(dataType);
            pmmlExpression = AliasExpression.unwrap(pmmlExpression);
            DerivedField derivedField = encoder.createDerivedField(name, opType, dataType, pmmlExpression);
            return new FieldRef(derivedField.getName());
        }
    } else if (expression instanceof Concat) {
        Concat concat = (Concat) expression;
        List<Expression> children = JavaConversions.seqAsJavaList(concat.children());
        Apply apply = PMMLUtil.createApply(PMMLFunctions.CONCAT);
        for (Expression child : children) {
            apply.addExpressions(translateInternal(child));
        }
        return apply;
    } else if (expression instanceof Greatest) {
        Greatest greatest = (Greatest) expression;
        List<Expression> children = JavaConversions.seqAsJavaList(greatest.children());
        Apply apply = PMMLUtil.createApply(PMMLFunctions.MAX);
        for (Expression child : children) {
            apply.addExpressions(translateInternal(child));
        }
        return apply;
    } else if (expression instanceof If) {
        If _if = (If) expression;
        Expression predicate = _if.predicate();
        Expression trueValue = _if.trueValue();
        Expression falseValue = _if.falseValue();
        return PMMLUtil.createApply(PMMLFunctions.IF, translateInternal(predicate), translateInternal(trueValue), translateInternal(falseValue));
    } else if (expression instanceof In) {
        In in = (In) expression;
        Expression value = in.value();
        List<Expression> elements = JavaConversions.seqAsJavaList(in.list());
        Apply apply = PMMLUtil.createApply(PMMLFunctions.ISIN, translateInternal(value));
        for (Expression element : elements) {
            apply.addExpressions(translateInternal(element));
        }
        return apply;
    } else if (expression instanceof Least) {
        Least least = (Least) expression;
        List<Expression> children = JavaConversions.seqAsJavaList(least.children());
        Apply apply = PMMLUtil.createApply(PMMLFunctions.MIN);
        for (Expression child : children) {
            apply.addExpressions(translateInternal(child));
        }
        return apply;
    } else if (expression instanceof Length) {
        Length length = (Length) expression;
        Expression child = length.child();
        return PMMLUtil.createApply(PMMLFunctions.STRINGLENGTH, translateInternal(child));
    } else if (expression instanceof Literal) {
        Literal literal = (Literal) expression;
        Object value = literal.value();
        if (value == null) {
            return PMMLUtil.createMissingConstant();
        }
        DataType dataType;
        // XXX
        if (value instanceof Decimal) {
            Decimal decimal = (Decimal) value;
            dataType = DataType.STRING;
            value = decimal.toString();
        } else {
            dataType = DatasetUtil.translateDataType(literal.dataType());
            value = toSimpleObject(value);
        }
        return PMMLUtil.createConstant(value, dataType);
    } else if (expression instanceof RegExpReplace) {
        RegExpReplace regexpReplace = (RegExpReplace) expression;
        Expression subject = regexpReplace.subject();
        Expression regexp = regexpReplace.regexp();
        Expression rep = regexpReplace.rep();
        return PMMLUtil.createApply(PMMLFunctions.REPLACE, translateInternal(subject), translateInternal(regexp), translateInternal(rep));
    } else if (expression instanceof RLike) {
        RLike rlike = (RLike) expression;
        Expression left = rlike.left();
        Expression right = rlike.right();
        return PMMLUtil.createApply(PMMLFunctions.MATCHES, translateInternal(left), translateInternal(right));
    } else if (expression instanceof StringTrim) {
        StringTrim stringTrim = (StringTrim) expression;
        Expression srcStr = stringTrim.srcStr();
        Option<Expression> trimStr = stringTrim.trimStr();
        if (trimStr.isDefined()) {
            throw new IllegalArgumentException();
        }
        return PMMLUtil.createApply(PMMLFunctions.TRIMBLANKS, translateInternal(srcStr));
    } else if (expression instanceof Substring) {
        Substring substring = (Substring) expression;
        Expression str = substring.str();
        Literal pos = (Literal) substring.pos();
        Literal len = (Literal) substring.len();
        int posValue = ValueUtil.asInt((Number) pos.value());
        if (posValue <= 0) {
            throw new IllegalArgumentException("Expected absolute start position, got relative start position " + (pos));
        }
        int lenValue = ValueUtil.asInt((Number) len.value());
        // XXX
        lenValue = Math.min(lenValue, MAX_STRING_LENGTH);
        return PMMLUtil.createApply(PMMLFunctions.SUBSTRING, translateInternal(str), PMMLUtil.createConstant(posValue), PMMLUtil.createConstant(lenValue));
    } else if (expression instanceof UnaryExpression) {
        UnaryExpression unaryExpression = (UnaryExpression) expression;
        Expression child = unaryExpression.child();
        if (expression instanceof Abs) {
            return PMMLUtil.createApply(PMMLFunctions.ABS, translateInternal(child));
        } else if (expression instanceof Acos) {
            return PMMLUtil.createApply(PMMLFunctions.ACOS, translateInternal(child));
        } else if (expression instanceof Asin) {
            return PMMLUtil.createApply(PMMLFunctions.ASIN, translateInternal(child));
        } else if (expression instanceof Atan) {
            return PMMLUtil.createApply(PMMLFunctions.ATAN, translateInternal(child));
        } else if (expression instanceof Ceil) {
            return PMMLUtil.createApply(PMMLFunctions.CEIL, translateInternal(child));
        } else if (expression instanceof Cos) {
            return PMMLUtil.createApply(PMMLFunctions.COS, translateInternal(child));
        } else if (expression instanceof Cosh) {
            return PMMLUtil.createApply(PMMLFunctions.COSH, translateInternal(child));
        } else if (expression instanceof Exp) {
            return PMMLUtil.createApply(PMMLFunctions.EXP, translateInternal(child));
        } else if (expression instanceof Expm1) {
            return PMMLUtil.createApply(PMMLFunctions.EXPM1, translateInternal(child));
        } else if (expression instanceof Floor) {
            return PMMLUtil.createApply(PMMLFunctions.FLOOR, translateInternal(child));
        } else if (expression instanceof Log) {
            return PMMLUtil.createApply(PMMLFunctions.LN, translateInternal(child));
        } else if (expression instanceof Log10) {
            return PMMLUtil.createApply(PMMLFunctions.LOG10, translateInternal(child));
        } else if (expression instanceof Log1p) {
            return PMMLUtil.createApply(PMMLFunctions.LN1P, translateInternal(child));
        } else if (expression instanceof Lower) {
            return PMMLUtil.createApply(PMMLFunctions.LOWERCASE, translateInternal(child));
        } else if (expression instanceof IsNaN) {
            // XXX
            return PMMLUtil.createApply(PMMLFunctions.ISNOTVALID, translateInternal(child));
        } else if (expression instanceof IsNotNull) {
            return PMMLUtil.createApply(PMMLFunctions.ISNOTMISSING, translateInternal(child));
        } else if (expression instanceof IsNull) {
            return PMMLUtil.createApply(PMMLFunctions.ISMISSING, translateInternal(child));
        } else if (expression instanceof Not) {
            return PMMLUtil.createApply(PMMLFunctions.NOT, translateInternal(child));
        } else if (expression instanceof Rint) {
            return PMMLUtil.createApply(PMMLFunctions.RINT, translateInternal(child));
        } else if (expression instanceof Sin) {
            return PMMLUtil.createApply(PMMLFunctions.SIN, translateInternal(child));
        } else if (expression instanceof Sinh) {
            return PMMLUtil.createApply(PMMLFunctions.SINH, translateInternal(child));
        } else if (expression instanceof Sqrt) {
            return PMMLUtil.createApply(PMMLFunctions.SQRT, translateInternal(child));
        } else if (expression instanceof Tan) {
            return PMMLUtil.createApply(PMMLFunctions.TAN, translateInternal(child));
        } else if (expression instanceof Tanh) {
            return PMMLUtil.createApply(PMMLFunctions.TANH, translateInternal(child));
        } else if (expression instanceof UnaryMinus) {
            return PMMLUtil.toNegative(translateInternal(child));
        } else if (expression instanceof UnaryPositive) {
            return translateInternal(child);
        } else if (expression instanceof Upper) {
            return PMMLUtil.createApply(PMMLFunctions.UPPERCASE, translateInternal(child));
        } else {
            throw new IllegalArgumentException(formatMessage(unaryExpression));
        }
    } else {
        throw new IllegalArgumentException(formatMessage(expression));
    }
}
Also used : Add(org.apache.spark.sql.catalyst.expressions.Add) Tan(org.apache.spark.sql.catalyst.expressions.Tan) RegExpReplace(org.apache.spark.sql.catalyst.expressions.RegExpReplace) Lower(org.apache.spark.sql.catalyst.expressions.Lower) Or(org.apache.spark.sql.catalyst.expressions.Or) Apply(org.dmg.pmml.Apply) CaseWhen(org.apache.spark.sql.catalyst.expressions.CaseWhen) Concat(org.apache.spark.sql.catalyst.expressions.Concat) Divide(org.apache.spark.sql.catalyst.expressions.Divide) LessThan(org.apache.spark.sql.catalyst.expressions.LessThan) IsNotNull(org.apache.spark.sql.catalyst.expressions.IsNotNull) Decimal(org.apache.spark.sql.types.Decimal) GreaterThan(org.apache.spark.sql.catalyst.expressions.GreaterThan) Multiply(org.apache.spark.sql.catalyst.expressions.Multiply) Least(org.apache.spark.sql.catalyst.expressions.Least) Literal(org.apache.spark.sql.catalyst.expressions.Literal) Sinh(org.apache.spark.sql.catalyst.expressions.Sinh) Greatest(org.apache.spark.sql.catalyst.expressions.Greatest) List(java.util.List) BinaryOperator(org.apache.spark.sql.catalyst.expressions.BinaryOperator) BinaryArithmetic(org.apache.spark.sql.catalyst.expressions.BinaryArithmetic) BinaryMathExpression(org.apache.spark.sql.catalyst.expressions.BinaryMathExpression) Substring(org.apache.spark.sql.catalyst.expressions.Substring) Tanh(org.apache.spark.sql.catalyst.expressions.Tanh) Log10(org.apache.spark.sql.catalyst.expressions.Log10) FieldRef(org.dmg.pmml.FieldRef) Expm1(org.apache.spark.sql.catalyst.expressions.Expm1) Log(org.apache.spark.sql.catalyst.expressions.Log) AttributeReference(org.apache.spark.sql.catalyst.expressions.AttributeReference) UnaryMinus(org.apache.spark.sql.catalyst.expressions.UnaryMinus) Hypot(org.apache.spark.sql.catalyst.expressions.Hypot) RLike(org.apache.spark.sql.catalyst.expressions.RLike) EqualTo(org.apache.spark.sql.catalyst.expressions.EqualTo) Not(org.apache.spark.sql.catalyst.expressions.Not) And(org.apache.spark.sql.catalyst.expressions.And) BinaryComparison(org.apache.spark.sql.catalyst.expressions.BinaryComparison) Pow(org.apache.spark.sql.catalyst.expressions.Pow) Subtract(org.apache.spark.sql.catalyst.expressions.Subtract) Acos(org.apache.spark.sql.catalyst.expressions.Acos) Sin(org.apache.spark.sql.catalyst.expressions.Sin) Option(scala.Option) Ceil(org.apache.spark.sql.catalyst.expressions.Ceil) If(org.apache.spark.sql.catalyst.expressions.If) Cast(org.apache.spark.sql.catalyst.expressions.Cast) Cosh(org.apache.spark.sql.catalyst.expressions.Cosh) LessThanOrEqual(org.apache.spark.sql.catalyst.expressions.LessThanOrEqual) In(org.apache.spark.sql.catalyst.expressions.In) UnaryExpression(org.apache.spark.sql.catalyst.expressions.UnaryExpression) GreaterThanOrEqual(org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual) Abs(org.apache.spark.sql.catalyst.expressions.Abs) Iterator(java.util.Iterator) HasDataType(org.dmg.pmml.HasDataType) DataType(org.dmg.pmml.DataType) FieldName(org.dmg.pmml.FieldName) Atan(org.apache.spark.sql.catalyst.expressions.Atan) UnaryPositive(org.apache.spark.sql.catalyst.expressions.UnaryPositive) Upper(org.apache.spark.sql.catalyst.expressions.Upper) Floor(org.apache.spark.sql.catalyst.expressions.Floor) Log1p(org.apache.spark.sql.catalyst.expressions.Log1p) Cos(org.apache.spark.sql.catalyst.expressions.Cos) Sqrt(org.apache.spark.sql.catalyst.expressions.Sqrt) Asin(org.apache.spark.sql.catalyst.expressions.Asin) IsNaN(org.apache.spark.sql.catalyst.expressions.IsNaN) Expression(org.apache.spark.sql.catalyst.expressions.Expression) UnaryExpression(org.apache.spark.sql.catalyst.expressions.UnaryExpression) BinaryMathExpression(org.apache.spark.sql.catalyst.expressions.BinaryMathExpression) Length(org.apache.spark.sql.catalyst.expressions.Length) Alias(org.apache.spark.sql.catalyst.expressions.Alias) Tuple2(scala.Tuple2) OpType(org.dmg.pmml.OpType) IsNull(org.apache.spark.sql.catalyst.expressions.IsNull) Rint(org.apache.spark.sql.catalyst.expressions.Rint) Exp(org.apache.spark.sql.catalyst.expressions.Exp) DerivedField(org.dmg.pmml.DerivedField) StringTrim(org.apache.spark.sql.catalyst.expressions.StringTrim) HasDataType(org.dmg.pmml.HasDataType)

Example 47 with DerivedField

use of org.dmg.pmml.DerivedField in project jpmml-sparkml by jpmml.

the class PMMLBuilder method build.

public PMML build() {
    StructType schema = getSchema();
    PipelineModel pipelineModel = getPipelineModel();
    Map<RegexKey, ? extends Map<String, ?>> options = getOptions();
    Verification verification = getVerification();
    ConverterFactory converterFactory = new ConverterFactory(options);
    SparkMLEncoder encoder = new SparkMLEncoder(schema, converterFactory);
    Map<FieldName, DerivedField> derivedFields = encoder.getDerivedFields();
    List<org.dmg.pmml.Model> models = new ArrayList<>();
    List<String> predictionColumns = new ArrayList<>();
    List<String> probabilityColumns = new ArrayList<>();
    // Transformations preceding the last model
    List<FieldName> preProcessorNames = Collections.emptyList();
    Iterable<Transformer> transformers = getTransformers(pipelineModel);
    for (Transformer transformer : transformers) {
        TransformerConverter<?> converter = converterFactory.newConverter(transformer);
        if (converter instanceof FeatureConverter) {
            FeatureConverter<?> featureConverter = (FeatureConverter<?>) converter;
            featureConverter.registerFeatures(encoder);
        } else if (converter instanceof ModelConverter) {
            ModelConverter<?> modelConverter = (ModelConverter<?>) converter;
            org.dmg.pmml.Model model = modelConverter.registerModel(encoder);
            models.add(model);
            featureImportances: if (modelConverter instanceof HasFeatureImportances) {
                HasFeatureImportances hasFeatureImportances = (HasFeatureImportances) modelConverter;
                Boolean estimateFeatureImportances = (Boolean) modelConverter.getOption(HasTreeOptions.OPTION_ESTIMATE_FEATURE_IMPORTANCES, Boolean.FALSE);
                if (!estimateFeatureImportances) {
                    break featureImportances;
                }
                List<Double> featureImportances = VectorUtil.toList(hasFeatureImportances.getFeatureImportances());
                List<Feature> features = modelConverter.getFeatures(encoder);
                SchemaUtil.checkSize(featureImportances.size(), features);
                for (int i = 0; i < featureImportances.size(); i++) {
                    Double featureImportance = featureImportances.get(i);
                    Feature feature = features.get(i);
                    encoder.addFeatureImportance(model, feature, featureImportance);
                }
            }
            hasPredictionCol: if (transformer instanceof HasPredictionCol) {
                HasPredictionCol hasPredictionCol = (HasPredictionCol) transformer;
                // XXX
                if ((transformer instanceof GeneralizedLinearRegressionModel) && (MiningFunction.CLASSIFICATION).equals(model.getMiningFunction())) {
                    break hasPredictionCol;
                }
                predictionColumns.add(hasPredictionCol.getPredictionCol());
            }
            if (transformer instanceof HasProbabilityCol) {
                HasProbabilityCol hasProbabilityCol = (HasProbabilityCol) transformer;
                probabilityColumns.add(hasProbabilityCol.getProbabilityCol());
            }
            preProcessorNames = new ArrayList<>(derivedFields.keySet());
        } else {
            throw new IllegalArgumentException("Expected a subclass of " + FeatureConverter.class.getName() + " or " + ModelConverter.class.getName() + ", got " + (converter != null ? ("class " + (converter.getClass()).getName()) : null));
        }
    }
    // Transformations following the last model
    List<FieldName> postProcessorNames = new ArrayList<>(derivedFields.keySet());
    postProcessorNames.removeAll(preProcessorNames);
    org.dmg.pmml.Model model;
    if (models.size() == 0) {
        model = null;
    } else if (models.size() == 1) {
        model = Iterables.getOnlyElement(models);
    } else {
        model = MiningModelUtil.createModelChain(models);
    }
    if ((model != null) && (postProcessorNames.size() > 0)) {
        org.dmg.pmml.Model finalModel = MiningModelUtil.getFinalModel(model);
        Output output = ModelUtil.ensureOutput(finalModel);
        for (FieldName postProcessorName : postProcessorNames) {
            DerivedField derivedField = derivedFields.get(postProcessorName);
            encoder.removeDerivedField(postProcessorName);
            OutputField outputField = new OutputField(derivedField.getName(), derivedField.getOpType(), derivedField.getDataType()).setResultFeature(ResultFeature.TRANSFORMED_VALUE).setExpression(derivedField.getExpression());
            output.addOutputFields(outputField);
        }
    }
    PMML pmml = encoder.encodePMML(model);
    if ((model != null) && (predictionColumns.size() > 0 || probabilityColumns.size() > 0) && (verification != null)) {
        Dataset<Row> dataset = verification.getDataset();
        Dataset<Row> transformedDataset = verification.getTransformedDataset();
        Double precision = verification.getPrecision();
        Double zeroThreshold = verification.getZeroThreshold();
        List<String> inputColumns = new ArrayList<>();
        MiningSchema miningSchema = model.getMiningSchema();
        List<MiningField> miningFields = miningSchema.getMiningFields();
        for (MiningField miningField : miningFields) {
            MiningField.UsageType usageType = miningField.getUsageType();
            switch(usageType) {
                case ACTIVE:
                    FieldName name = miningField.getName();
                    inputColumns.add(name.getValue());
                    break;
                default:
                    break;
            }
        }
        Map<VerificationField, List<?>> data = new LinkedHashMap<>();
        for (String inputColumn : inputColumns) {
            VerificationField verificationField = ModelUtil.createVerificationField(FieldName.create(inputColumn));
            data.put(verificationField, getColumn(dataset, inputColumn));
        }
        for (String predictionColumn : predictionColumns) {
            Feature feature = encoder.getOnlyFeature(predictionColumn);
            VerificationField verificationField = ModelUtil.createVerificationField(feature.getName()).setPrecision(precision).setZeroThreshold(zeroThreshold);
            data.put(verificationField, getColumn(transformedDataset, predictionColumn));
        }
        for (String probabilityColumn : probabilityColumns) {
            List<Feature> features = encoder.getFeatures(probabilityColumn);
            for (int i = 0; i < features.size(); i++) {
                Feature feature = features.get(i);
                VerificationField verificationField = ModelUtil.createVerificationField(feature.getName()).setPrecision(precision).setZeroThreshold(zeroThreshold);
                data.put(verificationField, getVectorColumn(transformedDataset, probabilityColumn, i));
            }
        }
        model.setModelVerification(ModelUtil.createModelVerification(data));
    }
    return pmml;
}
Also used : MiningField(org.dmg.pmml.MiningField) StructType(org.apache.spark.sql.types.StructType) HasProbabilityCol(org.apache.spark.ml.param.shared.HasProbabilityCol) GeneralizedLinearRegressionModel(org.apache.spark.ml.regression.GeneralizedLinearRegressionModel) ArrayList(java.util.ArrayList) ResultFeature(org.dmg.pmml.ResultFeature) Feature(org.jpmml.converter.Feature) LinkedHashMap(java.util.LinkedHashMap) HasFeatureImportances(org.jpmml.sparkml.model.HasFeatureImportances) ArrayList(java.util.ArrayList) List(java.util.List) HasPredictionCol(org.apache.spark.ml.param.shared.HasPredictionCol) MiningSchema(org.dmg.pmml.MiningSchema) OutputField(org.dmg.pmml.OutputField) Row(org.apache.spark.sql.Row) Transformer(org.apache.spark.ml.Transformer) PipelineModel(org.apache.spark.ml.PipelineModel) Output(org.dmg.pmml.Output) FieldName(org.dmg.pmml.FieldName) VerificationField(org.dmg.pmml.VerificationField) GeneralizedLinearRegressionModel(org.apache.spark.ml.regression.GeneralizedLinearRegressionModel) PipelineModel(org.apache.spark.ml.PipelineModel) TrainValidationSplitModel(org.apache.spark.ml.tuning.TrainValidationSplitModel) CrossValidatorModel(org.apache.spark.ml.tuning.CrossValidatorModel) PMML(org.dmg.pmml.PMML) DerivedField(org.dmg.pmml.DerivedField)

Example 48 with DerivedField

use of org.dmg.pmml.DerivedField in project jpmml-sparkml by jpmml.

the class PCAModelConverter method encodeFeatures.

@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
    PCAModel transformer = getTransformer();
    DenseMatrix pc = transformer.pc();
    List<Feature> features = encoder.getFeatures(transformer.getInputCol());
    MatrixUtil.checkRows(features.size(), pc);
    List<Feature> result = new ArrayList<>();
    for (int i = 0, length = transformer.getK(); i < length; i++) {
        Apply apply = PMMLUtil.createApply(PMMLFunctions.SUM);
        for (int j = 0; j < features.size(); j++) {
            Feature feature = features.get(j);
            ContinuousFeature continuousFeature = feature.toContinuousFeature();
            Expression expression = continuousFeature.ref();
            Double coefficient = pc.apply(j, i);
            if (!ValueUtil.isOne(coefficient)) {
                expression = PMMLUtil.createApply(PMMLFunctions.MULTIPLY, expression, PMMLUtil.createConstant(coefficient));
            }
            apply.addExpressions(expression);
        }
        DerivedField derivedField = encoder.createDerivedField(formatName(transformer, i, length), OpType.CONTINUOUS, DataType.DOUBLE, apply);
        result.add(new ContinuousFeature(encoder, derivedField));
    }
    return result;
}
Also used : PCAModel(org.apache.spark.ml.feature.PCAModel) ContinuousFeature(org.jpmml.converter.ContinuousFeature) Expression(org.dmg.pmml.Expression) Apply(org.dmg.pmml.Apply) ArrayList(java.util.ArrayList) Feature(org.jpmml.converter.Feature) ContinuousFeature(org.jpmml.converter.ContinuousFeature) DerivedField(org.dmg.pmml.DerivedField) DenseMatrix(org.apache.spark.ml.linalg.DenseMatrix)

Example 49 with DerivedField

use of org.dmg.pmml.DerivedField in project jpmml-sparkml by jpmml.

the class StringIndexerModelConverter method encodeFeatures.

@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
    StringIndexerModel transformer = getTransformer();
    String[][] labelsArray = transformer.labelsArray();
    InOutMode inputMode = getInputMode();
    List<Feature> result = new ArrayList<>();
    String[] inputCols = inputMode.getInputCols(transformer);
    for (int i = 0; i < inputCols.length; i++) {
        String inputCol = inputCols[i];
        String[] labels = labelsArray[i];
        Feature feature = encoder.getOnlyFeature(inputCol);
        List<String> categories = new ArrayList<>();
        categories.addAll(Arrays.asList(labels));
        String invalidCategory;
        DataType dataType = feature.getDataType();
        switch(dataType) {
            case INTEGER:
            case FLOAT:
            case DOUBLE:
                invalidCategory = "-999";
                break;
            default:
                invalidCategory = "__unknown";
                break;
        }
        String handleInvalid = transformer.getHandleInvalid();
        Field<?> field = encoder.toCategorical(feature.getName(), categories);
        if (field instanceof DataField) {
            DataField dataField = (DataField) field;
            InvalidValueDecorator invalidValueDecorator;
            switch(handleInvalid) {
                case "keep":
                    {
                        invalidValueDecorator = new InvalidValueDecorator(InvalidValueTreatmentMethod.AS_VALUE, invalidCategory);
                        categories.add(invalidCategory);
                    }
                    break;
                case "error":
                    {
                        invalidValueDecorator = new InvalidValueDecorator(InvalidValueTreatmentMethod.RETURN_INVALID, null);
                    }
                    break;
                default:
                    throw new IllegalArgumentException("Invalid value handling strategy " + handleInvalid + " is not supported");
            }
            encoder.addDecorator(dataField, invalidValueDecorator);
        } else if (field instanceof DerivedField) {
            switch(handleInvalid) {
                case "keep":
                    {
                        Apply setApply = PMMLUtil.createApply(PMMLFunctions.ISIN, feature.ref());
                        for (String category : categories) {
                            setApply.addExpressions(PMMLUtil.createConstant(category, dataType));
                        }
                        categories.add(invalidCategory);
                        Apply apply = PMMLUtil.createApply(PMMLFunctions.IF).addExpressions(setApply).addExpressions(feature.ref(), PMMLUtil.createConstant(invalidCategory, dataType));
                        field = encoder.createDerivedField(FieldNameUtil.create("handleInvalid", feature), OpType.CATEGORICAL, dataType, apply);
                    }
                    break;
                case "error":
                    {
                    // Ignored: Assume that a DerivedField element can never return an erroneous field value
                    }
                    break;
                default:
                    throw new IllegalArgumentException(handleInvalid);
            }
        } else {
            throw new IllegalArgumentException();
        }
        result.add(new CategoricalFeature(encoder, field, categories));
    }
    return result;
}
Also used : Apply(org.dmg.pmml.Apply) ArrayList(java.util.ArrayList) Feature(org.jpmml.converter.Feature) CategoricalFeature(org.jpmml.converter.CategoricalFeature) StringIndexerModel(org.apache.spark.ml.feature.StringIndexerModel) CategoricalFeature(org.jpmml.converter.CategoricalFeature) InvalidValueDecorator(org.jpmml.converter.InvalidValueDecorator) DataField(org.dmg.pmml.DataField) DataType(org.dmg.pmml.DataType) DerivedField(org.dmg.pmml.DerivedField)

Example 50 with DerivedField

use of org.dmg.pmml.DerivedField in project jpmml-sparkml by jpmml.

the class MaxAbsScalerModelConverter method encodeFeatures.

@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
    MaxAbsScalerModel transformer = getTransformer();
    Vector maxAbs = transformer.maxAbs();
    List<Feature> features = encoder.getFeatures(transformer.getInputCol());
    SchemaUtil.checkSize(maxAbs.size(), features);
    List<Feature> result = new ArrayList<>();
    for (int i = 0, length = features.size(); i < length; i++) {
        Feature feature = features.get(i);
        double maxAbsUnzero = maxAbs.apply(i);
        if (maxAbsUnzero == 0d) {
            maxAbsUnzero = 1d;
        }
        if (!ValueUtil.isOne(maxAbsUnzero)) {
            ContinuousFeature continuousFeature = feature.toContinuousFeature();
            Expression expression = PMMLUtil.createApply(PMMLFunctions.DIVIDE, continuousFeature.ref(), PMMLUtil.createConstant(maxAbsUnzero));
            DerivedField derivedField = encoder.createDerivedField(formatName(transformer, i, length), OpType.CONTINUOUS, DataType.DOUBLE, expression);
            feature = new ContinuousFeature(encoder, derivedField);
        }
        result.add(feature);
    }
    return result;
}
Also used : ContinuousFeature(org.jpmml.converter.ContinuousFeature) Expression(org.dmg.pmml.Expression) MaxAbsScalerModel(org.apache.spark.ml.feature.MaxAbsScalerModel) ArrayList(java.util.ArrayList) Vector(org.apache.spark.ml.linalg.Vector) Feature(org.jpmml.converter.Feature) ContinuousFeature(org.jpmml.converter.ContinuousFeature) DerivedField(org.dmg.pmml.DerivedField)

Aggregations

DerivedField (org.dmg.pmml.DerivedField)54 ArrayList (java.util.ArrayList)21 Feature (org.jpmml.converter.Feature)18 ContinuousFeature (org.jpmml.converter.ContinuousFeature)17 FieldName (org.dmg.pmml.FieldName)15 Apply (org.dmg.pmml.Apply)11 Expression (org.dmg.pmml.Expression)10 DataField (org.dmg.pmml.DataField)8 Test (org.junit.Test)8 KiePMMLDerivedField (org.kie.pmml.commons.transformations.KiePMMLDerivedField)8 List (java.util.List)7 Constant (org.dmg.pmml.Constant)7 FieldRef (org.dmg.pmml.FieldRef)6 MapValues (org.dmg.pmml.MapValues)6 NormContinuous (org.dmg.pmml.NormContinuous)6 BlockStmt (com.github.javaparser.ast.stmt.BlockStmt)5 PMML (org.dmg.pmml.PMML)5 DataType (org.dmg.pmml.DataType)4 Discretize (org.dmg.pmml.Discretize)4 Statement (com.github.javaparser.ast.stmt.Statement)3