Search in sources :

Example 1 with BinaryComparison

use of org.apache.spark.sql.catalyst.expressions.BinaryComparison in project jpmml-sparkml by jpmml.

the class ExpressionTranslator method translateInternal.

private org.dmg.pmml.Expression translateInternal(Expression expression) {
    SparkMLEncoder encoder = getEncoder();
    if (expression instanceof Alias) {
        Alias alias = (Alias) expression;
        String name = alias.name();
        Expression child = alias.child();
        org.dmg.pmml.Expression pmmlExpression = translateInternal(child);
        return new AliasExpression(name, pmmlExpression);
    }
    if (expression instanceof AttributeReference) {
        AttributeReference attributeReference = (AttributeReference) expression;
        String name = attributeReference.name();
        return new FieldRef(FieldName.create(name));
    } else if (expression instanceof BinaryMathExpression) {
        BinaryMathExpression binaryMathExpression = (BinaryMathExpression) expression;
        Expression left = binaryMathExpression.left();
        Expression right = binaryMathExpression.right();
        String function;
        if (binaryMathExpression instanceof Hypot) {
            function = PMMLFunctions.HYPOT;
        } else if (binaryMathExpression instanceof Pow) {
            function = PMMLFunctions.POW;
        } else {
            throw new IllegalArgumentException(formatMessage(binaryMathExpression));
        }
        return PMMLUtil.createApply(function, translateInternal(left), translateInternal(right));
    } else if (expression instanceof BinaryOperator) {
        BinaryOperator binaryOperator = (BinaryOperator) expression;
        String symbol = binaryOperator.symbol();
        Expression left = binaryOperator.left();
        Expression right = binaryOperator.right();
        String function;
        if (expression instanceof And || expression instanceof Or) {
            switch(symbol) {
                case "&&":
                    function = PMMLFunctions.AND;
                    break;
                case "||":
                    function = PMMLFunctions.OR;
                    break;
                default:
                    throw new IllegalArgumentException(formatMessage(binaryOperator));
            }
        } else if (expression instanceof Add || expression instanceof Divide || expression instanceof Multiply || expression instanceof Subtract) {
            BinaryArithmetic binaryArithmetic = (BinaryArithmetic) binaryOperator;
            switch(symbol) {
                case "+":
                    function = PMMLFunctions.ADD;
                    break;
                case "/":
                    function = PMMLFunctions.DIVIDE;
                    break;
                case "*":
                    function = PMMLFunctions.MULTIPLY;
                    break;
                case "-":
                    function = PMMLFunctions.SUBTRACT;
                    break;
                default:
                    throw new IllegalArgumentException(formatMessage(binaryArithmetic));
            }
        } else if (expression instanceof EqualTo || expression instanceof GreaterThan || expression instanceof GreaterThanOrEqual || expression instanceof LessThan || expression instanceof LessThanOrEqual) {
            BinaryComparison binaryComparison = (BinaryComparison) binaryOperator;
            switch(symbol) {
                case "=":
                    function = PMMLFunctions.EQUAL;
                    break;
                case ">":
                    function = PMMLFunctions.GREATERTHAN;
                    break;
                case ">=":
                    function = PMMLFunctions.GREATEROREQUAL;
                    break;
                case "<":
                    function = PMMLFunctions.LESSTHAN;
                    break;
                case "<=":
                    function = PMMLFunctions.LESSOREQUAL;
                    break;
                default:
                    throw new IllegalArgumentException(formatMessage(binaryComparison));
            }
        } else {
            throw new IllegalArgumentException(formatMessage(binaryOperator));
        }
        return PMMLUtil.createApply(function, translateInternal(left), translateInternal(right));
    } else if (expression instanceof CaseWhen) {
        CaseWhen caseWhen = (CaseWhen) expression;
        List<Tuple2<Expression, Expression>> branches = JavaConversions.seqAsJavaList(caseWhen.branches());
        Option<Expression> elseValue = caseWhen.elseValue();
        Apply apply = null;
        Iterator<Tuple2<Expression, Expression>> branchIt = branches.iterator();
        Apply prevBranchApply = null;
        do {
            Tuple2<Expression, Expression> branch = branchIt.next();
            Expression predicate = branch._1();
            Expression value = branch._2();
            Apply branchApply = PMMLUtil.createApply(PMMLFunctions.IF, translateInternal(predicate), translateInternal(value));
            if (apply == null) {
                apply = branchApply;
            }
            if (prevBranchApply != null) {
                prevBranchApply.addExpressions(branchApply);
            }
            prevBranchApply = branchApply;
        } while (branchIt.hasNext());
        if (elseValue.isDefined()) {
            Expression value = elseValue.get();
            prevBranchApply.addExpressions(translateInternal(value));
        }
        return apply;
    } else if (expression instanceof Cast) {
        Cast cast = (Cast) expression;
        Expression child = cast.child();
        org.dmg.pmml.Expression pmmlExpression = translateInternal(child);
        DataType dataType = DatasetUtil.translateDataType(cast.dataType());
        if (pmmlExpression instanceof HasDataType) {
            HasDataType<?> hasDataType = (HasDataType<?>) pmmlExpression;
            hasDataType.setDataType(dataType);
            return pmmlExpression;
        } else {
            FieldName name;
            if (pmmlExpression instanceof AliasExpression) {
                AliasExpression aliasExpression = (AliasExpression) pmmlExpression;
                name = FieldName.create(aliasExpression.getName());
            } else {
                name = FieldNameUtil.create(dataType, ExpressionUtil.format(child));
            }
            OpType opType = ExpressionUtil.getOpType(dataType);
            pmmlExpression = AliasExpression.unwrap(pmmlExpression);
            DerivedField derivedField = encoder.createDerivedField(name, opType, dataType, pmmlExpression);
            return new FieldRef(derivedField.getName());
        }
    } else if (expression instanceof Concat) {
        Concat concat = (Concat) expression;
        List<Expression> children = JavaConversions.seqAsJavaList(concat.children());
        Apply apply = PMMLUtil.createApply(PMMLFunctions.CONCAT);
        for (Expression child : children) {
            apply.addExpressions(translateInternal(child));
        }
        return apply;
    } else if (expression instanceof Greatest) {
        Greatest greatest = (Greatest) expression;
        List<Expression> children = JavaConversions.seqAsJavaList(greatest.children());
        Apply apply = PMMLUtil.createApply(PMMLFunctions.MAX);
        for (Expression child : children) {
            apply.addExpressions(translateInternal(child));
        }
        return apply;
    } else if (expression instanceof If) {
        If _if = (If) expression;
        Expression predicate = _if.predicate();
        Expression trueValue = _if.trueValue();
        Expression falseValue = _if.falseValue();
        return PMMLUtil.createApply(PMMLFunctions.IF, translateInternal(predicate), translateInternal(trueValue), translateInternal(falseValue));
    } else if (expression instanceof In) {
        In in = (In) expression;
        Expression value = in.value();
        List<Expression> elements = JavaConversions.seqAsJavaList(in.list());
        Apply apply = PMMLUtil.createApply(PMMLFunctions.ISIN, translateInternal(value));
        for (Expression element : elements) {
            apply.addExpressions(translateInternal(element));
        }
        return apply;
    } else if (expression instanceof Least) {
        Least least = (Least) expression;
        List<Expression> children = JavaConversions.seqAsJavaList(least.children());
        Apply apply = PMMLUtil.createApply(PMMLFunctions.MIN);
        for (Expression child : children) {
            apply.addExpressions(translateInternal(child));
        }
        return apply;
    } else if (expression instanceof Length) {
        Length length = (Length) expression;
        Expression child = length.child();
        return PMMLUtil.createApply(PMMLFunctions.STRINGLENGTH, translateInternal(child));
    } else if (expression instanceof Literal) {
        Literal literal = (Literal) expression;
        Object value = literal.value();
        if (value == null) {
            return PMMLUtil.createMissingConstant();
        }
        DataType dataType;
        // XXX
        if (value instanceof Decimal) {
            Decimal decimal = (Decimal) value;
            dataType = DataType.STRING;
            value = decimal.toString();
        } else {
            dataType = DatasetUtil.translateDataType(literal.dataType());
            value = toSimpleObject(value);
        }
        return PMMLUtil.createConstant(value, dataType);
    } else if (expression instanceof RegExpReplace) {
        RegExpReplace regexpReplace = (RegExpReplace) expression;
        Expression subject = regexpReplace.subject();
        Expression regexp = regexpReplace.regexp();
        Expression rep = regexpReplace.rep();
        return PMMLUtil.createApply(PMMLFunctions.REPLACE, translateInternal(subject), translateInternal(regexp), translateInternal(rep));
    } else if (expression instanceof RLike) {
        RLike rlike = (RLike) expression;
        Expression left = rlike.left();
        Expression right = rlike.right();
        return PMMLUtil.createApply(PMMLFunctions.MATCHES, translateInternal(left), translateInternal(right));
    } else if (expression instanceof StringTrim) {
        StringTrim stringTrim = (StringTrim) expression;
        Expression srcStr = stringTrim.srcStr();
        Option<Expression> trimStr = stringTrim.trimStr();
        if (trimStr.isDefined()) {
            throw new IllegalArgumentException();
        }
        return PMMLUtil.createApply(PMMLFunctions.TRIMBLANKS, translateInternal(srcStr));
    } else if (expression instanceof Substring) {
        Substring substring = (Substring) expression;
        Expression str = substring.str();
        Literal pos = (Literal) substring.pos();
        Literal len = (Literal) substring.len();
        int posValue = ValueUtil.asInt((Number) pos.value());
        if (posValue <= 0) {
            throw new IllegalArgumentException("Expected absolute start position, got relative start position " + (pos));
        }
        int lenValue = ValueUtil.asInt((Number) len.value());
        // XXX
        lenValue = Math.min(lenValue, MAX_STRING_LENGTH);
        return PMMLUtil.createApply(PMMLFunctions.SUBSTRING, translateInternal(str), PMMLUtil.createConstant(posValue), PMMLUtil.createConstant(lenValue));
    } else if (expression instanceof UnaryExpression) {
        UnaryExpression unaryExpression = (UnaryExpression) expression;
        Expression child = unaryExpression.child();
        if (expression instanceof Abs) {
            return PMMLUtil.createApply(PMMLFunctions.ABS, translateInternal(child));
        } else if (expression instanceof Acos) {
            return PMMLUtil.createApply(PMMLFunctions.ACOS, translateInternal(child));
        } else if (expression instanceof Asin) {
            return PMMLUtil.createApply(PMMLFunctions.ASIN, translateInternal(child));
        } else if (expression instanceof Atan) {
            return PMMLUtil.createApply(PMMLFunctions.ATAN, translateInternal(child));
        } else if (expression instanceof Ceil) {
            return PMMLUtil.createApply(PMMLFunctions.CEIL, translateInternal(child));
        } else if (expression instanceof Cos) {
            return PMMLUtil.createApply(PMMLFunctions.COS, translateInternal(child));
        } else if (expression instanceof Cosh) {
            return PMMLUtil.createApply(PMMLFunctions.COSH, translateInternal(child));
        } else if (expression instanceof Exp) {
            return PMMLUtil.createApply(PMMLFunctions.EXP, translateInternal(child));
        } else if (expression instanceof Expm1) {
            return PMMLUtil.createApply(PMMLFunctions.EXPM1, translateInternal(child));
        } else if (expression instanceof Floor) {
            return PMMLUtil.createApply(PMMLFunctions.FLOOR, translateInternal(child));
        } else if (expression instanceof Log) {
            return PMMLUtil.createApply(PMMLFunctions.LN, translateInternal(child));
        } else if (expression instanceof Log10) {
            return PMMLUtil.createApply(PMMLFunctions.LOG10, translateInternal(child));
        } else if (expression instanceof Log1p) {
            return PMMLUtil.createApply(PMMLFunctions.LN1P, translateInternal(child));
        } else if (expression instanceof Lower) {
            return PMMLUtil.createApply(PMMLFunctions.LOWERCASE, translateInternal(child));
        } else if (expression instanceof IsNaN) {
            // XXX
            return PMMLUtil.createApply(PMMLFunctions.ISNOTVALID, translateInternal(child));
        } else if (expression instanceof IsNotNull) {
            return PMMLUtil.createApply(PMMLFunctions.ISNOTMISSING, translateInternal(child));
        } else if (expression instanceof IsNull) {
            return PMMLUtil.createApply(PMMLFunctions.ISMISSING, translateInternal(child));
        } else if (expression instanceof Not) {
            return PMMLUtil.createApply(PMMLFunctions.NOT, translateInternal(child));
        } else if (expression instanceof Rint) {
            return PMMLUtil.createApply(PMMLFunctions.RINT, translateInternal(child));
        } else if (expression instanceof Sin) {
            return PMMLUtil.createApply(PMMLFunctions.SIN, translateInternal(child));
        } else if (expression instanceof Sinh) {
            return PMMLUtil.createApply(PMMLFunctions.SINH, translateInternal(child));
        } else if (expression instanceof Sqrt) {
            return PMMLUtil.createApply(PMMLFunctions.SQRT, translateInternal(child));
        } else if (expression instanceof Tan) {
            return PMMLUtil.createApply(PMMLFunctions.TAN, translateInternal(child));
        } else if (expression instanceof Tanh) {
            return PMMLUtil.createApply(PMMLFunctions.TANH, translateInternal(child));
        } else if (expression instanceof UnaryMinus) {
            return PMMLUtil.toNegative(translateInternal(child));
        } else if (expression instanceof UnaryPositive) {
            return translateInternal(child);
        } else if (expression instanceof Upper) {
            return PMMLUtil.createApply(PMMLFunctions.UPPERCASE, translateInternal(child));
        } else {
            throw new IllegalArgumentException(formatMessage(unaryExpression));
        }
    } else {
        throw new IllegalArgumentException(formatMessage(expression));
    }
}
Also used : Add(org.apache.spark.sql.catalyst.expressions.Add) Tan(org.apache.spark.sql.catalyst.expressions.Tan) RegExpReplace(org.apache.spark.sql.catalyst.expressions.RegExpReplace) Lower(org.apache.spark.sql.catalyst.expressions.Lower) Or(org.apache.spark.sql.catalyst.expressions.Or) Apply(org.dmg.pmml.Apply) CaseWhen(org.apache.spark.sql.catalyst.expressions.CaseWhen) Concat(org.apache.spark.sql.catalyst.expressions.Concat) Divide(org.apache.spark.sql.catalyst.expressions.Divide) LessThan(org.apache.spark.sql.catalyst.expressions.LessThan) IsNotNull(org.apache.spark.sql.catalyst.expressions.IsNotNull) Decimal(org.apache.spark.sql.types.Decimal) GreaterThan(org.apache.spark.sql.catalyst.expressions.GreaterThan) Multiply(org.apache.spark.sql.catalyst.expressions.Multiply) Least(org.apache.spark.sql.catalyst.expressions.Least) Literal(org.apache.spark.sql.catalyst.expressions.Literal) Sinh(org.apache.spark.sql.catalyst.expressions.Sinh) Greatest(org.apache.spark.sql.catalyst.expressions.Greatest) List(java.util.List) BinaryOperator(org.apache.spark.sql.catalyst.expressions.BinaryOperator) BinaryArithmetic(org.apache.spark.sql.catalyst.expressions.BinaryArithmetic) BinaryMathExpression(org.apache.spark.sql.catalyst.expressions.BinaryMathExpression) Substring(org.apache.spark.sql.catalyst.expressions.Substring) Tanh(org.apache.spark.sql.catalyst.expressions.Tanh) Log10(org.apache.spark.sql.catalyst.expressions.Log10) FieldRef(org.dmg.pmml.FieldRef) Expm1(org.apache.spark.sql.catalyst.expressions.Expm1) Log(org.apache.spark.sql.catalyst.expressions.Log) AttributeReference(org.apache.spark.sql.catalyst.expressions.AttributeReference) UnaryMinus(org.apache.spark.sql.catalyst.expressions.UnaryMinus) Hypot(org.apache.spark.sql.catalyst.expressions.Hypot) RLike(org.apache.spark.sql.catalyst.expressions.RLike) EqualTo(org.apache.spark.sql.catalyst.expressions.EqualTo) Not(org.apache.spark.sql.catalyst.expressions.Not) And(org.apache.spark.sql.catalyst.expressions.And) BinaryComparison(org.apache.spark.sql.catalyst.expressions.BinaryComparison) Pow(org.apache.spark.sql.catalyst.expressions.Pow) Subtract(org.apache.spark.sql.catalyst.expressions.Subtract) Acos(org.apache.spark.sql.catalyst.expressions.Acos) Sin(org.apache.spark.sql.catalyst.expressions.Sin) Option(scala.Option) Ceil(org.apache.spark.sql.catalyst.expressions.Ceil) If(org.apache.spark.sql.catalyst.expressions.If) Cast(org.apache.spark.sql.catalyst.expressions.Cast) Cosh(org.apache.spark.sql.catalyst.expressions.Cosh) LessThanOrEqual(org.apache.spark.sql.catalyst.expressions.LessThanOrEqual) In(org.apache.spark.sql.catalyst.expressions.In) UnaryExpression(org.apache.spark.sql.catalyst.expressions.UnaryExpression) GreaterThanOrEqual(org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual) Abs(org.apache.spark.sql.catalyst.expressions.Abs) Iterator(java.util.Iterator) HasDataType(org.dmg.pmml.HasDataType) DataType(org.dmg.pmml.DataType) FieldName(org.dmg.pmml.FieldName) Atan(org.apache.spark.sql.catalyst.expressions.Atan) UnaryPositive(org.apache.spark.sql.catalyst.expressions.UnaryPositive) Upper(org.apache.spark.sql.catalyst.expressions.Upper) Floor(org.apache.spark.sql.catalyst.expressions.Floor) Log1p(org.apache.spark.sql.catalyst.expressions.Log1p) Cos(org.apache.spark.sql.catalyst.expressions.Cos) Sqrt(org.apache.spark.sql.catalyst.expressions.Sqrt) Asin(org.apache.spark.sql.catalyst.expressions.Asin) IsNaN(org.apache.spark.sql.catalyst.expressions.IsNaN) Expression(org.apache.spark.sql.catalyst.expressions.Expression) UnaryExpression(org.apache.spark.sql.catalyst.expressions.UnaryExpression) BinaryMathExpression(org.apache.spark.sql.catalyst.expressions.BinaryMathExpression) Length(org.apache.spark.sql.catalyst.expressions.Length) Alias(org.apache.spark.sql.catalyst.expressions.Alias) Tuple2(scala.Tuple2) OpType(org.dmg.pmml.OpType) IsNull(org.apache.spark.sql.catalyst.expressions.IsNull) Rint(org.apache.spark.sql.catalyst.expressions.Rint) Exp(org.apache.spark.sql.catalyst.expressions.Exp) DerivedField(org.dmg.pmml.DerivedField) StringTrim(org.apache.spark.sql.catalyst.expressions.StringTrim) HasDataType(org.dmg.pmml.HasDataType)

