use of org.dmg.pmml.FieldRef in project jpmml-sparkml by jpmml.
the class ExpressionTranslator method translateInternal.
private org.dmg.pmml.Expression translateInternal(Expression expression) {
SparkMLEncoder encoder = getEncoder();
if (expression instanceof Alias) {
Alias alias = (Alias) expression;
String name = alias.name();
Expression child = alias.child();
org.dmg.pmml.Expression pmmlExpression = translateInternal(child);
return new AliasExpression(name, pmmlExpression);
}
if (expression instanceof AttributeReference) {
AttributeReference attributeReference = (AttributeReference) expression;
String name = attributeReference.name();
return new FieldRef(FieldName.create(name));
} else if (expression instanceof BinaryMathExpression) {
BinaryMathExpression binaryMathExpression = (BinaryMathExpression) expression;
Expression left = binaryMathExpression.left();
Expression right = binaryMathExpression.right();
String function;
if (binaryMathExpression instanceof Hypot) {
function = PMMLFunctions.HYPOT;
} else if (binaryMathExpression instanceof Pow) {
function = PMMLFunctions.POW;
} else {
throw new IllegalArgumentException(formatMessage(binaryMathExpression));
}
return PMMLUtil.createApply(function, translateInternal(left), translateInternal(right));
} else if (expression instanceof BinaryOperator) {
BinaryOperator binaryOperator = (BinaryOperator) expression;
String symbol = binaryOperator.symbol();
Expression left = binaryOperator.left();
Expression right = binaryOperator.right();
String function;
if (expression instanceof And || expression instanceof Or) {
switch(symbol) {
case "&&":
function = PMMLFunctions.AND;
break;
case "||":
function = PMMLFunctions.OR;
break;
default:
throw new IllegalArgumentException(formatMessage(binaryOperator));
}
} else if (expression instanceof Add || expression instanceof Divide || expression instanceof Multiply || expression instanceof Subtract) {
BinaryArithmetic binaryArithmetic = (BinaryArithmetic) binaryOperator;
switch(symbol) {
case "+":
function = PMMLFunctions.ADD;
break;
case "/":
function = PMMLFunctions.DIVIDE;
break;
case "*":
function = PMMLFunctions.MULTIPLY;
break;
case "-":
function = PMMLFunctions.SUBTRACT;
break;
default:
throw new IllegalArgumentException(formatMessage(binaryArithmetic));
}
} else if (expression instanceof EqualTo || expression instanceof GreaterThan || expression instanceof GreaterThanOrEqual || expression instanceof LessThan || expression instanceof LessThanOrEqual) {
BinaryComparison binaryComparison = (BinaryComparison) binaryOperator;
switch(symbol) {
case "=":
function = PMMLFunctions.EQUAL;
break;
case ">":
function = PMMLFunctions.GREATERTHAN;
break;
case ">=":
function = PMMLFunctions.GREATEROREQUAL;
break;
case "<":
function = PMMLFunctions.LESSTHAN;
break;
case "<=":
function = PMMLFunctions.LESSOREQUAL;
break;
default:
throw new IllegalArgumentException(formatMessage(binaryComparison));
}
} else {
throw new IllegalArgumentException(formatMessage(binaryOperator));
}
return PMMLUtil.createApply(function, translateInternal(left), translateInternal(right));
} else if (expression instanceof CaseWhen) {
CaseWhen caseWhen = (CaseWhen) expression;
List<Tuple2<Expression, Expression>> branches = JavaConversions.seqAsJavaList(caseWhen.branches());
Option<Expression> elseValue = caseWhen.elseValue();
Apply apply = null;
Iterator<Tuple2<Expression, Expression>> branchIt = branches.iterator();
Apply prevBranchApply = null;
do {
Tuple2<Expression, Expression> branch = branchIt.next();
Expression predicate = branch._1();
Expression value = branch._2();
Apply branchApply = PMMLUtil.createApply(PMMLFunctions.IF, translateInternal(predicate), translateInternal(value));
if (apply == null) {
apply = branchApply;
}
if (prevBranchApply != null) {
prevBranchApply.addExpressions(branchApply);
}
prevBranchApply = branchApply;
} while (branchIt.hasNext());
if (elseValue.isDefined()) {
Expression value = elseValue.get();
prevBranchApply.addExpressions(translateInternal(value));
}
return apply;
} else if (expression instanceof Cast) {
Cast cast = (Cast) expression;
Expression child = cast.child();
org.dmg.pmml.Expression pmmlExpression = translateInternal(child);
DataType dataType = DatasetUtil.translateDataType(cast.dataType());
if (pmmlExpression instanceof HasDataType) {
HasDataType<?> hasDataType = (HasDataType<?>) pmmlExpression;
hasDataType.setDataType(dataType);
return pmmlExpression;
} else {
FieldName name;
if (pmmlExpression instanceof AliasExpression) {
AliasExpression aliasExpression = (AliasExpression) pmmlExpression;
name = FieldName.create(aliasExpression.getName());
} else {
name = FieldNameUtil.create(dataType, ExpressionUtil.format(child));
}
OpType opType = ExpressionUtil.getOpType(dataType);
pmmlExpression = AliasExpression.unwrap(pmmlExpression);
DerivedField derivedField = encoder.createDerivedField(name, opType, dataType, pmmlExpression);
return new FieldRef(derivedField.getName());
}
} else if (expression instanceof Concat) {
Concat concat = (Concat) expression;
List<Expression> children = JavaConversions.seqAsJavaList(concat.children());
Apply apply = PMMLUtil.createApply(PMMLFunctions.CONCAT);
for (Expression child : children) {
apply.addExpressions(translateInternal(child));
}
return apply;
} else if (expression instanceof Greatest) {
Greatest greatest = (Greatest) expression;
List<Expression> children = JavaConversions.seqAsJavaList(greatest.children());
Apply apply = PMMLUtil.createApply(PMMLFunctions.MAX);
for (Expression child : children) {
apply.addExpressions(translateInternal(child));
}
return apply;
} else if (expression instanceof If) {
If _if = (If) expression;
Expression predicate = _if.predicate();
Expression trueValue = _if.trueValue();
Expression falseValue = _if.falseValue();
return PMMLUtil.createApply(PMMLFunctions.IF, translateInternal(predicate), translateInternal(trueValue), translateInternal(falseValue));
} else if (expression instanceof In) {
In in = (In) expression;
Expression value = in.value();
List<Expression> elements = JavaConversions.seqAsJavaList(in.list());
Apply apply = PMMLUtil.createApply(PMMLFunctions.ISIN, translateInternal(value));
for (Expression element : elements) {
apply.addExpressions(translateInternal(element));
}
return apply;
} else if (expression instanceof Least) {
Least least = (Least) expression;
List<Expression> children = JavaConversions.seqAsJavaList(least.children());
Apply apply = PMMLUtil.createApply(PMMLFunctions.MIN);
for (Expression child : children) {
apply.addExpressions(translateInternal(child));
}
return apply;
} else if (expression instanceof Length) {
Length length = (Length) expression;
Expression child = length.child();
return PMMLUtil.createApply(PMMLFunctions.STRINGLENGTH, translateInternal(child));
} else if (expression instanceof Literal) {
Literal literal = (Literal) expression;
Object value = literal.value();
if (value == null) {
return PMMLUtil.createMissingConstant();
}
DataType dataType;
// XXX
if (value instanceof Decimal) {
Decimal decimal = (Decimal) value;
dataType = DataType.STRING;
value = decimal.toString();
} else {
dataType = DatasetUtil.translateDataType(literal.dataType());
value = toSimpleObject(value);
}
return PMMLUtil.createConstant(value, dataType);
} else if (expression instanceof RegExpReplace) {
RegExpReplace regexpReplace = (RegExpReplace) expression;
Expression subject = regexpReplace.subject();
Expression regexp = regexpReplace.regexp();
Expression rep = regexpReplace.rep();
return PMMLUtil.createApply(PMMLFunctions.REPLACE, translateInternal(subject), translateInternal(regexp), translateInternal(rep));
} else if (expression instanceof RLike) {
RLike rlike = (RLike) expression;
Expression left = rlike.left();
Expression right = rlike.right();
return PMMLUtil.createApply(PMMLFunctions.MATCHES, translateInternal(left), translateInternal(right));
} else if (expression instanceof StringTrim) {
StringTrim stringTrim = (StringTrim) expression;
Expression srcStr = stringTrim.srcStr();
Option<Expression> trimStr = stringTrim.trimStr();
if (trimStr.isDefined()) {
throw new IllegalArgumentException();
}
return PMMLUtil.createApply(PMMLFunctions.TRIMBLANKS, translateInternal(srcStr));
} else if (expression instanceof Substring) {
Substring substring = (Substring) expression;
Expression str = substring.str();
Literal pos = (Literal) substring.pos();
Literal len = (Literal) substring.len();
int posValue = ValueUtil.asInt((Number) pos.value());
if (posValue <= 0) {
throw new IllegalArgumentException("Expected absolute start position, got relative start position " + (pos));
}
int lenValue = ValueUtil.asInt((Number) len.value());
// XXX
lenValue = Math.min(lenValue, MAX_STRING_LENGTH);
return PMMLUtil.createApply(PMMLFunctions.SUBSTRING, translateInternal(str), PMMLUtil.createConstant(posValue), PMMLUtil.createConstant(lenValue));
} else if (expression instanceof UnaryExpression) {
UnaryExpression unaryExpression = (UnaryExpression) expression;
Expression child = unaryExpression.child();
if (expression instanceof Abs) {
return PMMLUtil.createApply(PMMLFunctions.ABS, translateInternal(child));
} else if (expression instanceof Acos) {
return PMMLUtil.createApply(PMMLFunctions.ACOS, translateInternal(child));
} else if (expression instanceof Asin) {
return PMMLUtil.createApply(PMMLFunctions.ASIN, translateInternal(child));
} else if (expression instanceof Atan) {
return PMMLUtil.createApply(PMMLFunctions.ATAN, translateInternal(child));
} else if (expression instanceof Ceil) {
return PMMLUtil.createApply(PMMLFunctions.CEIL, translateInternal(child));
} else if (expression instanceof Cos) {
return PMMLUtil.createApply(PMMLFunctions.COS, translateInternal(child));
} else if (expression instanceof Cosh) {
return PMMLUtil.createApply(PMMLFunctions.COSH, translateInternal(child));
} else if (expression instanceof Exp) {
return PMMLUtil.createApply(PMMLFunctions.EXP, translateInternal(child));
} else if (expression instanceof Expm1) {
return PMMLUtil.createApply(PMMLFunctions.EXPM1, translateInternal(child));
} else if (expression instanceof Floor) {
return PMMLUtil.createApply(PMMLFunctions.FLOOR, translateInternal(child));
} else if (expression instanceof Log) {
return PMMLUtil.createApply(PMMLFunctions.LN, translateInternal(child));
} else if (expression instanceof Log10) {
return PMMLUtil.createApply(PMMLFunctions.LOG10, translateInternal(child));
} else if (expression instanceof Log1p) {
return PMMLUtil.createApply(PMMLFunctions.LN1P, translateInternal(child));
} else if (expression instanceof Lower) {
return PMMLUtil.createApply(PMMLFunctions.LOWERCASE, translateInternal(child));
} else if (expression instanceof IsNaN) {
// XXX
return PMMLUtil.createApply(PMMLFunctions.ISNOTVALID, translateInternal(child));
} else if (expression instanceof IsNotNull) {
return PMMLUtil.createApply(PMMLFunctions.ISNOTMISSING, translateInternal(child));
} else if (expression instanceof IsNull) {
return PMMLUtil.createApply(PMMLFunctions.ISMISSING, translateInternal(child));
} else if (expression instanceof Not) {
return PMMLUtil.createApply(PMMLFunctions.NOT, translateInternal(child));
} else if (expression instanceof Rint) {
return PMMLUtil.createApply(PMMLFunctions.RINT, translateInternal(child));
} else if (expression instanceof Sin) {
return PMMLUtil.createApply(PMMLFunctions.SIN, translateInternal(child));
} else if (expression instanceof Sinh) {
return PMMLUtil.createApply(PMMLFunctions.SINH, translateInternal(child));
} else if (expression instanceof Sqrt) {
return PMMLUtil.createApply(PMMLFunctions.SQRT, translateInternal(child));
} else if (expression instanceof Tan) {
return PMMLUtil.createApply(PMMLFunctions.TAN, translateInternal(child));
} else if (expression instanceof Tanh) {
return PMMLUtil.createApply(PMMLFunctions.TANH, translateInternal(child));
} else if (expression instanceof UnaryMinus) {
return PMMLUtil.toNegative(translateInternal(child));
} else if (expression instanceof UnaryPositive) {
return translateInternal(child);
} else if (expression instanceof Upper) {
return PMMLUtil.createApply(PMMLFunctions.UPPERCASE, translateInternal(child));
} else {
throw new IllegalArgumentException(formatMessage(unaryExpression));
}
} else {
throw new IllegalArgumentException(formatMessage(expression));
}
}
use of org.dmg.pmml.FieldRef in project jpmml-sparkml by jpmml.
the class LinearSVCModelConverter method encodeModel.
@Override
public MiningModel encodeModel(Schema schema) {
LinearSVCModel model = getTransformer();
Transformation transformation = new AbstractTransformation() {
@Override
public FieldName getName(FieldName name) {
return FieldNameUtil.create("threshold", name);
}
@Override
public Expression createExpression(FieldRef fieldRef) {
return PMMLUtil.createApply(PMMLFunctions.THRESHOLD).addExpressions(fieldRef, PMMLUtil.createConstant(model.getThreshold()));
}
};
Schema segmentSchema = schema.toAnonymousRegressorSchema(DataType.DOUBLE);
Model linearModel = LinearModelUtil.createRegression(this, model.coefficients(), model.intercept(), segmentSchema).setOutput(ModelUtil.createPredictedOutput(FieldName.create("margin"), OpType.CONTINUOUS, DataType.DOUBLE, transformation));
return MiningModelUtil.createBinaryLogisticClassification(linearModel, 1d, 0d, RegressionModel.NormalizationMethod.NONE, false, schema);
}
use of org.dmg.pmml.FieldRef in project jpmml-sparkml by jpmml.
the class CountVectorizerModelConverter method encodeFeatures.
@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
CountVectorizerModel transformer = getTransformer();
DocumentFeature documentFeature = (DocumentFeature) encoder.getOnlyFeature(transformer.getInputCol());
ParameterField documentField = new ParameterField(FieldName.create("document"));
ParameterField termField = new ParameterField(FieldName.create("term"));
TextIndex textIndex = new TextIndex(documentField.getName(), new FieldRef(termField.getName())).setTokenize(Boolean.TRUE).setWordSeparatorCharacterRE(documentFeature.getWordSeparatorRE()).setLocalTermWeights(transformer.getBinary() ? TextIndex.LocalTermWeights.BINARY : null);
Set<DocumentFeature.StopWordSet> stopWordSets = documentFeature.getStopWordSets();
for (DocumentFeature.StopWordSet stopWordSet : stopWordSets) {
if (stopWordSet.isEmpty()) {
continue;
}
String tokenRE;
String wordSeparatorRE = documentFeature.getWordSeparatorRE();
switch(wordSeparatorRE) {
case "\\s+":
tokenRE = "(^|\\s+)\\p{Punct}*(" + JOINER.join(stopWordSet) + ")\\p{Punct}*(\\s+|$)";
break;
case "\\W+":
tokenRE = "(\\W+)(" + JOINER.join(stopWordSet) + ")(\\W+)";
break;
default:
throw new IllegalArgumentException("Expected \"\\s+\" or \"\\W+\" as splitter regex pattern, got \"" + wordSeparatorRE + "\"");
}
Map<String, List<String>> data = new LinkedHashMap<>();
data.put("string", Collections.singletonList(tokenRE));
data.put("stem", Collections.singletonList(" "));
data.put("regex", Collections.singletonList("true"));
TextIndexNormalization textIndexNormalization = new TextIndexNormalization(null, PMMLUtil.createInlineTable(data)).setCaseSensitive(stopWordSet.isCaseSensitive()).setRecursive(// Handles consecutive matches. See http://stackoverflow.com/a/25085385
Boolean.TRUE);
textIndex.addTextIndexNormalizations(textIndexNormalization);
}
DefineFunction defineFunction = new DefineFunction("tf" + "@" + String.valueOf(CountVectorizerModelConverter.SEQUENCE.getAndIncrement()), OpType.CONTINUOUS, DataType.INTEGER, null, textIndex).addParameterFields(documentField, termField);
encoder.addDefineFunction(defineFunction);
List<Feature> result = new ArrayList<>();
String[] vocabulary = transformer.vocabulary();
for (int i = 0; i < vocabulary.length; i++) {
String term = vocabulary[i];
if (TermUtil.hasPunctuation(term)) {
throw new IllegalArgumentException("Punctuated vocabulary terms (" + term + ") are not supported");
}
result.add(new TermFeature(encoder, defineFunction, documentFeature, term));
}
return result;
}
use of org.dmg.pmml.FieldRef in project jpmml-sparkml by jpmml.
the class TermFeature method toWeightedTermFeature.
public WeightedTermFeature toWeightedTermFeature(Number weight) {
PMMLEncoder encoder = getEncoder();
DefineFunction defineFunction = getDefineFunction();
String name = (defineFunction.getName()).replace("tf@", "tf-idf@");
DefineFunction weightedDefineFunction = encoder.getDefineFunction(name);
if (weightedDefineFunction == null) {
ParameterField weightField = new ParameterField(FieldName.create("weight"));
List<ParameterField> weightedParameterFields = new ArrayList<>(defineFunction.getParameterFields());
weightedParameterFields.add(weightField);
Apply apply = PMMLUtil.createApply(PMMLFunctions.MULTIPLY, defineFunction.getExpression(), new FieldRef(weightField.getName()));
weightedDefineFunction = new DefineFunction(name, OpType.CONTINUOUS, DataType.DOUBLE, weightedParameterFields, apply);
encoder.addDefineFunction(weightedDefineFunction);
}
return new WeightedTermFeature(encoder, weightedDefineFunction, getFeature(), getValue(), weight);
}
use of org.dmg.pmml.FieldRef in project jpmml-sparkml by jpmml.
the class SQLTransformerConverter method encodeLogicalPlan.
public static List<Field<?>> encodeLogicalPlan(SparkMLEncoder encoder, LogicalPlan logicalPlan) {
List<Field<?>> result = new ArrayList<>();
List<LogicalPlan> children = JavaConversions.seqAsJavaList(logicalPlan.children());
for (LogicalPlan child : children) {
encodeLogicalPlan(encoder, child);
}
List<Expression> expressions = JavaConversions.seqAsJavaList(logicalPlan.expressions());
for (Expression expression : expressions) {
org.dmg.pmml.Expression pmmlExpression = ExpressionTranslator.translate(encoder, expression);
if (pmmlExpression instanceof FieldRef) {
FieldRef fieldRef = (FieldRef) pmmlExpression;
Field<?> field = ensureField(encoder, fieldRef.getField());
if (field != null) {
result.add(field);
continue;
}
}
FieldName name = null;
if (pmmlExpression instanceof AliasExpression) {
AliasExpression aliasExpression = (AliasExpression) pmmlExpression;
name = FieldName.create(aliasExpression.getName());
} else {
name = FieldNameUtil.create("sql", ExpressionUtil.format(expression));
}
DataType dataType = DatasetUtil.translateDataType(expression.dataType());
OpType opType = ExpressionUtil.getOpType(dataType);
pmmlExpression = AliasExpression.unwrap(pmmlExpression);
Visitor visitor = new AbstractVisitor() {
@Override
public VisitorAction visit(FieldRef fieldRef) {
ensureField(encoder, fieldRef.getField());
return super.visit(fieldRef);
}
};
visitor.applyTo(pmmlExpression);
DerivedField derivedField = encoder.createDerivedField(name, opType, dataType, pmmlExpression);
result.add(derivedField);
}
return result;
}
Aggregations