use of io.prestosql.spi.relation.CallExpression in project hetu-core by openlookeng.
the class TestPageProcessorCompiler method testNoCaching.
@Test
public void testNoCaching() {
ImmutableList.Builder<RowExpression> projectionsBuilder = ImmutableList.builder();
ArrayType arrayType = new ArrayType(VARCHAR);
FunctionHandle functionHandle = metadata.getFunctionAndTypeManager().lookupFunction("concat", fromTypes(arrayType, arrayType));
projectionsBuilder.add(new CallExpression("concat", functionHandle, arrayType, ImmutableList.of(field(0, arrayType), field(1, arrayType))));
ImmutableList<RowExpression> projections = projectionsBuilder.build();
PageProcessor pageProcessor = compiler.compilePageProcessor(Optional.empty(), projections).get();
PageProcessor pageProcessor2 = compiler.compilePageProcessor(Optional.empty(), projections).get();
assertTrue(pageProcessor != pageProcessor2);
}
use of io.prestosql.spi.relation.CallExpression in project hetu-core by openlookeng.
the class DynamicFilters method createDynamicFilterPredicate.
public static Optional<Predicate<List>> createDynamicFilterPredicate(Optional<RowExpression> filter) {
if (filter.isPresent()) {
if (filter.get() instanceof CallExpression) {
CallExpression call = (CallExpression) filter.get();
BuiltInFunctionHandle builtInFunctionHandle = (BuiltInFunctionHandle) call.getFunctionHandle();
String name = builtInFunctionHandle.getSignature().getNameSuffix();
if (name.contains("$operator$") && Signature.unmangleOperator(name).isComparisonOperator()) {
if (call.getArguments().get(1) instanceof VariableReferenceExpression && call.getArguments().get(0) instanceof VariableReferenceExpression) {
switch(Signature.unmangleOperator(name)) {
case LESS_THAN:
return Optional.of((values) -> {
Object probeValue = values.get(0);
Object buildValue = values.get(1);
if (!(probeValue instanceof Long) || !(buildValue instanceof Long)) {
return true;
}
Long probeLiteral = (Long) probeValue;
Long buildLiteral = (Long) buildValue;
return probeLiteral.compareTo(buildLiteral) < 0;
});
case LESS_THAN_OR_EQUAL:
return Optional.of((values) -> {
Object probeValue = values.get(0);
Object buildValue = values.get(1);
if (!(probeValue instanceof Long) || !(buildValue instanceof Long)) {
return true;
}
Long probeLiteral = (Long) probeValue;
Long buildLiteral = (Long) buildValue;
return probeLiteral.compareTo(buildLiteral) <= 0;
});
case GREATER_THAN:
return Optional.of((values) -> {
Object probeValue = values.get(0);
Object buildValue = values.get(1);
if (!(probeValue instanceof Long) || !(buildValue instanceof Long)) {
return true;
}
Long probeLiteral = (Long) probeValue;
Long buildLiteral = (Long) buildValue;
return probeLiteral.compareTo(buildLiteral) > 0;
});
case GREATER_THAN_OR_EQUAL:
return Optional.of((values) -> {
Object probeValue = values.get(0);
Object buildValue = values.get(1);
if (!(probeValue instanceof Long) || !(buildValue instanceof Long)) {
return true;
}
Long probeLiteral = (Long) probeValue;
Long buildLiteral = (Long) buildValue;
return probeLiteral.compareTo(buildLiteral) >= 0;
});
default:
return Optional.empty();
}
}
}
}
}
return Optional.empty();
}
use of io.prestosql.spi.relation.CallExpression in project hetu-core by openlookeng.
the class DynamicFilters method getDescriptor.
public static Optional<Descriptor> getDescriptor(RowExpression expression) {
if (!(expression instanceof CallExpression)) {
return Optional.empty();
}
CallExpression callExpression = (CallExpression) expression;
if (!callExpression.getDisplayName().equals(Function.NAME)) {
return Optional.empty();
}
List<RowExpression> arguments = callExpression.getArguments();
checkArgument(arguments.size() == 2, "invalid arguments count: %s", arguments.size());
RowExpression firstArgument = arguments.get(0);
checkArgument(firstArgument instanceof ConstantExpression, "firstArgument is expected to be an instance of ConstantExpression: %s", firstArgument.getClass().getSimpleName());
Object firstArgumentValue = ((ConstantExpression) firstArgument).getValue();
String id = (firstArgumentValue instanceof String) ? (String) (firstArgumentValue) : ((Slice) (firstArgumentValue)).toStringUtf8();
return Optional.of(new Descriptor(id, arguments.get(1), callExpression.getFilter()));
/* Fixme(Nitin): Resolve the filter expression from the dynamic filter */
}
use of io.prestosql.spi.relation.CallExpression in project hetu-core by openlookeng.
the class TransformUnCorrelatedInPredicateSubQuerySelfJoinToAggregate method transformProjectNode.
private Optional<ProjectNode> transformProjectNode(Context context, ProjectNode projectNode) {
// IN predicate requires only one projection
if (projectNode.getOutputSymbols().size() > 1) {
return Optional.empty();
}
PlanNode source = context.getLookup().resolve(projectNode.getSource());
if (source instanceof CTEScanNode) {
source = getChildFilterNode(context, context.getLookup().resolve(((CTEScanNode) source).getSource()));
}
if (!(source instanceof FilterNode && context.getLookup().resolve(((FilterNode) source).getSource()) instanceof JoinNode)) {
return Optional.empty();
}
FilterNode filter = (FilterNode) source;
Expression predicate = OriginalExpressionUtils.castToExpression(filter.getPredicate());
List<SymbolReference> allPredicateSymbols = new ArrayList<>();
getAllSymbols(predicate, allPredicateSymbols);
JoinNode joinNode = (JoinNode) context.getLookup().resolve(((FilterNode) source).getSource());
if (!isSelfJoin(projectNode, predicate, joinNode, context.getLookup())) {
// Check next level for Self Join
PlanNode left = context.getLookup().resolve(joinNode.getLeft());
boolean changed = false;
if (left instanceof ProjectNode) {
Optional<ProjectNode> transformResult = transformProjectNode(context, (ProjectNode) left);
if (transformResult.isPresent()) {
joinNode = new JoinNode(joinNode.getId(), joinNode.getType(), transformResult.get(), joinNode.getRight(), joinNode.getCriteria(), joinNode.getOutputSymbols(), joinNode.getFilter(), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType(), joinNode.isSpillable(), joinNode.getDynamicFilters());
changed = true;
}
}
PlanNode right = context.getLookup().resolve(joinNode.getRight());
if (right instanceof ProjectNode) {
Optional<ProjectNode> transformResult = transformProjectNode(context, (ProjectNode) right);
if (transformResult.isPresent()) {
joinNode = new JoinNode(joinNode.getId(), joinNode.getType(), joinNode.getLeft(), transformResult.get(), joinNode.getCriteria(), joinNode.getOutputSymbols(), joinNode.getFilter(), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType(), joinNode.isSpillable(), joinNode.getDynamicFilters());
changed = true;
}
}
if (changed) {
FilterNode transformedFilter = new FilterNode(filter.getId(), joinNode, filter.getPredicate());
ProjectNode transformedProject = new ProjectNode(projectNode.getId(), transformedFilter, projectNode.getAssignments());
return Optional.of(transformedProject);
}
return Optional.empty();
}
// Choose the table to use based on projected output.
TableScanNode leftTable = (TableScanNode) context.getLookup().resolve(joinNode.getLeft());
TableScanNode rightTable = (TableScanNode) context.getLookup().resolve(joinNode.getRight());
TableScanNode tableToUse;
List<RowExpression> aggregationSymbols;
AggregationNode.GroupingSetDescriptor groupingSetDescriptor;
Assignments projectionsForCTE = null;
Assignments projectionsFromFilter = null;
// Use non-projected column for aggregation
if (context.getLookup().resolve(projectNode.getSource()) instanceof CTEScanNode) {
CTEScanNode cteScanNode = (CTEScanNode) context.getLookup().resolve(projectNode.getSource());
ProjectNode childProjectOfCte = (ProjectNode) context.getLookup().resolve(cteScanNode.getSource());
List<Symbol> completeOutputSymbols = new ArrayList<>();
rightTable.getOutputSymbols().forEach((s) -> completeOutputSymbols.add(s));
leftTable.getOutputSymbols().forEach((s) -> completeOutputSymbols.add(s));
List<Symbol> outputSymbols = new ArrayList<>();
for (int i = 0; i < completeOutputSymbols.size(); i++) {
Symbol outputSymbol = completeOutputSymbols.get(i);
for (Symbol symbol : projectNode.getOutputSymbols()) {
if (childProjectOfCte.getAssignments().getMap().containsKey(symbol)) {
if (((SymbolReference) OriginalExpressionUtils.castToExpression(childProjectOfCte.getAssignments().getMap().get(symbol))).getName().equals(outputSymbol.getName())) {
outputSymbols.add(outputSymbol);
}
}
}
}
Map<Symbol, RowExpression> projectionsForCTEMap = new HashMap<>();
Map<Symbol, RowExpression> projectionsFromFilterMap = new HashMap<>();
for (Map.Entry entry : childProjectOfCte.getAssignments().getMap().entrySet()) {
if (entry.getKey().equals(getOnlyElement(projectNode.getOutputSymbols()))) {
projectionsForCTEMap.put((Symbol) entry.getKey(), (RowExpression) entry.getValue());
}
if (entry.getKey().equals(getOnlyElement(projectNode.getOutputSymbols()))) {
projectionsFromFilterMap.put(getOnlyElement(outputSymbols), (RowExpression) entry.getValue());
}
}
projectionsForCTE = new Assignments(projectionsForCTEMap);
projectionsFromFilter = new Assignments(projectionsFromFilterMap);
tableToUse = leftTable.getOutputSymbols().contains(getOnlyElement(outputSymbols)) ? leftTable : rightTable;
aggregationSymbols = allPredicateSymbols.stream().filter(s -> tableToUse.getOutputSymbols().contains(SymbolUtils.from(s))).filter(s -> !outputSymbols.contains(SymbolUtils.from(s))).map(OriginalExpressionUtils::castToRowExpression).collect(Collectors.toList());
// Create aggregation
groupingSetDescriptor = new AggregationNode.GroupingSetDescriptor(ImmutableList.copyOf(outputSymbols), 1, ImmutableSet.of());
} else {
tableToUse = leftTable.getOutputSymbols().contains(getOnlyElement(projectNode.getOutputSymbols())) ? leftTable : rightTable;
aggregationSymbols = allPredicateSymbols.stream().filter(s -> tableToUse.getOutputSymbols().contains(SymbolUtils.from(s))).filter(s -> !projectNode.getOutputSymbols().contains(SymbolUtils.from(s))).map(OriginalExpressionUtils::castToRowExpression).collect(Collectors.toList());
// Create aggregation
groupingSetDescriptor = new AggregationNode.GroupingSetDescriptor(ImmutableList.copyOf(projectNode.getOutputSymbols()), 1, ImmutableSet.of());
}
AggregationNode.Aggregation aggregation = new AggregationNode.Aggregation(new CallExpression("count", functionResolution.countFunction(), BIGINT, aggregationSymbols), aggregationSymbols, // mark DISTINCT since NOT_EQUALS predicate
true, Optional.empty(), Optional.empty(), Optional.empty());
ImmutableMap.Builder<Symbol, AggregationNode.Aggregation> aggregationsBuilder = ImmutableMap.builder();
Symbol countSymbol = context.getSymbolAllocator().newSymbol(aggregation.getFunctionCall().getDisplayName(), BIGINT);
aggregationsBuilder.put(countSymbol, aggregation);
AggregationNode aggregationNode = new AggregationNode(context.getIdAllocator().getNextId(), tableToUse, aggregationsBuilder.build(), groupingSetDescriptor, ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty(), AggregationNode.AggregationType.HASH, Optional.empty());
// Filter rows with count < 1 from aggregation results to match the NOT_EQUALS clause in original query.
FilterNode filterNode = new FilterNode(context.getIdAllocator().getNextId(), aggregationNode, OriginalExpressionUtils.castToRowExpression(new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, SymbolUtils.toSymbolReference(countSymbol), new GenericLiteral("BIGINT", "1"))));
// Project the aggregated+filtered rows.
ProjectNode transformedSubquery = new ProjectNode(projectNode.getId(), filterNode, projectNode.getAssignments());
if (context.getLookup().resolve(projectNode.getSource()) instanceof CTEScanNode) {
CTEScanNode cteScanNode = (CTEScanNode) context.getLookup().resolve(projectNode.getSource());
PlanNode projectNodeForCTE = context.getLookup().resolve(cteScanNode.getSource());
PlanNode projectNodeFromFilter = context.getLookup().resolve(projectNodeForCTE);
projectNodeFromFilter = new ProjectNode(projectNodeFromFilter.getId(), filterNode, projectionsFromFilter);
projectNodeForCTE = new ProjectNode(projectNodeForCTE.getId(), projectNodeFromFilter, projectionsForCTE);
cteScanNode = (CTEScanNode) cteScanNode.replaceChildren(ImmutableList.of(projectNodeForCTE));
cteScanNode.setOutputSymbols(projectNode.getOutputSymbols());
transformedSubquery = new ProjectNode(projectNode.getId(), cteScanNode, projectNode.getAssignments());
}
return Optional.of(transformedSubquery);
}
use of io.prestosql.spi.relation.CallExpression in project hetu-core by openlookeng.
the class AggregationRewriteWithCube method rewrite.
public PlanNode rewrite(AggregationNode originalAggregationNode, PlanNode filterNode) {
QualifiedObjectName starTreeTableName = QualifiedObjectName.valueOf(cubeMetadata.getCubeName());
TableHandle cubeTableHandle = metadata.getTableHandle(session, starTreeTableName).orElseThrow(() -> new CubeNotFoundException(starTreeTableName.toString()));
Map<String, ColumnHandle> cubeColumnsMap = metadata.getColumnHandles(session, cubeTableHandle);
TableMetadata cubeTableMetadata = metadata.getTableMetadata(session, cubeTableHandle);
List<ColumnMetadata> cubeColumnMetadataList = cubeTableMetadata.getColumns();
// Add group by
List<Symbol> groupings = new ArrayList<>(originalAggregationNode.getGroupingKeys().size());
for (Symbol symbol : originalAggregationNode.getGroupingKeys()) {
Object column = symbolMappings.get(symbol.getName());
if (column instanceof ColumnHandle) {
groupings.add(new Symbol(((ColumnHandle) column).getColumnName()));
}
}
Set<String> cubeGroups = cubeMetadata.getGroup();
boolean exactGroupsMatch = false;
if (groupings.size() == cubeGroups.size()) {
exactGroupsMatch = groupings.stream().map(Symbol::getName).map(String::toLowerCase).allMatch(cubeGroups::contains);
}
CubeRewriteResult cubeRewriteResult = createScanNode(originalAggregationNode, filterNode, cubeTableHandle, cubeColumnsMap, cubeColumnMetadataList, exactGroupsMatch);
PlanNode planNode = cubeRewriteResult.getTableScanNode();
// Add filter node
if (filterNode != null) {
Expression expression = castToExpression(((FilterNode) filterNode).getPredicate());
expression = rewriteExpression(expression, rewrittenMappings);
planNode = new FilterNode(idAllocator.getNextId(), planNode, castToRowExpression(expression));
}
if (!exactGroupsMatch) {
Map<Symbol, Symbol> cubeScanToAggOutputMap = new HashMap<>();
// Rewrite AggregationNode using Cube table
ImmutableMap.Builder<Symbol, AggregationNode.Aggregation> aggregationsBuilder = ImmutableMap.builder();
for (CubeRewriteResult.AggregatorSource aggregatorSource : cubeRewriteResult.getAggregationColumns()) {
ColumnHandle cubeColHandle = cubeRewriteResult.getTableScanNode().getAssignments().get(aggregatorSource.getScanSymbol());
ColumnMetadata cubeColumnMetadata = cubeRewriteResult.getSymbolMetadataMap().get(aggregatorSource.getScanSymbol());
Type type = cubeColumnMetadata.getType();
AggregationSignature aggregationSignature = cubeMetadata.getAggregationSignature(cubeColumnMetadata.getName()).orElseThrow(() -> new ColumnNotFoundException(new SchemaTableName(starTreeTableName.getSchemaName(), starTreeTableName.getObjectName()), cubeColHandle.getColumnName()));
String aggFunction = COUNT.getName().equals(aggregationSignature.getFunction()) ? "sum" : aggregationSignature.getFunction();
SymbolReference argument = toSymbolReference(aggregatorSource.getScanSymbol());
FunctionHandle functionHandle = metadata.getFunctionAndTypeManager().lookupFunction(aggFunction, TypeSignatureProvider.fromTypeSignatures(type.getTypeSignature()));
cubeScanToAggOutputMap.put(aggregatorSource.getScanSymbol(), aggregatorSource.getOriginalAggSymbol());
aggregationsBuilder.put(aggregatorSource.getOriginalAggSymbol(), new AggregationNode.Aggregation(new CallExpression(aggFunction, functionHandle, type, ImmutableList.of(OriginalExpressionUtils.castToRowExpression(argument))), ImmutableList.of(OriginalExpressionUtils.castToRowExpression(argument)), false, Optional.empty(), Optional.empty(), Optional.empty()));
}
List<Symbol> groupingKeys = originalAggregationNode.getGroupingKeys().stream().map(Symbol::getName).map(rewrittenMappings::get).collect(Collectors.toList());
planNode = new AggregationNode(idAllocator.getNextId(), planNode, aggregationsBuilder.build(), singleGroupingSet(groupingKeys), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty(), AggregationNode.AggregationType.HASH, Optional.empty());
AggregationNode aggNode = (AggregationNode) planNode;
if (!cubeRewriteResult.getAvgAggregationColumns().isEmpty()) {
if (!cubeRewriteResult.getComputeAvgDividingSumByCount()) {
Map<Symbol, Expression> aggregateAssignments = new HashMap<>();
for (CubeRewriteResult.AggregatorSource aggregatorSource : cubeRewriteResult.getAggregationColumns()) {
aggregateAssignments.put(aggregatorSource.getOriginalAggSymbol(), toSymbolReference(aggregatorSource.getScanSymbol()));
}
planNode = new ProjectNode(idAllocator.getNextId(), aggNode, new Assignments(aggregateAssignments.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> castToRowExpression(entry.getValue())))));
} else {
// If there was an AVG aggregation, map it to AVG = SUM/COUNT
Map<Symbol, Expression> projections = new HashMap<>();
aggNode.getGroupingKeys().forEach(symbol -> projections.put(symbol, toSymbolReference(symbol)));
aggNode.getAggregations().keySet().stream().filter(symbol -> symbolMappings.containsValue(symbol.getName())).forEach(aggSymbol -> projections.put(aggSymbol, toSymbolReference(aggSymbol)));
// Add AVG = SUM / COUNT
for (CubeRewriteResult.AverageAggregatorSource avgAggSource : cubeRewriteResult.getAvgAggregationColumns()) {
Symbol sumSymbol = cubeScanToAggOutputMap.get(avgAggSource.getSum());
Symbol countSymbol = cubeScanToAggOutputMap.get(avgAggSource.getCount());
Type avgResultType = typeProvider.get(avgAggSource.getOriginalAggSymbol());
ArithmeticBinaryExpression division = new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.DIVIDE, new Cast(toSymbolReference(sumSymbol), avgResultType.getTypeSignature().toString()), new Cast(toSymbolReference(countSymbol), avgResultType.getTypeSignature().toString()));
projections.put(avgAggSource.getOriginalAggSymbol(), division);
}
planNode = new ProjectNode(idAllocator.getNextId(), aggNode, new Assignments(projections.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> castToRowExpression(entry.getValue())))));
}
}
}
// Safety check to remove redundant symbols and rename original column names to intermediate names
if (!planNode.getOutputSymbols().equals(originalAggregationNode.getOutputSymbols())) {
// Map new symbol names to the old symbols
Map<Symbol, Expression> assignments = new HashMap<>();
Set<Symbol> planNodeOutput = new HashSet<>(planNode.getOutputSymbols());
for (Symbol originalAggOutputSymbol : originalAggregationNode.getOutputSymbols()) {
if (!planNodeOutput.contains(originalAggOutputSymbol)) {
// Must be grouping key
assignments.put(originalAggOutputSymbol, toSymbolReference(rewrittenMappings.get(originalAggOutputSymbol.getName())));
} else {
// Should be an expression and must have the same name in the new plan node
assignments.put(originalAggOutputSymbol, toSymbolReference(originalAggOutputSymbol));
}
}
planNode = new ProjectNode(idAllocator.getNextId(), planNode, new Assignments(assignments.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> castToRowExpression(entry.getValue())))));
}
return planNode;
}
Aggregations