Search in sources :

Example 1 with VirtualEvaluationContext

use of org.jpmml.evaluator.VirtualEvaluationContext in project jpmml-sparkml by jpmml.

the class ExpressionTranslatorTest method checkValue.

public static void checkValue(Object expectedValue, String sqlExpression) {
    ConverterFactory converterFactory = new ConverterFactory(Collections.emptyMap());
    SparkMLEncoder encoder = new SparkMLEncoder(ExpressionTranslatorTest.schema, converterFactory);
    Expression expression = translateInternal("SELECT " + sqlExpression + " FROM __THIS__");
    Object sparkValue = expression.eval(InternalRow.empty());
    if (expectedValue instanceof String) {
        assertEquals(expectedValue, sparkValue.toString());
    } else if (expectedValue instanceof Integer) {
        assertEquals(expectedValue, ((Number) sparkValue).intValue());
    } else if (expectedValue instanceof Float) {
        assertEquals(expectedValue, ((Number) sparkValue).floatValue());
    } else if (expectedValue instanceof Double) {
        assertEquals(expectedValue, ((Number) sparkValue).doubleValue());
    } else {
        assertEquals(expectedValue, sparkValue);
    }
    org.dmg.pmml.Expression pmmlExpression = ExpressionTranslator.translate(encoder, expression);
    pmmlExpression = AliasExpression.unwrap(pmmlExpression);
    PMML pmml = encoder.encodePMML();
    EvaluationContext context = new VirtualEvaluationContext() {

        @Override
        public FieldValue resolve(FieldName name) {
            TransformationDictionary transformationDictionary = pmml.getTransformationDictionary();
            if (transformationDictionary != null && transformationDictionary.hasDerivedFields()) {
                List<DerivedField> derivedFields = transformationDictionary.getDerivedFields();
                for (DerivedField derivedField : derivedFields) {
                    if (Objects.equals(derivedField.getName(), name)) {
                        return ExpressionUtil.evaluate(derivedField, this);
                    }
                }
            }
            return super.resolve(name);
        }
    };
    context.declareAll(Collections.emptyMap());
    FieldValue value = ExpressionUtil.evaluate(pmmlExpression, context);
    Object pmmlValue = FieldValueUtil.getValue(value);
    assertEquals(expectedValue, pmmlValue);
}
Also used : TransformationDictionary(org.dmg.pmml.TransformationDictionary) Expression(org.apache.spark.sql.catalyst.expressions.Expression) PMML(org.dmg.pmml.PMML) EvaluationContext(org.jpmml.evaluator.EvaluationContext) VirtualEvaluationContext(org.jpmml.evaluator.VirtualEvaluationContext) FieldValue(org.jpmml.evaluator.FieldValue) VirtualEvaluationContext(org.jpmml.evaluator.VirtualEvaluationContext) FieldName(org.dmg.pmml.FieldName) DerivedField(org.dmg.pmml.DerivedField)

Aggregations

Expression (org.apache.spark.sql.catalyst.expressions.Expression)1 DerivedField (org.dmg.pmml.DerivedField)1 FieldName (org.dmg.pmml.FieldName)1 PMML (org.dmg.pmml.PMML)1 TransformationDictionary (org.dmg.pmml.TransformationDictionary)1 EvaluationContext (org.jpmml.evaluator.EvaluationContext)1 FieldValue (org.jpmml.evaluator.FieldValue)1 VirtualEvaluationContext (org.jpmml.evaluator.VirtualEvaluationContext)1