use of org.apache.spark.sql.catalyst.expressions.CaseWhen in project jpmml-sparkml by jpmml.
the class ExpressionTranslator method translateInternal.
private org.dmg.pmml.Expression translateInternal(Expression expression) {
SparkMLEncoder encoder = getEncoder();
if (expression instanceof Alias) {
Alias alias = (Alias) expression;
String name = alias.name();
Expression child = alias.child();
org.dmg.pmml.Expression pmmlExpression = translateInternal(child);
return new AliasExpression(name, pmmlExpression);
}
if (expression instanceof AttributeReference) {
AttributeReference attributeReference = (AttributeReference) expression;
String name = attributeReference.name();
return new FieldRef(FieldName.create(name));
} else if (expression instanceof BinaryMathExpression) {
BinaryMathExpression binaryMathExpression = (BinaryMathExpression) expression;
Expression left = binaryMathExpression.left();
Expression right = binaryMathExpression.right();
String function;
if (binaryMathExpression instanceof Hypot) {
function = PMMLFunctions.HYPOT;
} else if (binaryMathExpression instanceof Pow) {
function = PMMLFunctions.POW;
} else {
throw new IllegalArgumentException(formatMessage(binaryMathExpression));
}
return PMMLUtil.createApply(function, translateInternal(left), translateInternal(right));
} else if (expression instanceof BinaryOperator) {
BinaryOperator binaryOperator = (BinaryOperator) expression;
String symbol = binaryOperator.symbol();
Expression left = binaryOperator.left();
Expression right = binaryOperator.right();
String function;
if (expression instanceof And || expression instanceof Or) {
switch(symbol) {
case "&&":
function = PMMLFunctions.AND;
break;
case "||":
function = PMMLFunctions.OR;
break;
default:
throw new IllegalArgumentException(formatMessage(binaryOperator));
}
} else if (expression instanceof Add || expression instanceof Divide || expression instanceof Multiply || expression instanceof Subtract) {
BinaryArithmetic binaryArithmetic = (BinaryArithmetic) binaryOperator;
switch(symbol) {
case "+":
function = PMMLFunctions.ADD;
break;
case "/":
function = PMMLFunctions.DIVIDE;
break;
case "*":
function = PMMLFunctions.MULTIPLY;
break;
case "-":
function = PMMLFunctions.SUBTRACT;
break;
default:
throw new IllegalArgumentException(formatMessage(binaryArithmetic));
}
} else if (expression instanceof EqualTo || expression instanceof GreaterThan || expression instanceof GreaterThanOrEqual || expression instanceof LessThan || expression instanceof LessThanOrEqual) {
BinaryComparison binaryComparison = (BinaryComparison) binaryOperator;
switch(symbol) {
case "=":
function = PMMLFunctions.EQUAL;
break;
case ">":
function = PMMLFunctions.GREATERTHAN;
break;
case ">=":
function = PMMLFunctions.GREATEROREQUAL;
break;
case "<":
function = PMMLFunctions.LESSTHAN;
break;
case "<=":
function = PMMLFunctions.LESSOREQUAL;
break;
default:
throw new IllegalArgumentException(formatMessage(binaryComparison));
}
} else {
throw new IllegalArgumentException(formatMessage(binaryOperator));
}
return PMMLUtil.createApply(function, translateInternal(left), translateInternal(right));
} else if (expression instanceof CaseWhen) {
CaseWhen caseWhen = (CaseWhen) expression;
List<Tuple2<Expression, Expression>> branches = JavaConversions.seqAsJavaList(caseWhen.branches());
Option<Expression> elseValue = caseWhen.elseValue();
Apply apply = null;
Iterator<Tuple2<Expression, Expression>> branchIt = branches.iterator();
Apply prevBranchApply = null;
do {
Tuple2<Expression, Expression> branch = branchIt.next();
Expression predicate = branch._1();
Expression value = branch._2();
Apply branchApply = PMMLUtil.createApply(PMMLFunctions.IF, translateInternal(predicate), translateInternal(value));
if (apply == null) {
apply = branchApply;
}
if (prevBranchApply != null) {
prevBranchApply.addExpressions(branchApply);
}
prevBranchApply = branchApply;
} while (branchIt.hasNext());
if (elseValue.isDefined()) {
Expression value = elseValue.get();
prevBranchApply.addExpressions(translateInternal(value));
}
return apply;
} else if (expression instanceof Cast) {
Cast cast = (Cast) expression;
Expression child = cast.child();
org.dmg.pmml.Expression pmmlExpression = translateInternal(child);
DataType dataType = DatasetUtil.translateDataType(cast.dataType());
if (pmmlExpression instanceof HasDataType) {
HasDataType<?> hasDataType = (HasDataType<?>) pmmlExpression;
hasDataType.setDataType(dataType);
return pmmlExpression;
} else {
FieldName name;
if (pmmlExpression instanceof AliasExpression) {
AliasExpression aliasExpression = (AliasExpression) pmmlExpression;
name = FieldName.create(aliasExpression.getName());
} else {
name = FieldNameUtil.create(dataType, ExpressionUtil.format(child));
}
OpType opType = ExpressionUtil.getOpType(dataType);
pmmlExpression = AliasExpression.unwrap(pmmlExpression);
DerivedField derivedField = encoder.createDerivedField(name, opType, dataType, pmmlExpression);
return new FieldRef(derivedField.getName());
}
} else if (expression instanceof Concat) {
Concat concat = (Concat) expression;
List<Expression> children = JavaConversions.seqAsJavaList(concat.children());
Apply apply = PMMLUtil.createApply(PMMLFunctions.CONCAT);
for (Expression child : children) {
apply.addExpressions(translateInternal(child));
}
return apply;
} else if (expression instanceof Greatest) {
Greatest greatest = (Greatest) expression;
List<Expression> children = JavaConversions.seqAsJavaList(greatest.children());
Apply apply = PMMLUtil.createApply(PMMLFunctions.MAX);
for (Expression child : children) {
apply.addExpressions(translateInternal(child));
}
return apply;
} else if (expression instanceof If) {
If _if = (If) expression;
Expression predicate = _if.predicate();
Expression trueValue = _if.trueValue();
Expression falseValue = _if.falseValue();
return PMMLUtil.createApply(PMMLFunctions.IF, translateInternal(predicate), translateInternal(trueValue), translateInternal(falseValue));
} else if (expression instanceof In) {
In in = (In) expression;
Expression value = in.value();
List<Expression> elements = JavaConversions.seqAsJavaList(in.list());
Apply apply = PMMLUtil.createApply(PMMLFunctions.ISIN, translateInternal(value));
for (Expression element : elements) {
apply.addExpressions(translateInternal(element));
}
return apply;
} else if (expression instanceof Least) {
Least least = (Least) expression;
List<Expression> children = JavaConversions.seqAsJavaList(least.children());
Apply apply = PMMLUtil.createApply(PMMLFunctions.MIN);
for (Expression child : children) {
apply.addExpressions(translateInternal(child));
}
return apply;
} else if (expression instanceof Length) {
Length length = (Length) expression;
Expression child = length.child();
return PMMLUtil.createApply(PMMLFunctions.STRINGLENGTH, translateInternal(child));
} else if (expression instanceof Literal) {
Literal literal = (Literal) expression;
Object value = literal.value();
if (value == null) {
return PMMLUtil.createMissingConstant();
}
DataType dataType;
// XXX
if (value instanceof Decimal) {
Decimal decimal = (Decimal) value;
dataType = DataType.STRING;
value = decimal.toString();
} else {
dataType = DatasetUtil.translateDataType(literal.dataType());
value = toSimpleObject(value);
}
return PMMLUtil.createConstant(value, dataType);
} else if (expression instanceof RegExpReplace) {
RegExpReplace regexpReplace = (RegExpReplace) expression;
Expression subject = regexpReplace.subject();
Expression regexp = regexpReplace.regexp();
Expression rep = regexpReplace.rep();
return PMMLUtil.createApply(PMMLFunctions.REPLACE, translateInternal(subject), translateInternal(regexp), translateInternal(rep));
} else if (expression instanceof RLike) {
RLike rlike = (RLike) expression;
Expression left = rlike.left();
Expression right = rlike.right();
return PMMLUtil.createApply(PMMLFunctions.MATCHES, translateInternal(left), translateInternal(right));
} else if (expression instanceof StringTrim) {
StringTrim stringTrim = (StringTrim) expression;
Expression srcStr = stringTrim.srcStr();
Option<Expression> trimStr = stringTrim.trimStr();
if (trimStr.isDefined()) {
throw new IllegalArgumentException();
}
return PMMLUtil.createApply(PMMLFunctions.TRIMBLANKS, translateInternal(srcStr));
} else if (expression instanceof Substring) {
Substring substring = (Substring) expression;
Expression str = substring.str();
Literal pos = (Literal) substring.pos();
Literal len = (Literal) substring.len();
int posValue = ValueUtil.asInt((Number) pos.value());
if (posValue <= 0) {
throw new IllegalArgumentException("Expected absolute start position, got relative start position " + (pos));
}
int lenValue = ValueUtil.asInt((Number) len.value());
// XXX
lenValue = Math.min(lenValue, MAX_STRING_LENGTH);
return PMMLUtil.createApply(PMMLFunctions.SUBSTRING, translateInternal(str), PMMLUtil.createConstant(posValue), PMMLUtil.createConstant(lenValue));
} else if (expression instanceof UnaryExpression) {
UnaryExpression unaryExpression = (UnaryExpression) expression;
Expression child = unaryExpression.child();
if (expression instanceof Abs) {
return PMMLUtil.createApply(PMMLFunctions.ABS, translateInternal(child));
} else if (expression instanceof Acos) {
return PMMLUtil.createApply(PMMLFunctions.ACOS, translateInternal(child));
} else if (expression instanceof Asin) {
return PMMLUtil.createApply(PMMLFunctions.ASIN, translateInternal(child));
} else if (expression instanceof Atan) {
return PMMLUtil.createApply(PMMLFunctions.ATAN, translateInternal(child));
} else if (expression instanceof Ceil) {
return PMMLUtil.createApply(PMMLFunctions.CEIL, translateInternal(child));
} else if (expression instanceof Cos) {
return PMMLUtil.createApply(PMMLFunctions.COS, translateInternal(child));
} else if (expression instanceof Cosh) {
return PMMLUtil.createApply(PMMLFunctions.COSH, translateInternal(child));
} else if (expression instanceof Exp) {
return PMMLUtil.createApply(PMMLFunctions.EXP, translateInternal(child));
} else if (expression instanceof Expm1) {
return PMMLUtil.createApply(PMMLFunctions.EXPM1, translateInternal(child));
} else if (expression instanceof Floor) {
return PMMLUtil.createApply(PMMLFunctions.FLOOR, translateInternal(child));
} else if (expression instanceof Log) {
return PMMLUtil.createApply(PMMLFunctions.LN, translateInternal(child));
} else if (expression instanceof Log10) {
return PMMLUtil.createApply(PMMLFunctions.LOG10, translateInternal(child));
} else if (expression instanceof Log1p) {
return PMMLUtil.createApply(PMMLFunctions.LN1P, translateInternal(child));
} else if (expression instanceof Lower) {
return PMMLUtil.createApply(PMMLFunctions.LOWERCASE, translateInternal(child));
} else if (expression instanceof IsNaN) {
// XXX
return PMMLUtil.createApply(PMMLFunctions.ISNOTVALID, translateInternal(child));
} else if (expression instanceof IsNotNull) {
return PMMLUtil.createApply(PMMLFunctions.ISNOTMISSING, translateInternal(child));
} else if (expression instanceof IsNull) {
return PMMLUtil.createApply(PMMLFunctions.ISMISSING, translateInternal(child));
} else if (expression instanceof Not) {
return PMMLUtil.createApply(PMMLFunctions.NOT, translateInternal(child));
} else if (expression instanceof Rint) {
return PMMLUtil.createApply(PMMLFunctions.RINT, translateInternal(child));
} else if (expression instanceof Sin) {
return PMMLUtil.createApply(PMMLFunctions.SIN, translateInternal(child));
} else if (expression instanceof Sinh) {
return PMMLUtil.createApply(PMMLFunctions.SINH, translateInternal(child));
} else if (expression instanceof Sqrt) {
return PMMLUtil.createApply(PMMLFunctions.SQRT, translateInternal(child));
} else if (expression instanceof Tan) {
return PMMLUtil.createApply(PMMLFunctions.TAN, translateInternal(child));
} else if (expression instanceof Tanh) {
return PMMLUtil.createApply(PMMLFunctions.TANH, translateInternal(child));
} else if (expression instanceof UnaryMinus) {
return PMMLUtil.toNegative(translateInternal(child));
} else if (expression instanceof UnaryPositive) {
return translateInternal(child);
} else if (expression instanceof Upper) {
return PMMLUtil.createApply(PMMLFunctions.UPPERCASE, translateInternal(child));
} else {
throw new IllegalArgumentException(formatMessage(unaryExpression));
}
} else {
throw new IllegalArgumentException(formatMessage(expression));
}
}
Aggregations