use of com.facebook.presto.common.relation.Predicate in project presto by prestodb.
the class RowExpressionPredicateCompiler method compilePredicateInternal.
private Supplier<Predicate> compilePredicateInternal(SqlFunctionProperties sqlFunctionProperties, Map<SqlFunctionId, SqlInvokedFunction> sessionFunctions, RowExpression predicate) {
requireNonNull(predicate, "predicate is null");
PageFieldsToInputParametersRewriter.Result result = rewritePageFieldsToInputParameters(predicate);
int[] inputChannels = result.getInputChannels().getInputChannels().stream().mapToInt(Integer::intValue).toArray();
CallSiteBinder callSiteBinder = new CallSiteBinder();
ClassDefinition classDefinition = definePredicateClass(sqlFunctionProperties, sessionFunctions, result.getRewrittenExpression(), inputChannels, callSiteBinder);
Class<? extends Predicate> predicateClass;
try {
predicateClass = defineClass(classDefinition, Predicate.class, callSiteBinder.getBindings(), getClass().getClassLoader());
} catch (Exception e) {
throw new PrestoException(COMPILER_ERROR, predicate.toString(), e.getCause());
}
return () -> {
try {
return predicateClass.getConstructor().newInstance();
} catch (ReflectiveOperationException e) {
throw new PrestoException(COMPILER_ERROR, e);
}
};
}
use of com.facebook.presto.common.relation.Predicate in project presto by prestodb.
the class TestRowExpressionPredicateCompiler method test.
@Test
public void test() {
InputReferenceExpression a = new InputReferenceExpression(Optional.empty(), 0, BIGINT);
InputReferenceExpression b = new InputReferenceExpression(Optional.empty(), 1, BIGINT);
Block aBlock = createLongBlock(5, 5, 5, 5, 5);
Block bBlock = createLongBlock(1, 3, 5, 7, 0);
// b - a >= 0
RowExpression sum = call("<", functionResolution.comparisonFunction(GREATER_THAN_OR_EQUAL, BIGINT, BIGINT), BOOLEAN, call("b - a", functionResolution.arithmeticFunction(SUBTRACT, BIGINT, BIGINT), BIGINT, b, a), constant(0L, BIGINT));
PredicateCompiler compiler = new RowExpressionPredicateCompiler(metadata, 10_000);
Predicate compiledSum = compiler.compilePredicate(SESSION.getSqlFunctionProperties(), SESSION.getSessionFunctions(), sum).get();
assertEquals(Arrays.asList(1, 0), Ints.asList(compiledSum.getInputChannels()));
Page page = new Page(bBlock, aBlock);
assertFalse(compiledSum.evaluate(SESSION.getSqlFunctionProperties(), page, 0));
assertFalse(compiledSum.evaluate(SESSION.getSqlFunctionProperties(), page, 1));
assertTrue(compiledSum.evaluate(SESSION.getSqlFunctionProperties(), page, 2));
assertTrue(compiledSum.evaluate(SESSION.getSqlFunctionProperties(), page, 3));
assertFalse(compiledSum.evaluate(SESSION.getSqlFunctionProperties(), page, 4));
// b * 2 < 10
RowExpression timesTwo = call("=", functionResolution.comparisonFunction(LESS_THAN, BIGINT, BIGINT), BOOLEAN, call("b * 2", functionResolution.arithmeticFunction(MULTIPLY, BIGINT, BIGINT), BIGINT, b, constant(2L, BIGINT)), constant(10L, BIGINT));
Predicate compiledTimesTwo = compiler.compilePredicate(SESSION.getSqlFunctionProperties(), SESSION.getSessionFunctions(), timesTwo).get();
assertEquals(Arrays.asList(1), Ints.asList(compiledTimesTwo.getInputChannels()));
page = new Page(bBlock);
assertTrue(compiledTimesTwo.evaluate(SESSION.getSqlFunctionProperties(), page, 0));
assertTrue(compiledTimesTwo.evaluate(SESSION.getSqlFunctionProperties(), page, 1));
assertFalse(compiledTimesTwo.evaluate(SESSION.getSqlFunctionProperties(), page, 2));
assertFalse(compiledTimesTwo.evaluate(SESSION.getSqlFunctionProperties(), page, 3));
assertTrue(compiledTimesTwo.evaluate(SESSION.getSqlFunctionProperties(), page, 4));
}
use of com.facebook.presto.common.relation.Predicate in project presto by prestodb.
the class OrcSelectivePageSourceFactory method toFilterFunctions.
/**
* Split filter expression into groups of conjuncts that depend on the same set of inputs,
* then compile each group into FilterFunction.
*/
private static List<FilterFunction> toFilterFunctions(RowExpression filter, Optional<BucketAdapter> bucketAdapter, ConnectorSession session, DeterminismEvaluator determinismEvaluator, PredicateCompiler predicateCompiler) {
ImmutableList.Builder<FilterFunction> filterFunctions = ImmutableList.builder();
bucketAdapter.map(predicate -> new FilterFunction(session.getSqlFunctionProperties(), true, predicate)).ifPresent(filterFunctions::add);
if (TRUE_CONSTANT.equals(filter)) {
return filterFunctions.build();
}
DynamicFilterExtractResult extractDynamicFilterResult = extractDynamicFilters(filter);
// dynamic filter will be added through subfield pushdown
filter = and(extractDynamicFilterResult.getStaticConjuncts());
if (!isAdaptiveFilterReorderingEnabled(session)) {
filterFunctions.add(new FilterFunction(session.getSqlFunctionProperties(), determinismEvaluator.isDeterministic(filter), predicateCompiler.compilePredicate(session.getSqlFunctionProperties(), session.getSessionFunctions(), filter).get()));
return filterFunctions.build();
}
List<RowExpression> conjuncts = extractConjuncts(filter);
if (conjuncts.size() == 1) {
filterFunctions.add(new FilterFunction(session.getSqlFunctionProperties(), determinismEvaluator.isDeterministic(filter), predicateCompiler.compilePredicate(session.getSqlFunctionProperties(), session.getSessionFunctions(), filter).get()));
return filterFunctions.build();
}
// Use LinkedHashMap to preserve user-specified order of conjuncts. This will be the initial order in which filters are applied.
Map<Set<Integer>, List<RowExpression>> inputsToConjuncts = new LinkedHashMap<>();
for (RowExpression conjunct : conjuncts) {
inputsToConjuncts.computeIfAbsent(extractInputs(conjunct), k -> new ArrayList<>()).add(conjunct);
}
inputsToConjuncts.values().stream().map(expressions -> binaryExpression(AND, expressions)).map(predicate -> new FilterFunction(session.getSqlFunctionProperties(), determinismEvaluator.isDeterministic(predicate), predicateCompiler.compilePredicate(session.getSqlFunctionProperties(), session.getSessionFunctions(), predicate).get())).forEach(filterFunctions::add);
return filterFunctions.build();
}
use of com.facebook.presto.common.relation.Predicate in project presto by prestodb.
the class TestOrcReaderPositions method testRowGroupSkipping.
@Test
public void testRowGroupSkipping() throws Exception {
try (TempFile tempFile = new TempFile()) {
// create single stripe file with multiple row groups
int rowCount = 142_000;
createSequentialFile(tempFile.getFile(), rowCount);
// test reading two row groups from middle of file
OrcPredicate predicate = (numberOfRows, statisticsByColumnIndex) -> {
if (numberOfRows == rowCount) {
return true;
}
IntegerStatistics stats = statisticsByColumnIndex.get(0).getIntegerStatistics();
return (stats.getMin() == 50_000) || (stats.getMin() == 60_000);
};
try (OrcBatchRecordReader reader = createCustomOrcRecordReader(tempFile, ORC, predicate, BIGINT, MAX_BATCH_SIZE, false, false)) {
assertEquals(reader.getFileRowCount(), rowCount);
assertEquals(reader.getReaderRowCount(), rowCount);
assertEquals(reader.getFilePosition(), 0);
assertEquals(reader.getReaderPosition(), 0);
long position = 50_000;
while (true) {
int batchSize = reader.nextBatch();
if (batchSize == -1) {
break;
}
Block block = reader.readBlock(0);
for (int i = 0; i < batchSize; i++) {
assertEquals(BIGINT.getLong(block, i), position + i);
}
assertEquals(reader.getFilePosition(), position);
assertEquals(reader.getReaderPosition(), position);
position += batchSize;
}
assertEquals(position, 70_000);
assertEquals(reader.getFilePosition(), rowCount);
assertEquals(reader.getReaderPosition(), rowCount);
}
}
}
use of com.facebook.presto.common.relation.Predicate in project presto by prestodb.
the class TestOrcReaderPositions method testRowGroupSkippingWithAppendRowNumber.
@Test
public void testRowGroupSkippingWithAppendRowNumber() throws Exception {
try (TempFile tempFile = new TempFile()) {
// create single stripe file with multiple row groups
int rowCount = 142_000;
createSequentialFile(tempFile.getFile(), rowCount);
// test reading two row groups from middle of file
OrcPredicate predicate = (numberOfRows, statisticsByColumnIndex) -> {
if (numberOfRows == rowCount) {
return true;
}
IntegerStatistics stats = statisticsByColumnIndex.get(0).getIntegerStatistics();
return (stats.getMin() == 50_000) || (stats.getMin() == 70_000);
};
List<Long> expectedValues = new ArrayList<>();
expectedValues.addAll(LongStream.range(50_000, 60_000).collect(ArrayList::new, List::add, List::addAll));
expectedValues.addAll(LongStream.range(70_000, 80_000).collect(ArrayList::new, List::add, List::addAll));
OrcSelectiveRecordReader reader = createCustomOrcSelectiveRecordReader(tempFile, ORC, predicate, BIGINT, MAX_BATCH_SIZE, false, true);
verifyAppendNumber(expectedValues, reader);
}
}
Aggregations