use of org.jpmml.evaluator.EvaluationContext in project jpmml-sparkml by jpmml.
the class ExpressionTranslatorTest method checkValue.
public static void checkValue(Object expectedValue, String sqlExpression) {
ConverterFactory converterFactory = new ConverterFactory(Collections.emptyMap());
SparkMLEncoder encoder = new SparkMLEncoder(ExpressionTranslatorTest.schema, converterFactory);
Expression expression = translateInternal("SELECT " + sqlExpression + " FROM __THIS__");
Object sparkValue = expression.eval(InternalRow.empty());
if (expectedValue instanceof String) {
assertEquals(expectedValue, sparkValue.toString());
} else if (expectedValue instanceof Integer) {
assertEquals(expectedValue, ((Number) sparkValue).intValue());
} else if (expectedValue instanceof Float) {
assertEquals(expectedValue, ((Number) sparkValue).floatValue());
} else if (expectedValue instanceof Double) {
assertEquals(expectedValue, ((Number) sparkValue).doubleValue());
} else {
assertEquals(expectedValue, sparkValue);
}
org.dmg.pmml.Expression pmmlExpression = ExpressionTranslator.translate(encoder, expression);
pmmlExpression = AliasExpression.unwrap(pmmlExpression);
PMML pmml = encoder.encodePMML();
EvaluationContext context = new VirtualEvaluationContext() {
@Override
public FieldValue resolve(FieldName name) {
TransformationDictionary transformationDictionary = pmml.getTransformationDictionary();
if (transformationDictionary != null && transformationDictionary.hasDerivedFields()) {
List<DerivedField> derivedFields = transformationDictionary.getDerivedFields();
for (DerivedField derivedField : derivedFields) {
if (Objects.equals(derivedField.getName(), name)) {
return ExpressionUtil.evaluate(derivedField, this);
}
}
}
return super.resolve(name);
}
};
context.declareAll(Collections.emptyMap());
FieldValue value = ExpressionUtil.evaluate(pmmlExpression, context);
Object pmmlValue = FieldValueUtil.getValue(value);
assertEquals(expectedValue, pmmlValue);
}
Aggregations