use of io.prestosql.sql.tree.ArithmeticBinaryExpression in project hetu-core by openlookeng.
the class TestDesugarTryExpressionRewriter method testTryExpressionDesugaringRewriter.
@Test
public void testTryExpressionDesugaringRewriter() {
// 1 + try(2)
Expression before = new ArithmeticBinaryExpression(ADD, new DecimalLiteral("1"), new TryExpression(new DecimalLiteral("2")));
// 1 + try_function(() -> 2)
Expression after = new ArithmeticBinaryExpression(ADD, new DecimalLiteral("1"), new FunctionCallBuilder(tester().getMetadata()).setName(QualifiedName.of(TryFunction.NAME)).addArgument(new FunctionType(ImmutableList.of(), createDecimalType(1)), new LambdaExpression(ImmutableList.of(), new DecimalLiteral("2"))).build());
assertEquals(DesugarTryExpressionRewriter.rewrite(before, tester().getMetadata(), tester().getTypeAnalyzer(), tester().getSession(), new PlanSymbolAllocator()), after);
}
use of io.prestosql.sql.tree.ArithmeticBinaryExpression in project hetu-core by openlookeng.
the class TestSqlParser method testPrecedenceAndAssociativity.
@Test
public void testPrecedenceAndAssociativity() {
assertExpression("1 AND 2 OR 3", new LogicalBinaryExpression(LogicalBinaryExpression.Operator.OR, new LogicalBinaryExpression(LogicalBinaryExpression.Operator.AND, new LongLiteral("1"), new LongLiteral("2")), new LongLiteral("3")));
assertExpression("1 OR 2 AND 3", new LogicalBinaryExpression(LogicalBinaryExpression.Operator.OR, new LongLiteral("1"), new LogicalBinaryExpression(LogicalBinaryExpression.Operator.AND, new LongLiteral("2"), new LongLiteral("3"))));
assertExpression("NOT 1 AND 2", new LogicalBinaryExpression(LogicalBinaryExpression.Operator.AND, new NotExpression(new LongLiteral("1")), new LongLiteral("2")));
assertExpression("NOT 1 OR 2", new LogicalBinaryExpression(LogicalBinaryExpression.Operator.OR, new NotExpression(new LongLiteral("1")), new LongLiteral("2")));
assertExpression("-1 + 2", new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.ADD, new LongLiteral("-1"), new LongLiteral("2")));
assertExpression("1 - 2 - 3", new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.SUBTRACT, new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.SUBTRACT, new LongLiteral("1"), new LongLiteral("2")), new LongLiteral("3")));
assertExpression("1 / 2 / 3", new ArithmeticBinaryExpression(DIVIDE, new ArithmeticBinaryExpression(DIVIDE, new LongLiteral("1"), new LongLiteral("2")), new LongLiteral("3")));
assertExpression("1 + 2 * 3", new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.ADD, new LongLiteral("1"), new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, new LongLiteral("2"), new LongLiteral("3"))));
}
use of io.prestosql.sql.tree.ArithmeticBinaryExpression in project hetu-core by openlookeng.
the class AggregationRewriteWithCube method rewrite.
public PlanNode rewrite(AggregationNode originalAggregationNode, PlanNode filterNode) {
QualifiedObjectName starTreeTableName = QualifiedObjectName.valueOf(cubeMetadata.getCubeName());
TableHandle cubeTableHandle = metadata.getTableHandle(session, starTreeTableName).orElseThrow(() -> new CubeNotFoundException(starTreeTableName.toString()));
Map<String, ColumnHandle> cubeColumnsMap = metadata.getColumnHandles(session, cubeTableHandle);
TableMetadata cubeTableMetadata = metadata.getTableMetadata(session, cubeTableHandle);
List<ColumnMetadata> cubeColumnMetadataList = cubeTableMetadata.getColumns();
// Add group by
List<Symbol> groupings = new ArrayList<>(originalAggregationNode.getGroupingKeys().size());
for (Symbol symbol : originalAggregationNode.getGroupingKeys()) {
Object column = symbolMappings.get(symbol.getName());
if (column instanceof ColumnHandle) {
groupings.add(new Symbol(((ColumnHandle) column).getColumnName()));
}
}
Set<String> cubeGroups = cubeMetadata.getGroup();
boolean exactGroupsMatch = false;
if (groupings.size() == cubeGroups.size()) {
exactGroupsMatch = groupings.stream().map(Symbol::getName).map(String::toLowerCase).allMatch(cubeGroups::contains);
}
CubeRewriteResult cubeRewriteResult = createScanNode(originalAggregationNode, filterNode, cubeTableHandle, cubeColumnsMap, cubeColumnMetadataList, exactGroupsMatch);
PlanNode planNode = cubeRewriteResult.getTableScanNode();
// Add filter node
if (filterNode != null) {
Expression expression = castToExpression(((FilterNode) filterNode).getPredicate());
expression = rewriteExpression(expression, rewrittenMappings);
planNode = new FilterNode(idAllocator.getNextId(), planNode, castToRowExpression(expression));
}
if (!exactGroupsMatch) {
Map<Symbol, Symbol> cubeScanToAggOutputMap = new HashMap<>();
// Rewrite AggregationNode using Cube table
ImmutableMap.Builder<Symbol, AggregationNode.Aggregation> aggregationsBuilder = ImmutableMap.builder();
for (CubeRewriteResult.AggregatorSource aggregatorSource : cubeRewriteResult.getAggregationColumns()) {
ColumnHandle cubeColHandle = cubeRewriteResult.getTableScanNode().getAssignments().get(aggregatorSource.getScanSymbol());
ColumnMetadata cubeColumnMetadata = cubeRewriteResult.getSymbolMetadataMap().get(aggregatorSource.getScanSymbol());
Type type = cubeColumnMetadata.getType();
AggregationSignature aggregationSignature = cubeMetadata.getAggregationSignature(cubeColumnMetadata.getName()).orElseThrow(() -> new ColumnNotFoundException(new SchemaTableName(starTreeTableName.getSchemaName(), starTreeTableName.getObjectName()), cubeColHandle.getColumnName()));
String aggFunction = COUNT.getName().equals(aggregationSignature.getFunction()) ? "sum" : aggregationSignature.getFunction();
SymbolReference argument = toSymbolReference(aggregatorSource.getScanSymbol());
FunctionHandle functionHandle = metadata.getFunctionAndTypeManager().lookupFunction(aggFunction, TypeSignatureProvider.fromTypeSignatures(type.getTypeSignature()));
cubeScanToAggOutputMap.put(aggregatorSource.getScanSymbol(), aggregatorSource.getOriginalAggSymbol());
aggregationsBuilder.put(aggregatorSource.getOriginalAggSymbol(), new AggregationNode.Aggregation(new CallExpression(aggFunction, functionHandle, type, ImmutableList.of(OriginalExpressionUtils.castToRowExpression(argument))), ImmutableList.of(OriginalExpressionUtils.castToRowExpression(argument)), false, Optional.empty(), Optional.empty(), Optional.empty()));
}
List<Symbol> groupingKeys = originalAggregationNode.getGroupingKeys().stream().map(Symbol::getName).map(rewrittenMappings::get).collect(Collectors.toList());
planNode = new AggregationNode(idAllocator.getNextId(), planNode, aggregationsBuilder.build(), singleGroupingSet(groupingKeys), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty(), AggregationNode.AggregationType.HASH, Optional.empty());
AggregationNode aggNode = (AggregationNode) planNode;
if (!cubeRewriteResult.getAvgAggregationColumns().isEmpty()) {
if (!cubeRewriteResult.getComputeAvgDividingSumByCount()) {
Map<Symbol, Expression> aggregateAssignments = new HashMap<>();
for (CubeRewriteResult.AggregatorSource aggregatorSource : cubeRewriteResult.getAggregationColumns()) {
aggregateAssignments.put(aggregatorSource.getOriginalAggSymbol(), toSymbolReference(aggregatorSource.getScanSymbol()));
}
planNode = new ProjectNode(idAllocator.getNextId(), aggNode, new Assignments(aggregateAssignments.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> castToRowExpression(entry.getValue())))));
} else {
// If there was an AVG aggregation, map it to AVG = SUM/COUNT
Map<Symbol, Expression> projections = new HashMap<>();
aggNode.getGroupingKeys().forEach(symbol -> projections.put(symbol, toSymbolReference(symbol)));
aggNode.getAggregations().keySet().stream().filter(symbol -> symbolMappings.containsValue(symbol.getName())).forEach(aggSymbol -> projections.put(aggSymbol, toSymbolReference(aggSymbol)));
// Add AVG = SUM / COUNT
for (CubeRewriteResult.AverageAggregatorSource avgAggSource : cubeRewriteResult.getAvgAggregationColumns()) {
Symbol sumSymbol = cubeScanToAggOutputMap.get(avgAggSource.getSum());
Symbol countSymbol = cubeScanToAggOutputMap.get(avgAggSource.getCount());
Type avgResultType = typeProvider.get(avgAggSource.getOriginalAggSymbol());
ArithmeticBinaryExpression division = new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.DIVIDE, new Cast(toSymbolReference(sumSymbol), avgResultType.getTypeSignature().toString()), new Cast(toSymbolReference(countSymbol), avgResultType.getTypeSignature().toString()));
projections.put(avgAggSource.getOriginalAggSymbol(), division);
}
planNode = new ProjectNode(idAllocator.getNextId(), aggNode, new Assignments(projections.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> castToRowExpression(entry.getValue())))));
}
}
}
// Safety check to remove redundant symbols and rename original column names to intermediate names
if (!planNode.getOutputSymbols().equals(originalAggregationNode.getOutputSymbols())) {
// Map new symbol names to the old symbols
Map<Symbol, Expression> assignments = new HashMap<>();
Set<Symbol> planNodeOutput = new HashSet<>(planNode.getOutputSymbols());
for (Symbol originalAggOutputSymbol : originalAggregationNode.getOutputSymbols()) {
if (!planNodeOutput.contains(originalAggOutputSymbol)) {
// Must be grouping key
assignments.put(originalAggOutputSymbol, toSymbolReference(rewrittenMappings.get(originalAggOutputSymbol.getName())));
} else {
// Should be an expression and must have the same name in the new plan node
assignments.put(originalAggOutputSymbol, toSymbolReference(originalAggOutputSymbol));
}
}
planNode = new ProjectNode(idAllocator.getNextId(), planNode, new Assignments(assignments.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> castToRowExpression(entry.getValue())))));
}
return planNode;
}
use of io.prestosql.sql.tree.ArithmeticBinaryExpression in project hetu-core by openlookeng.
the class GroupingOperationRewriter method rewriteGroupingOperation.
public static Expression rewriteGroupingOperation(GroupingOperation expression, List<Set<Integer>> groupingSets, Map<NodeRef<Expression>, FieldId> columnReferenceFields, Optional<Symbol> groupIdSymbol) {
requireNonNull(groupIdSymbol, "groupIdSymbol is null");
// See SQL:2011:4.16.2 and SQL:2011:6.9.10.
if (groupingSets.size() == 1) {
return new LongLiteral("0");
} else {
checkState(groupIdSymbol.isPresent(), "groupId symbol is missing");
RelationId relationId = columnReferenceFields.get(NodeRef.of(expression.getGroupingColumns().get(0))).getRelationId();
List<Integer> columns = expression.getGroupingColumns().stream().map(NodeRef::of).peek(groupingColumn -> checkState(columnReferenceFields.containsKey(groupingColumn), "the grouping column is not in the columnReferencesField map")).map(columnReferenceFields::get).map(fieldId -> translateFieldToInteger(fieldId, relationId)).collect(toImmutableList());
List<Expression> groupingResults = groupingSets.stream().map(groupingSet -> String.valueOf(calculateGrouping(groupingSet, columns))).map(LongLiteral::new).collect(toImmutableList());
// It is necessary to add a 1 to the groupId because the underlying array is indexed starting at 1
return new SubscriptExpression(new ArrayConstructor(groupingResults), new ArithmeticBinaryExpression(ADD, toSymbolReference(groupIdSymbol.get()), new GenericLiteral("BIGINT", "1")));
}
}
use of io.prestosql.sql.tree.ArithmeticBinaryExpression in project hetu-core by openlookeng.
the class CubeOptimizer method rewriteAggregationNode.
private PlanNode rewriteAggregationNode(CubeRewriteResult cubeRewriteResult, PlanNode inputPlanNode) {
TypeProvider typeProvider = context.getSymbolAllocator().getTypes();
// Add group by
List<Symbol> groupings = aggregationNode.getGroupingKeys().stream().map(Symbol::getName).map(columnRewritesMap::get).map(optimizedPlanMappings::get).collect(Collectors.toList());
Map<Symbol, Symbol> cubeScanToAggOutputMap = new HashMap<>();
// Rewrite AggregationNode using Cube table
ImmutableMap.Builder<Symbol, AggregationNode.Aggregation> aggregationsBuilder = ImmutableMap.builder();
for (CubeRewriteResult.AggregatorSource aggregatorSource : cubeRewriteResult.getAggregationColumns()) {
Type type = cubeRewriteResult.getSymbolMetadataMap().get(aggregatorSource.getOriginalAggSymbol()).getType();
TypeSignature typeSignature = type.getTypeSignature();
ColumnHandle cubeColHandle = cubeRewriteResult.getTableScanNode().getAssignments().get(aggregatorSource.getScanSymbol());
ColumnMetadata cubeColumnMetadata = metadata.getColumnMetadata(context.getSession(), cubeTableHandle, cubeColHandle);
AggregationSignature aggregationSignature = cubeMetadata.getAggregationSignature(cubeColumnMetadata.getName()).orElseThrow(() -> new ColumnNotFoundException(new SchemaTableName("", ""), cubeColHandle.getColumnName()));
String aggFunction = COUNT.getName().equals(aggregationSignature.getFunction()) ? SUM.getName() : aggregationSignature.getFunction();
SymbolReference argument = toSymbolReference(aggregatorSource.getScanSymbol());
FunctionHandle functionHandle = metadata.getFunctionAndTypeManager().lookupFunction(aggFunction, TypeSignatureProvider.fromTypeSignatures(typeSignature));
cubeScanToAggOutputMap.put(aggregatorSource.getScanSymbol(), aggregatorSource.getOriginalAggSymbol());
aggregationsBuilder.put(aggregatorSource.getOriginalAggSymbol(), new AggregationNode.Aggregation(new CallExpression(aggFunction, functionHandle, type, ImmutableList.of(castToRowExpression(argument))), ImmutableList.of(castToRowExpression(argument)), false, Optional.empty(), Optional.empty(), Optional.empty()));
}
PlanNode planNode = inputPlanNode;
AggregationNode aggNode = new AggregationNode(context.getIdAllocator().getNextId(), planNode, aggregationsBuilder.build(), singleGroupingSet(groupings), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty(), AggregationNode.AggregationType.HASH, Optional.empty());
if (cubeRewriteResult.getAvgAggregationColumns().isEmpty()) {
return aggNode;
}
if (!cubeRewriteResult.getComputeAvgDividingSumByCount()) {
Map<Symbol, Expression> aggregateAssignments = new HashMap<>();
for (CubeRewriteResult.AggregatorSource aggregatorSource : cubeRewriteResult.getAggregationColumns()) {
aggregateAssignments.put(aggregatorSource.getOriginalAggSymbol(), toSymbolReference(aggregatorSource.getScanSymbol()));
}
planNode = new ProjectNode(context.getIdAllocator().getNextId(), aggNode, new Assignments(aggregateAssignments.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> castToRowExpression(entry.getValue())))));
} else {
// If there was an AVG aggregation, map it to AVG = SUM/COUNT
Map<Symbol, Expression> projections = new HashMap<>();
aggNode.getGroupingKeys().forEach(symbol -> projections.put(symbol, toSymbolReference(symbol)));
aggNode.getAggregations().keySet().stream().filter(originalAggregationsMap::containsValue).forEach(aggSymbol -> projections.put(aggSymbol, toSymbolReference(aggSymbol)));
// Add AVG = SUM / COUNT
for (CubeRewriteResult.AverageAggregatorSource avgAggSource : cubeRewriteResult.getAvgAggregationColumns()) {
Symbol sumSymbol = cubeScanToAggOutputMap.get(avgAggSource.getSum());
Symbol countSymbol = cubeScanToAggOutputMap.get(avgAggSource.getCount());
Type avgResultType = typeProvider.get(avgAggSource.getOriginalAggSymbol());
ArithmeticBinaryExpression division = new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.DIVIDE, new Cast(toSymbolReference(sumSymbol), avgResultType.getTypeSignature().toString()), new Cast(toSymbolReference(countSymbol), avgResultType.getTypeSignature().toString()));
projections.put(avgAggSource.getOriginalAggSymbol(), division);
}
return new ProjectNode(context.getIdAllocator().getNextId(), aggNode, new Assignments(projections.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> castToRowExpression(entry.getValue())))));
}
return planNode;
}
Aggregations