use of com.facebook.presto.operator.aggregation.InternalAggregationFunction in project presto by prestodb.
the class TestLearnAggregations method testLearnLibSvm.
@Test
public void testLearnLibSvm() throws Exception {
Type mapType = typeManager.getParameterizedType("map", ImmutableList.of(TypeSignatureParameter.of(parseTypeSignature(StandardTypes.BIGINT)), TypeSignatureParameter.of(parseTypeSignature(StandardTypes.DOUBLE))));
InternalAggregationFunction aggregation = AggregationCompiler.generateAggregationBindableFunction(LearnLibSvmClassifierAggregation.class, ClassifierType.BIGINT_CLASSIFIER.getTypeSignature(), ImmutableList.of(BigintType.BIGINT.getTypeSignature(), mapType.getTypeSignature(), VarcharType.getParametrizedVarcharSignature("x"))).specialize(BoundVariables.builder().setLongVariable("x", (long) Integer.MAX_VALUE).build(), 3, typeManager);
assertLearnClassifer(aggregation.bind(ImmutableList.of(0, 1, 2), Optional.empty()).createAccumulator());
}
use of com.facebook.presto.operator.aggregation.InternalAggregationFunction in project presto by prestodb.
the class TestLearnAggregations method testLearn.
@Test
public void testLearn() throws Exception {
Type mapType = typeManager.getParameterizedType("map", ImmutableList.of(TypeSignatureParameter.of(parseTypeSignature(StandardTypes.BIGINT)), TypeSignatureParameter.of(parseTypeSignature(StandardTypes.DOUBLE))));
InternalAggregationFunction aggregation = generateInternalAggregationFunction(LearnClassifierAggregation.class, ClassifierType.BIGINT_CLASSIFIER.getTypeSignature(), ImmutableList.of(BigintType.BIGINT.getTypeSignature(), mapType.getTypeSignature()), typeManager);
assertLearnClassifer(aggregation.bind(ImmutableList.of(0, 1), Optional.empty()).createAccumulator());
}
use of com.facebook.presto.operator.aggregation.InternalAggregationFunction in project presto by prestodb.
the class DoubleSumAggregationBenchmark method createOperatorFactories.
@Override
protected List<? extends OperatorFactory> createOperatorFactories() {
OperatorFactory tableScanOperator = createTableScanOperator(0, new PlanNodeId("test"), "orders", "totalprice");
InternalAggregationFunction doubleSum = MetadataManager.createTestMetadataManager().getFunctionRegistry().getAggregateFunctionImplementation(new Signature("sum", AGGREGATE, DOUBLE.getTypeSignature(), DOUBLE.getTypeSignature()));
AggregationOperatorFactory aggregationOperator = new AggregationOperatorFactory(1, new PlanNodeId("test"), Step.SINGLE, ImmutableList.of(doubleSum.bind(ImmutableList.of(0), Optional.empty())));
return ImmutableList.of(tableScanOperator, aggregationOperator);
}
use of com.facebook.presto.operator.aggregation.InternalAggregationFunction in project presto by prestodb.
the class TestHashAggregationOperator method testHashAggregationWithGlobals.
@Test(dataProvider = "hashEnabled")
public void testHashAggregationWithGlobals(boolean hashEnabled) throws Exception {
MetadataManager metadata = MetadataManager.createTestMetadataManager();
InternalAggregationFunction countVarcharColumn = metadata.getFunctionRegistry().getAggregateFunctionImplementation(new Signature("count", AGGREGATE, parseTypeSignature(StandardTypes.BIGINT), parseTypeSignature(StandardTypes.VARCHAR)));
InternalAggregationFunction countBooleanColumn = metadata.getFunctionRegistry().getAggregateFunctionImplementation(new Signature("count", AGGREGATE, parseTypeSignature(StandardTypes.BIGINT), parseTypeSignature(StandardTypes.BOOLEAN)));
InternalAggregationFunction maxVarcharColumn = metadata.getFunctionRegistry().getAggregateFunctionImplementation(new Signature("max", AGGREGATE, parseTypeSignature(StandardTypes.VARCHAR), parseTypeSignature(StandardTypes.VARCHAR)));
Optional<Integer> groupIdChannel = Optional.of(1);
List<Integer> groupByChannels = Ints.asList(1, 2);
List<Integer> globalAggregationGroupIds = Ints.asList(42, 49);
RowPagesBuilder rowPagesBuilder = rowPagesBuilder(hashEnabled, groupByChannels, VARCHAR, VARCHAR, VARCHAR, BIGINT, BIGINT, BOOLEAN);
List<Page> input = rowPagesBuilder.build();
HashAggregationOperatorFactory operatorFactory = new HashAggregationOperatorFactory(0, new PlanNodeId("test"), ImmutableList.of(VARCHAR, BIGINT), groupByChannels, globalAggregationGroupIds, Step.SINGLE, ImmutableList.of(COUNT.bind(ImmutableList.of(0), Optional.empty()), LONG_SUM.bind(ImmutableList.of(4), Optional.empty()), LONG_AVERAGE.bind(ImmutableList.of(4), Optional.empty()), maxVarcharColumn.bind(ImmutableList.of(2), Optional.empty()), countVarcharColumn.bind(ImmutableList.of(0), Optional.empty()), countBooleanColumn.bind(ImmutableList.of(5), Optional.empty())), rowPagesBuilder.getHashChannel(), groupIdChannel, 100_000, new DataSize(16, MEGABYTE), joinCompiler);
MaterializedResult expected = resultBuilder(driverContext.getSession(), VARCHAR, BIGINT, BIGINT, BIGINT, DOUBLE, VARCHAR, BIGINT, BIGINT).row(null, 42L, 0L, null, null, null, 0L, 0L).row(null, 49L, 0L, null, null, null, 0L, 0L).build();
assertOperatorEqualsIgnoreOrder(operatorFactory, driverContext, input, expected, hashEnabled, Optional.of(groupByChannels.size()));
}
use of com.facebook.presto.operator.aggregation.InternalAggregationFunction in project presto by prestodb.
the class TestHashAggregationOperator method testSpillerFailure.
@Test(expectedExceptions = RuntimeException.class, expectedExceptionsMessageRegExp = ".* Failed to spill")
public void testSpillerFailure() {
MetadataManager metadata = MetadataManager.createTestMetadataManager();
InternalAggregationFunction maxVarcharColumn = metadata.getFunctionRegistry().getAggregateFunctionImplementation(new Signature("max", AGGREGATE, parseTypeSignature(StandardTypes.VARCHAR), parseTypeSignature(StandardTypes.VARCHAR)));
List<Integer> hashChannels = Ints.asList(1);
RowPagesBuilder rowPagesBuilder = rowPagesBuilder(false, hashChannels, VARCHAR, BIGINT, VARCHAR, BIGINT);
List<Page> input = rowPagesBuilder.addSequencePage(10, 100, 0, 100, 0).addSequencePage(10, 100, 0, 200, 0).addSequencePage(10, 100, 0, 300, 0).build();
DriverContext driverContext = createTaskContext(executor, TEST_SESSION, new DataSize(10, Unit.BYTE)).addPipelineContext(0, true, true).addDriverContext();
HashAggregationOperatorFactory operatorFactory = new HashAggregationOperatorFactory(0, new PlanNodeId("test"), ImmutableList.of(BIGINT), hashChannels, ImmutableList.of(), Step.SINGLE, ImmutableList.of(COUNT.bind(ImmutableList.of(0), Optional.empty()), LONG_SUM.bind(ImmutableList.of(3), Optional.empty()), LONG_AVERAGE.bind(ImmutableList.of(3), Optional.empty()), maxVarcharColumn.bind(ImmutableList.of(2), Optional.empty())), rowPagesBuilder.getHashChannel(), Optional.empty(), 100_000, new DataSize(16, MEGABYTE), true, succinctBytes(8), succinctBytes(Integer.MAX_VALUE), new FailingSpillerFactory(), joinCompiler);
toPages(operatorFactory, driverContext, input);
}
Aggregations