use of org.dmg.pmml.DerivedField in project jpmml-sparkml by jpmml.
the class ExpressionTranslator method translateInternal.
private org.dmg.pmml.Expression translateInternal(Expression expression) {
SparkMLEncoder encoder = getEncoder();
if (expression instanceof Alias) {
Alias alias = (Alias) expression;
String name = alias.name();
Expression child = alias.child();
org.dmg.pmml.Expression pmmlExpression = translateInternal(child);
return new AliasExpression(name, pmmlExpression);
}
if (expression instanceof AttributeReference) {
AttributeReference attributeReference = (AttributeReference) expression;
String name = attributeReference.name();
return new FieldRef(FieldName.create(name));
} else if (expression instanceof BinaryMathExpression) {
BinaryMathExpression binaryMathExpression = (BinaryMathExpression) expression;
Expression left = binaryMathExpression.left();
Expression right = binaryMathExpression.right();
String function;
if (binaryMathExpression instanceof Hypot) {
function = PMMLFunctions.HYPOT;
} else if (binaryMathExpression instanceof Pow) {
function = PMMLFunctions.POW;
} else {
throw new IllegalArgumentException(formatMessage(binaryMathExpression));
}
return PMMLUtil.createApply(function, translateInternal(left), translateInternal(right));
} else if (expression instanceof BinaryOperator) {
BinaryOperator binaryOperator = (BinaryOperator) expression;
String symbol = binaryOperator.symbol();
Expression left = binaryOperator.left();
Expression right = binaryOperator.right();
String function;
if (expression instanceof And || expression instanceof Or) {
switch(symbol) {
case "&&":
function = PMMLFunctions.AND;
break;
case "||":
function = PMMLFunctions.OR;
break;
default:
throw new IllegalArgumentException(formatMessage(binaryOperator));
}
} else if (expression instanceof Add || expression instanceof Divide || expression instanceof Multiply || expression instanceof Subtract) {
BinaryArithmetic binaryArithmetic = (BinaryArithmetic) binaryOperator;
switch(symbol) {
case "+":
function = PMMLFunctions.ADD;
break;
case "/":
function = PMMLFunctions.DIVIDE;
break;
case "*":
function = PMMLFunctions.MULTIPLY;
break;
case "-":
function = PMMLFunctions.SUBTRACT;
break;
default:
throw new IllegalArgumentException(formatMessage(binaryArithmetic));
}
} else if (expression instanceof EqualTo || expression instanceof GreaterThan || expression instanceof GreaterThanOrEqual || expression instanceof LessThan || expression instanceof LessThanOrEqual) {
BinaryComparison binaryComparison = (BinaryComparison) binaryOperator;
switch(symbol) {
case "=":
function = PMMLFunctions.EQUAL;
break;
case ">":
function = PMMLFunctions.GREATERTHAN;
break;
case ">=":
function = PMMLFunctions.GREATEROREQUAL;
break;
case "<":
function = PMMLFunctions.LESSTHAN;
break;
case "<=":
function = PMMLFunctions.LESSOREQUAL;
break;
default:
throw new IllegalArgumentException(formatMessage(binaryComparison));
}
} else {
throw new IllegalArgumentException(formatMessage(binaryOperator));
}
return PMMLUtil.createApply(function, translateInternal(left), translateInternal(right));
} else if (expression instanceof CaseWhen) {
CaseWhen caseWhen = (CaseWhen) expression;
List<Tuple2<Expression, Expression>> branches = JavaConversions.seqAsJavaList(caseWhen.branches());
Option<Expression> elseValue = caseWhen.elseValue();
Apply apply = null;
Iterator<Tuple2<Expression, Expression>> branchIt = branches.iterator();
Apply prevBranchApply = null;
do {
Tuple2<Expression, Expression> branch = branchIt.next();
Expression predicate = branch._1();
Expression value = branch._2();
Apply branchApply = PMMLUtil.createApply(PMMLFunctions.IF, translateInternal(predicate), translateInternal(value));
if (apply == null) {
apply = branchApply;
}
if (prevBranchApply != null) {
prevBranchApply.addExpressions(branchApply);
}
prevBranchApply = branchApply;
} while (branchIt.hasNext());
if (elseValue.isDefined()) {
Expression value = elseValue.get();
prevBranchApply.addExpressions(translateInternal(value));
}
return apply;
} else if (expression instanceof Cast) {
Cast cast = (Cast) expression;
Expression child = cast.child();
org.dmg.pmml.Expression pmmlExpression = translateInternal(child);
DataType dataType = DatasetUtil.translateDataType(cast.dataType());
if (pmmlExpression instanceof HasDataType) {
HasDataType<?> hasDataType = (HasDataType<?>) pmmlExpression;
hasDataType.setDataType(dataType);
return pmmlExpression;
} else {
FieldName name;
if (pmmlExpression instanceof AliasExpression) {
AliasExpression aliasExpression = (AliasExpression) pmmlExpression;
name = FieldName.create(aliasExpression.getName());
} else {
name = FieldNameUtil.create(dataType, ExpressionUtil.format(child));
}
OpType opType = ExpressionUtil.getOpType(dataType);
pmmlExpression = AliasExpression.unwrap(pmmlExpression);
DerivedField derivedField = encoder.createDerivedField(name, opType, dataType, pmmlExpression);
return new FieldRef(derivedField.getName());
}
} else if (expression instanceof Concat) {
Concat concat = (Concat) expression;
List<Expression> children = JavaConversions.seqAsJavaList(concat.children());
Apply apply = PMMLUtil.createApply(PMMLFunctions.CONCAT);
for (Expression child : children) {
apply.addExpressions(translateInternal(child));
}
return apply;
} else if (expression instanceof Greatest) {
Greatest greatest = (Greatest) expression;
List<Expression> children = JavaConversions.seqAsJavaList(greatest.children());
Apply apply = PMMLUtil.createApply(PMMLFunctions.MAX);
for (Expression child : children) {
apply.addExpressions(translateInternal(child));
}
return apply;
} else if (expression instanceof If) {
If _if = (If) expression;
Expression predicate = _if.predicate();
Expression trueValue = _if.trueValue();
Expression falseValue = _if.falseValue();
return PMMLUtil.createApply(PMMLFunctions.IF, translateInternal(predicate), translateInternal(trueValue), translateInternal(falseValue));
} else if (expression instanceof In) {
In in = (In) expression;
Expression value = in.value();
List<Expression> elements = JavaConversions.seqAsJavaList(in.list());
Apply apply = PMMLUtil.createApply(PMMLFunctions.ISIN, translateInternal(value));
for (Expression element : elements) {
apply.addExpressions(translateInternal(element));
}
return apply;
} else if (expression instanceof Least) {
Least least = (Least) expression;
List<Expression> children = JavaConversions.seqAsJavaList(least.children());
Apply apply = PMMLUtil.createApply(PMMLFunctions.MIN);
for (Expression child : children) {
apply.addExpressions(translateInternal(child));
}
return apply;
} else if (expression instanceof Length) {
Length length = (Length) expression;
Expression child = length.child();
return PMMLUtil.createApply(PMMLFunctions.STRINGLENGTH, translateInternal(child));
} else if (expression instanceof Literal) {
Literal literal = (Literal) expression;
Object value = literal.value();
if (value == null) {
return PMMLUtil.createMissingConstant();
}
DataType dataType;
// XXX
if (value instanceof Decimal) {
Decimal decimal = (Decimal) value;
dataType = DataType.STRING;
value = decimal.toString();
} else {
dataType = DatasetUtil.translateDataType(literal.dataType());
value = toSimpleObject(value);
}
return PMMLUtil.createConstant(value, dataType);
} else if (expression instanceof RegExpReplace) {
RegExpReplace regexpReplace = (RegExpReplace) expression;
Expression subject = regexpReplace.subject();
Expression regexp = regexpReplace.regexp();
Expression rep = regexpReplace.rep();
return PMMLUtil.createApply(PMMLFunctions.REPLACE, translateInternal(subject), translateInternal(regexp), translateInternal(rep));
} else if (expression instanceof RLike) {
RLike rlike = (RLike) expression;
Expression left = rlike.left();
Expression right = rlike.right();
return PMMLUtil.createApply(PMMLFunctions.MATCHES, translateInternal(left), translateInternal(right));
} else if (expression instanceof StringTrim) {
StringTrim stringTrim = (StringTrim) expression;
Expression srcStr = stringTrim.srcStr();
Option<Expression> trimStr = stringTrim.trimStr();
if (trimStr.isDefined()) {
throw new IllegalArgumentException();
}
return PMMLUtil.createApply(PMMLFunctions.TRIMBLANKS, translateInternal(srcStr));
} else if (expression instanceof Substring) {
Substring substring = (Substring) expression;
Expression str = substring.str();
Literal pos = (Literal) substring.pos();
Literal len = (Literal) substring.len();
int posValue = ValueUtil.asInt((Number) pos.value());
if (posValue <= 0) {
throw new IllegalArgumentException("Expected absolute start position, got relative start position " + (pos));
}
int lenValue = ValueUtil.asInt((Number) len.value());
// XXX
lenValue = Math.min(lenValue, MAX_STRING_LENGTH);
return PMMLUtil.createApply(PMMLFunctions.SUBSTRING, translateInternal(str), PMMLUtil.createConstant(posValue), PMMLUtil.createConstant(lenValue));
} else if (expression instanceof UnaryExpression) {
UnaryExpression unaryExpression = (UnaryExpression) expression;
Expression child = unaryExpression.child();
if (expression instanceof Abs) {
return PMMLUtil.createApply(PMMLFunctions.ABS, translateInternal(child));
} else if (expression instanceof Acos) {
return PMMLUtil.createApply(PMMLFunctions.ACOS, translateInternal(child));
} else if (expression instanceof Asin) {
return PMMLUtil.createApply(PMMLFunctions.ASIN, translateInternal(child));
} else if (expression instanceof Atan) {
return PMMLUtil.createApply(PMMLFunctions.ATAN, translateInternal(child));
} else if (expression instanceof Ceil) {
return PMMLUtil.createApply(PMMLFunctions.CEIL, translateInternal(child));
} else if (expression instanceof Cos) {
return PMMLUtil.createApply(PMMLFunctions.COS, translateInternal(child));
} else if (expression instanceof Cosh) {
return PMMLUtil.createApply(PMMLFunctions.COSH, translateInternal(child));
} else if (expression instanceof Exp) {
return PMMLUtil.createApply(PMMLFunctions.EXP, translateInternal(child));
} else if (expression instanceof Expm1) {
return PMMLUtil.createApply(PMMLFunctions.EXPM1, translateInternal(child));
} else if (expression instanceof Floor) {
return PMMLUtil.createApply(PMMLFunctions.FLOOR, translateInternal(child));
} else if (expression instanceof Log) {
return PMMLUtil.createApply(PMMLFunctions.LN, translateInternal(child));
} else if (expression instanceof Log10) {
return PMMLUtil.createApply(PMMLFunctions.LOG10, translateInternal(child));
} else if (expression instanceof Log1p) {
return PMMLUtil.createApply(PMMLFunctions.LN1P, translateInternal(child));
} else if (expression instanceof Lower) {
return PMMLUtil.createApply(PMMLFunctions.LOWERCASE, translateInternal(child));
} else if (expression instanceof IsNaN) {
// XXX
return PMMLUtil.createApply(PMMLFunctions.ISNOTVALID, translateInternal(child));
} else if (expression instanceof IsNotNull) {
return PMMLUtil.createApply(PMMLFunctions.ISNOTMISSING, translateInternal(child));
} else if (expression instanceof IsNull) {
return PMMLUtil.createApply(PMMLFunctions.ISMISSING, translateInternal(child));
} else if (expression instanceof Not) {
return PMMLUtil.createApply(PMMLFunctions.NOT, translateInternal(child));
} else if (expression instanceof Rint) {
return PMMLUtil.createApply(PMMLFunctions.RINT, translateInternal(child));
} else if (expression instanceof Sin) {
return PMMLUtil.createApply(PMMLFunctions.SIN, translateInternal(child));
} else if (expression instanceof Sinh) {
return PMMLUtil.createApply(PMMLFunctions.SINH, translateInternal(child));
} else if (expression instanceof Sqrt) {
return PMMLUtil.createApply(PMMLFunctions.SQRT, translateInternal(child));
} else if (expression instanceof Tan) {
return PMMLUtil.createApply(PMMLFunctions.TAN, translateInternal(child));
} else if (expression instanceof Tanh) {
return PMMLUtil.createApply(PMMLFunctions.TANH, translateInternal(child));
} else if (expression instanceof UnaryMinus) {
return PMMLUtil.toNegative(translateInternal(child));
} else if (expression instanceof UnaryPositive) {
return translateInternal(child);
} else if (expression instanceof Upper) {
return PMMLUtil.createApply(PMMLFunctions.UPPERCASE, translateInternal(child));
} else {
throw new IllegalArgumentException(formatMessage(unaryExpression));
}
} else {
throw new IllegalArgumentException(formatMessage(expression));
}
}
use of org.dmg.pmml.DerivedField in project jpmml-sparkml by jpmml.
the class PMMLBuilder method build.
public PMML build() {
StructType schema = getSchema();
PipelineModel pipelineModel = getPipelineModel();
Map<RegexKey, ? extends Map<String, ?>> options = getOptions();
Verification verification = getVerification();
ConverterFactory converterFactory = new ConverterFactory(options);
SparkMLEncoder encoder = new SparkMLEncoder(schema, converterFactory);
Map<FieldName, DerivedField> derivedFields = encoder.getDerivedFields();
List<org.dmg.pmml.Model> models = new ArrayList<>();
List<String> predictionColumns = new ArrayList<>();
List<String> probabilityColumns = new ArrayList<>();
// Transformations preceding the last model
List<FieldName> preProcessorNames = Collections.emptyList();
Iterable<Transformer> transformers = getTransformers(pipelineModel);
for (Transformer transformer : transformers) {
TransformerConverter<?> converter = converterFactory.newConverter(transformer);
if (converter instanceof FeatureConverter) {
FeatureConverter<?> featureConverter = (FeatureConverter<?>) converter;
featureConverter.registerFeatures(encoder);
} else if (converter instanceof ModelConverter) {
ModelConverter<?> modelConverter = (ModelConverter<?>) converter;
org.dmg.pmml.Model model = modelConverter.registerModel(encoder);
models.add(model);
featureImportances: if (modelConverter instanceof HasFeatureImportances) {
HasFeatureImportances hasFeatureImportances = (HasFeatureImportances) modelConverter;
Boolean estimateFeatureImportances = (Boolean) modelConverter.getOption(HasTreeOptions.OPTION_ESTIMATE_FEATURE_IMPORTANCES, Boolean.FALSE);
if (!estimateFeatureImportances) {
break featureImportances;
}
List<Double> featureImportances = VectorUtil.toList(hasFeatureImportances.getFeatureImportances());
List<Feature> features = modelConverter.getFeatures(encoder);
SchemaUtil.checkSize(featureImportances.size(), features);
for (int i = 0; i < featureImportances.size(); i++) {
Double featureImportance = featureImportances.get(i);
Feature feature = features.get(i);
encoder.addFeatureImportance(model, feature, featureImportance);
}
}
hasPredictionCol: if (transformer instanceof HasPredictionCol) {
HasPredictionCol hasPredictionCol = (HasPredictionCol) transformer;
// XXX
if ((transformer instanceof GeneralizedLinearRegressionModel) && (MiningFunction.CLASSIFICATION).equals(model.getMiningFunction())) {
break hasPredictionCol;
}
predictionColumns.add(hasPredictionCol.getPredictionCol());
}
if (transformer instanceof HasProbabilityCol) {
HasProbabilityCol hasProbabilityCol = (HasProbabilityCol) transformer;
probabilityColumns.add(hasProbabilityCol.getProbabilityCol());
}
preProcessorNames = new ArrayList<>(derivedFields.keySet());
} else {
throw new IllegalArgumentException("Expected a subclass of " + FeatureConverter.class.getName() + " or " + ModelConverter.class.getName() + ", got " + (converter != null ? ("class " + (converter.getClass()).getName()) : null));
}
}
// Transformations following the last model
List<FieldName> postProcessorNames = new ArrayList<>(derivedFields.keySet());
postProcessorNames.removeAll(preProcessorNames);
org.dmg.pmml.Model model;
if (models.size() == 0) {
model = null;
} else if (models.size() == 1) {
model = Iterables.getOnlyElement(models);
} else {
model = MiningModelUtil.createModelChain(models);
}
if ((model != null) && (postProcessorNames.size() > 0)) {
org.dmg.pmml.Model finalModel = MiningModelUtil.getFinalModel(model);
Output output = ModelUtil.ensureOutput(finalModel);
for (FieldName postProcessorName : postProcessorNames) {
DerivedField derivedField = derivedFields.get(postProcessorName);
encoder.removeDerivedField(postProcessorName);
OutputField outputField = new OutputField(derivedField.getName(), derivedField.getOpType(), derivedField.getDataType()).setResultFeature(ResultFeature.TRANSFORMED_VALUE).setExpression(derivedField.getExpression());
output.addOutputFields(outputField);
}
}
PMML pmml = encoder.encodePMML(model);
if ((model != null) && (predictionColumns.size() > 0 || probabilityColumns.size() > 0) && (verification != null)) {
Dataset<Row> dataset = verification.getDataset();
Dataset<Row> transformedDataset = verification.getTransformedDataset();
Double precision = verification.getPrecision();
Double zeroThreshold = verification.getZeroThreshold();
List<String> inputColumns = new ArrayList<>();
MiningSchema miningSchema = model.getMiningSchema();
List<MiningField> miningFields = miningSchema.getMiningFields();
for (MiningField miningField : miningFields) {
MiningField.UsageType usageType = miningField.getUsageType();
switch(usageType) {
case ACTIVE:
FieldName name = miningField.getName();
inputColumns.add(name.getValue());
break;
default:
break;
}
}
Map<VerificationField, List<?>> data = new LinkedHashMap<>();
for (String inputColumn : inputColumns) {
VerificationField verificationField = ModelUtil.createVerificationField(FieldName.create(inputColumn));
data.put(verificationField, getColumn(dataset, inputColumn));
}
for (String predictionColumn : predictionColumns) {
Feature feature = encoder.getOnlyFeature(predictionColumn);
VerificationField verificationField = ModelUtil.createVerificationField(feature.getName()).setPrecision(precision).setZeroThreshold(zeroThreshold);
data.put(verificationField, getColumn(transformedDataset, predictionColumn));
}
for (String probabilityColumn : probabilityColumns) {
List<Feature> features = encoder.getFeatures(probabilityColumn);
for (int i = 0; i < features.size(); i++) {
Feature feature = features.get(i);
VerificationField verificationField = ModelUtil.createVerificationField(feature.getName()).setPrecision(precision).setZeroThreshold(zeroThreshold);
data.put(verificationField, getVectorColumn(transformedDataset, probabilityColumn, i));
}
}
model.setModelVerification(ModelUtil.createModelVerification(data));
}
return pmml;
}
use of org.dmg.pmml.DerivedField in project jpmml-sparkml by jpmml.
the class PCAModelConverter method encodeFeatures.
@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
PCAModel transformer = getTransformer();
DenseMatrix pc = transformer.pc();
List<Feature> features = encoder.getFeatures(transformer.getInputCol());
MatrixUtil.checkRows(features.size(), pc);
List<Feature> result = new ArrayList<>();
for (int i = 0, length = transformer.getK(); i < length; i++) {
Apply apply = PMMLUtil.createApply(PMMLFunctions.SUM);
for (int j = 0; j < features.size(); j++) {
Feature feature = features.get(j);
ContinuousFeature continuousFeature = feature.toContinuousFeature();
Expression expression = continuousFeature.ref();
Double coefficient = pc.apply(j, i);
if (!ValueUtil.isOne(coefficient)) {
expression = PMMLUtil.createApply(PMMLFunctions.MULTIPLY, expression, PMMLUtil.createConstant(coefficient));
}
apply.addExpressions(expression);
}
DerivedField derivedField = encoder.createDerivedField(formatName(transformer, i, length), OpType.CONTINUOUS, DataType.DOUBLE, apply);
result.add(new ContinuousFeature(encoder, derivedField));
}
return result;
}
use of org.dmg.pmml.DerivedField in project jpmml-sparkml by jpmml.
the class StringIndexerModelConverter method encodeFeatures.
@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
StringIndexerModel transformer = getTransformer();
String[][] labelsArray = transformer.labelsArray();
InOutMode inputMode = getInputMode();
List<Feature> result = new ArrayList<>();
String[] inputCols = inputMode.getInputCols(transformer);
for (int i = 0; i < inputCols.length; i++) {
String inputCol = inputCols[i];
String[] labels = labelsArray[i];
Feature feature = encoder.getOnlyFeature(inputCol);
List<String> categories = new ArrayList<>();
categories.addAll(Arrays.asList(labels));
String invalidCategory;
DataType dataType = feature.getDataType();
switch(dataType) {
case INTEGER:
case FLOAT:
case DOUBLE:
invalidCategory = "-999";
break;
default:
invalidCategory = "__unknown";
break;
}
String handleInvalid = transformer.getHandleInvalid();
Field<?> field = encoder.toCategorical(feature.getName(), categories);
if (field instanceof DataField) {
DataField dataField = (DataField) field;
InvalidValueDecorator invalidValueDecorator;
switch(handleInvalid) {
case "keep":
{
invalidValueDecorator = new InvalidValueDecorator(InvalidValueTreatmentMethod.AS_VALUE, invalidCategory);
categories.add(invalidCategory);
}
break;
case "error":
{
invalidValueDecorator = new InvalidValueDecorator(InvalidValueTreatmentMethod.RETURN_INVALID, null);
}
break;
default:
throw new IllegalArgumentException("Invalid value handling strategy " + handleInvalid + " is not supported");
}
encoder.addDecorator(dataField, invalidValueDecorator);
} else if (field instanceof DerivedField) {
switch(handleInvalid) {
case "keep":
{
Apply setApply = PMMLUtil.createApply(PMMLFunctions.ISIN, feature.ref());
for (String category : categories) {
setApply.addExpressions(PMMLUtil.createConstant(category, dataType));
}
categories.add(invalidCategory);
Apply apply = PMMLUtil.createApply(PMMLFunctions.IF).addExpressions(setApply).addExpressions(feature.ref(), PMMLUtil.createConstant(invalidCategory, dataType));
field = encoder.createDerivedField(FieldNameUtil.create("handleInvalid", feature), OpType.CATEGORICAL, dataType, apply);
}
break;
case "error":
{
// Ignored: Assume that a DerivedField element can never return an erroneous field value
}
break;
default:
throw new IllegalArgumentException(handleInvalid);
}
} else {
throw new IllegalArgumentException();
}
result.add(new CategoricalFeature(encoder, field, categories));
}
return result;
}
use of org.dmg.pmml.DerivedField in project jpmml-sparkml by jpmml.
the class MaxAbsScalerModelConverter method encodeFeatures.
@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
MaxAbsScalerModel transformer = getTransformer();
Vector maxAbs = transformer.maxAbs();
List<Feature> features = encoder.getFeatures(transformer.getInputCol());
SchemaUtil.checkSize(maxAbs.size(), features);
List<Feature> result = new ArrayList<>();
for (int i = 0, length = features.size(); i < length; i++) {
Feature feature = features.get(i);
double maxAbsUnzero = maxAbs.apply(i);
if (maxAbsUnzero == 0d) {
maxAbsUnzero = 1d;
}
if (!ValueUtil.isOne(maxAbsUnzero)) {
ContinuousFeature continuousFeature = feature.toContinuousFeature();
Expression expression = PMMLUtil.createApply(PMMLFunctions.DIVIDE, continuousFeature.ref(), PMMLUtil.createConstant(maxAbsUnzero));
DerivedField derivedField = encoder.createDerivedField(formatName(transformer, i, length), OpType.CONTINUOUS, DataType.DOUBLE, expression);
feature = new ContinuousFeature(encoder, derivedField);
}
result.add(feature);
}
return result;
}
Aggregations