Search in sources :

Example 1 with LongDecimalWithOverflowAndLongStateFactory

use of io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongStateFactory in project trino by trinodb.

the class DecimalAverageAggregation method specialize.

@Override
public AggregationMetadata specialize(BoundSignature boundSignature) {
    Type type = getOnlyElement(boundSignature.getArgumentTypes());
    checkArgument(type instanceof DecimalType, "type must be Decimal");
    MethodHandle inputFunction;
    MethodHandle outputFunction;
    Class<LongDecimalWithOverflowAndLongState> stateInterface = LongDecimalWithOverflowAndLongState.class;
    LongDecimalWithOverflowAndLongStateSerializer stateSerializer = new LongDecimalWithOverflowAndLongStateSerializer();
    if (((DecimalType) type).isShort()) {
        inputFunction = SHORT_DECIMAL_INPUT_FUNCTION;
        outputFunction = SHORT_DECIMAL_OUTPUT_FUNCTION;
    } else {
        inputFunction = LONG_DECIMAL_INPUT_FUNCTION;
        outputFunction = LONG_DECIMAL_OUTPUT_FUNCTION;
    }
    outputFunction = outputFunction.bindTo(type);
    return new AggregationMetadata(inputFunction, Optional.empty(), Optional.of(COMBINE_FUNCTION), outputFunction, ImmutableList.of(new AccumulatorStateDescriptor<>(stateInterface, stateSerializer, new LongDecimalWithOverflowAndLongStateFactory())));
}
Also used : Type(io.trino.spi.type.Type) DecimalType(io.trino.spi.type.DecimalType) AccumulatorStateDescriptor(io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor) DecimalType(io.trino.spi.type.DecimalType) LongDecimalWithOverflowAndLongStateSerializer(io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongStateSerializer) LongDecimalWithOverflowAndLongStateFactory(io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongStateFactory) LongDecimalWithOverflowAndLongState(io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongState) MethodHandle(java.lang.invoke.MethodHandle)

Example 2 with LongDecimalWithOverflowAndLongStateFactory

use of io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongStateFactory in project trino by trinodb.

the class TestDecimalAverageAggregation method testCombineOverflow.

@Test
public void testCombineOverflow() {
    addToState(state, TWO.pow(126));
    addToState(state, TWO.pow(126));
    LongDecimalWithOverflowAndLongState otherState = new LongDecimalWithOverflowAndLongStateFactory().createSingleState();
    addToState(otherState, TWO.pow(126));
    addToState(otherState, TWO.pow(126));
    DecimalAverageAggregation.combine(state, otherState);
    assertEquals(state.getLong(), 4);
    assertEquals(state.getOverflow(), 1);
    assertEquals(getDecimal(state), Int128.ZERO);
    BigInteger expectedAverage = BigInteger.ZERO.add(TWO.pow(126)).add(TWO.pow(126)).add(TWO.pow(126)).add(TWO.pow(126)).divide(BigInteger.valueOf(4));
    assertAverageEquals(expectedAverage);
}
Also used : BigInteger(java.math.BigInteger) LongDecimalWithOverflowAndLongStateFactory(io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongStateFactory) LongDecimalWithOverflowAndLongState(io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongState) Test(org.testng.annotations.Test)

Example 3 with LongDecimalWithOverflowAndLongStateFactory

use of io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongStateFactory in project trino by trinodb.

the class TestDecimalAverageAggregation method testCombineUnderflow.

@Test
public void testCombineUnderflow() {
    addToState(state, TWO.pow(125).negate());
    addToState(state, TWO.pow(126).negate());
    LongDecimalWithOverflowAndLongState otherState = new LongDecimalWithOverflowAndLongStateFactory().createSingleState();
    addToState(otherState, TWO.pow(125).negate());
    addToState(otherState, TWO.pow(126).negate());
    DecimalAverageAggregation.combine(state, otherState);
    assertEquals(state.getLong(), 4);
    assertEquals(state.getOverflow(), -1);
    assertEquals(getDecimal(state), Int128.valueOf(1L << 62, 0));
    BigInteger expectedAverage = BigInteger.ZERO.add(TWO.pow(126)).add(TWO.pow(126)).add(TWO.pow(125)).add(TWO.pow(125)).negate().divide(BigInteger.valueOf(4));
    assertAverageEquals(expectedAverage);
}
Also used : BigInteger(java.math.BigInteger) LongDecimalWithOverflowAndLongStateFactory(io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongStateFactory) LongDecimalWithOverflowAndLongState(io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongState) Test(org.testng.annotations.Test)

Example 4 with LongDecimalWithOverflowAndLongStateFactory

use of io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongStateFactory in project trino by trinodb.

the class TestDecimalAverageAggregation method testNoOverflow.

private void testNoOverflow(DecimalType type, List<BigInteger> numbers) {
    LongDecimalWithOverflowAndLongState state = new LongDecimalWithOverflowAndLongStateFactory().createSingleState();
    for (BigInteger number : numbers) {
        addToState(type, state, number);
    }
    assertEquals(state.getOverflow(), 0);
    BigInteger sum = numbers.stream().reduce(BigInteger.ZERO, BigInteger::add);
    assertEquals(getDecimal(state), Int128.valueOf(sum));
    BigDecimal expectedAverage = new BigDecimal(sum, type.getScale()).divide(BigDecimal.valueOf(numbers.size()), type.getScale(), HALF_UP);
    assertEquals(decodeBigDecimal(type, average(state, type)), expectedAverage);
}
Also used : BigInteger(java.math.BigInteger) LongDecimalWithOverflowAndLongStateFactory(io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongStateFactory) LongDecimalWithOverflowAndLongState(io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongState) BigDecimal(java.math.BigDecimal)

Aggregations

LongDecimalWithOverflowAndLongState (io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongState)4 LongDecimalWithOverflowAndLongStateFactory (io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongStateFactory)4 BigInteger (java.math.BigInteger)3 Test (org.testng.annotations.Test)2 AccumulatorStateDescriptor (io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor)1 LongDecimalWithOverflowAndLongStateSerializer (io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongStateSerializer)1 DecimalType (io.trino.spi.type.DecimalType)1 Type (io.trino.spi.type.Type)1 MethodHandle (java.lang.invoke.MethodHandle)1 BigDecimal (java.math.BigDecimal)1