Aggregations

Iterator (java.util.Iterator)1 List (java.util.List)1 Abs (org.apache.spark.sql.catalyst.expressions.Abs)1 Acos (org.apache.spark.sql.catalyst.expressions.Acos)1 Add (org.apache.spark.sql.catalyst.expressions.Add)1 Alias (org.apache.spark.sql.catalyst.expressions.Alias)1 And (org.apache.spark.sql.catalyst.expressions.And)1 Asin (org.apache.spark.sql.catalyst.expressions.Asin)1 Atan (org.apache.spark.sql.catalyst.expressions.Atan)1 AttributeReference (org.apache.spark.sql.catalyst.expressions.AttributeReference)1 BinaryArithmetic (org.apache.spark.sql.catalyst.expressions.BinaryArithmetic)1 BinaryComparison (org.apache.spark.sql.catalyst.expressions.BinaryComparison)1 BinaryMathExpression (org.apache.spark.sql.catalyst.expressions.BinaryMathExpression)1 BinaryOperator (org.apache.spark.sql.catalyst.expressions.BinaryOperator)1 CaseWhen (org.apache.spark.sql.catalyst.expressions.CaseWhen)1 Cast (org.apache.spark.sql.catalyst.expressions.Cast)1 Ceil (org.apache.spark.sql.catalyst.expressions.Ceil)1 Concat (org.apache.spark.sql.catalyst.expressions.Concat)1 Cos (org.apache.spark.sql.catalyst.expressions.Cos)1 Cosh (org.apache.spark.sql.catalyst.expressions.Cosh)1