use of com.facebook.presto.spi.relation.LambdaDefinitionExpression in project presto by prestodb.
the class TestTranslateExpressions method testTranslateIntermediateAggregationWithLambda.
@Test
public void testTranslateIntermediateAggregationWithLambda() {
PlanNode result = tester().assertThat(new TranslateExpressions(METADATA, new SqlParser()).aggregationRowExpressionRewriteRule()).on(p -> p.aggregation(builder -> builder.globalGrouping().addAggregation(variable("reduce_agg", INTEGER), new AggregationNode.Aggregation(new CallExpression("reduce_agg", REDUCE_AGG, INTEGER, ImmutableList.of(castToRowExpression(expression("input")), castToRowExpression(expression("(x,y) -> x*y")), castToRowExpression(expression("(a,b) -> a*b")))), Optional.of(castToRowExpression(expression("input > 10"))), Optional.empty(), false, Optional.empty())).source(p.values(p.variable("input", INTEGER))))).get();
AggregationNode.Aggregation translated = ((AggregationNode) result).getAggregations().get(variable("reduce_agg", INTEGER));
assertEquals(translated, new AggregationNode.Aggregation(new CallExpression("reduce_agg", REDUCE_AGG, INTEGER, ImmutableList.of(variable("input", INTEGER), new LambdaDefinitionExpression(Optional.empty(), ImmutableList.of(INTEGER, INTEGER), ImmutableList.of("x", "y"), multiply(variable("x", INTEGER), variable("y", INTEGER))), new LambdaDefinitionExpression(Optional.empty(), ImmutableList.of(INTEGER, INTEGER), ImmutableList.of("a", "b"), multiply(variable("a", INTEGER), variable("b", INTEGER))))), Optional.of(greaterThan(variable("input", INTEGER), constant(10L, INTEGER))), Optional.empty(), false, Optional.empty()));
assertFalse(isUntranslated(translated));
}
use of com.facebook.presto.spi.relation.LambdaDefinitionExpression in project presto by prestodb.
the class TestTranslateExpressions method testTranslateAggregationWithLambda.
@Test
public void testTranslateAggregationWithLambda() {
PlanNode result = tester().assertThat(new TranslateExpressions(METADATA, new SqlParser()).aggregationRowExpressionRewriteRule()).on(p -> p.aggregation(builder -> builder.globalGrouping().addAggregation(variable("reduce_agg", INTEGER), new AggregationNode.Aggregation(new CallExpression("reduce_agg", REDUCE_AGG, INTEGER, ImmutableList.of(castToRowExpression(expression("input")), castToRowExpression(expression("0")), castToRowExpression(expression("(x,y) -> x*y")), castToRowExpression(expression("(a,b) -> a*b")))), Optional.of(castToRowExpression(expression("input > 10"))), Optional.empty(), false, Optional.empty())).source(p.values(p.variable("input", INTEGER))))).get();
// TODO migrate this to RowExpressionMatcher
AggregationNode.Aggregation translated = ((AggregationNode) result).getAggregations().get(variable("reduce_agg", INTEGER));
assertEquals(translated, new AggregationNode.Aggregation(new CallExpression("reduce_agg", REDUCE_AGG, INTEGER, ImmutableList.of(variable("input", INTEGER), constant(0L, INTEGER), new LambdaDefinitionExpression(Optional.empty(), ImmutableList.of(INTEGER, INTEGER), ImmutableList.of("x", "y"), multiply(variable("x", INTEGER), variable("y", INTEGER))), new LambdaDefinitionExpression(Optional.empty(), ImmutableList.of(INTEGER, INTEGER), ImmutableList.of("a", "b"), multiply(variable("a", INTEGER), variable("b", INTEGER))))), Optional.of(greaterThan(variable("input", INTEGER), constant(10L, INTEGER))), Optional.empty(), false, Optional.empty()));
assertFalse(isUntranslated(translated));
}
use of com.facebook.presto.spi.relation.LambdaDefinitionExpression in project presto by prestodb.
the class CursorProcessorCompiler method fieldReferenceCompiler.
static RowExpressionVisitor<BytecodeNode, Scope> fieldReferenceCompiler(Map<VariableReferenceExpression, CommonSubExpressionFields> variableMap) {
return new RowExpressionVisitor<BytecodeNode, Scope>() {
@Override
public BytecodeNode visitInputReference(InputReferenceExpression node, Scope scope) {
int field = node.getField();
Type type = node.getType();
Variable wasNullVariable = scope.getVariable("wasNull");
Variable cursorVariable = scope.getVariable("cursor");
Class<?> javaType = type.getJavaType();
if (!javaType.isPrimitive() && javaType != Slice.class) {
javaType = Object.class;
}
IfStatement ifStatement = new IfStatement();
ifStatement.condition().setDescription(format("cursor.get%s(%d)", type, field)).getVariable(cursorVariable).push(field).invokeInterface(RecordCursor.class, "isNull", boolean.class, int.class);
ifStatement.ifTrue().putVariable(wasNullVariable, true).pushJavaDefault(javaType);
ifStatement.ifFalse().getVariable(cursorVariable).push(field).invokeInterface(RecordCursor.class, "get" + Primitives.wrap(javaType).getSimpleName(), javaType, int.class);
return ifStatement;
}
@Override
public BytecodeNode visitCall(CallExpression call, Scope scope) {
throw new UnsupportedOperationException("not yet implemented");
}
@Override
public BytecodeNode visitConstant(ConstantExpression literal, Scope scope) {
throw new UnsupportedOperationException("not yet implemented");
}
@Override
public BytecodeNode visitLambda(LambdaDefinitionExpression lambda, Scope context) {
throw new UnsupportedOperationException();
}
@Override
public BytecodeNode visitVariableReference(VariableReferenceExpression reference, Scope context) {
CommonSubExpressionFields fields = variableMap.get(reference);
return new BytecodeBlock().append(context.getThis().invoke(fields.getMethodName(), fields.getResultType(), context.getVariable("properties"), context.getVariable("cursor"))).append(unboxPrimitiveIfNecessary(context, Primitives.wrap(reference.getType().getJavaType())));
}
@Override
public BytecodeNode visitSpecialForm(SpecialFormExpression specialForm, Scope context) {
throw new UnsupportedOperationException();
}
};
}
use of com.facebook.presto.spi.relation.LambdaDefinitionExpression in project presto by prestodb.
the class TestSubExpressions method testExtract.
@Test
void testExtract() {
RowExpression a = variable("a", BIGINT);
RowExpression b = constant(1L, BIGINT);
RowExpression c = call(ADD, a, b);
RowExpression d = new LambdaDefinitionExpression(Optional.empty(), ImmutableList.of(BIGINT), ImmutableList.of("a"), c);
RowExpression e = constant(1L, BIGINT);
RowExpression f = specialForm(BIND, new FunctionType(ImmutableList.of(BIGINT), BIGINT), e, d);
assertEquals(subExpressions(a), ImmutableList.of(a));
assertEquals(subExpressions(b), ImmutableList.of(b));
assertEquals(subExpressions(c), ImmutableList.of(c, a, b));
assertEquals(subExpressions(d), ImmutableList.of(d, c, a, b));
assertEquals(subExpressions(f), ImmutableList.of(f, e, d, c, a, b));
assertEqualsIgnoreOrder(uniqueSubExpressions(f), ImmutableSet.of(a, b, c, d, f));
}
use of com.facebook.presto.spi.relation.LambdaDefinitionExpression in project presto by prestodb.
the class PageFunctionCompiler method defineFilterClass.
private ClassDefinition defineFilterClass(SqlFunctionProperties sqlFunctionProperties, Map<SqlFunctionId, SqlInvokedFunction> sessionFunctions, RowExpression filter, InputChannels inputChannels, CallSiteBinder callSiteBinder, boolean isOptimizeCommonSubExpression, Optional<String> classNameSuffix) {
ClassDefinition classDefinition = new ClassDefinition(a(PUBLIC, FINAL), generateFilterClassName(classNameSuffix), type(Object.class), type(PageFilter.class));
CachedInstanceBinder cachedInstanceBinder = new CachedInstanceBinder(classDefinition, callSiteBinder);
Map<LambdaDefinitionExpression, CompiledLambda> compiledLambdaMap = generateMethodsForLambda(classDefinition, callSiteBinder, cachedInstanceBinder, filter, metadata, sqlFunctionProperties, sessionFunctions);
// cse
Map<VariableReferenceExpression, CommonSubExpressionFields> cseFields = ImmutableMap.of();
RowExpressionCompiler compiler = new RowExpressionCompiler(classDefinition, callSiteBinder, cachedInstanceBinder, new FieldAndVariableReferenceCompiler(callSiteBinder, cseFields), metadata, sqlFunctionProperties, sessionFunctions, compiledLambdaMap);
if (isOptimizeCommonSubExpression) {
Map<Integer, Map<RowExpression, VariableReferenceExpression>> commonSubExpressionsByLevel = collectCSEByLevel(filter);
if (!commonSubExpressionsByLevel.isEmpty()) {
cseFields = declareCommonSubExpressionFields(classDefinition, commonSubExpressionsByLevel);
compiler = new RowExpressionCompiler(classDefinition, callSiteBinder, cachedInstanceBinder, new FieldAndVariableReferenceCompiler(callSiteBinder, cseFields), metadata, sqlFunctionProperties, sessionFunctions, compiledLambdaMap);
generateCommonSubExpressionMethods(classDefinition, compiler, commonSubExpressionsByLevel, cseFields);
Map<RowExpression, VariableReferenceExpression> commonSubExpressions = commonSubExpressionsByLevel.values().stream().flatMap(m -> m.entrySet().stream()).collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
filter = rewriteExpressionWithCSE(filter, commonSubExpressions);
if (log.isDebugEnabled()) {
log.debug("Extracted %d common sub-expressions", commonSubExpressions.size());
commonSubExpressions.entrySet().forEach(entry -> log.debug("\t%s = %s", entry.getValue(), entry.getKey()));
log.debug("Rewrote filter: %s", filter);
}
}
}
generateFilterMethod(classDefinition, compiler, filter, cseFields);
FieldDefinition selectedPositions = classDefinition.declareField(a(PRIVATE), "selectedPositions", boolean[].class);
generatePageFilterMethod(classDefinition, selectedPositions);
// isDeterministic
classDefinition.declareMethod(a(PUBLIC), "isDeterministic", type(boolean.class)).getBody().append(constantBoolean(determinismEvaluator.isDeterministic(filter))).retBoolean();
// getInputChannels
classDefinition.declareMethod(a(PUBLIC), "getInputChannels", type(InputChannels.class)).getBody().append(invoke(callSiteBinder.bind(inputChannels, InputChannels.class), "getInputChannels")).retObject();
// toString
String toStringResult = toStringHelper(classDefinition.getType().getJavaClassName()).add("filter", filter).toString();
classDefinition.declareMethod(a(PUBLIC), "toString", type(String.class)).getBody().append(invoke(callSiteBinder.bind(toStringResult, String.class), "toString")).retObject();
// constructor
MethodDefinition constructorDefinition = classDefinition.declareConstructor(a(PUBLIC));
BytecodeBlock body = constructorDefinition.getBody();
Variable thisVariable = constructorDefinition.getThis();
body.comment("super();").append(thisVariable).invokeConstructor(Object.class).append(thisVariable.setField(selectedPositions, newArray(type(boolean[].class), 0)));
initializeCommonSubExpressionFields(cseFields.values(), thisVariable, body);
cachedInstanceBinder.generateInitializations(thisVariable, body);
body.ret();
return classDefinition;
}
Aggregations