use of org.apache.druid.query.aggregation.variance.StandardDeviationPostAggregator in project druid by druid-io.
the class VarianceSqlAggregatorTest method testStdDevWithVirtualColumns.
@Test
public void testStdDevWithVirtualColumns() throws Exception {
VarianceAggregatorCollector holder1 = new VarianceAggregatorCollector();
VarianceAggregatorCollector holder2 = new VarianceAggregatorCollector();
VarianceAggregatorCollector holder3 = new VarianceAggregatorCollector();
for (InputRow row : CalciteTests.ROWS1_WITH_NUMERIC_DIMS) {
Object raw1 = row.getRaw("d1");
Object raw2 = row.getRaw("f1");
Object raw3 = row.getRaw("l1");
addToHolder(holder1, raw1, 7);
addToHolder(holder2, raw2, 7);
addToHolder(holder3, raw3, 7);
}
final List<Object[]> expectedResults = ImmutableList.of(new Object[] { Math.sqrt(holder1.getVariance(false)), (float) Math.sqrt(holder2.getVariance(false)), (long) Math.sqrt(holder3.getVariance(false)) });
testQuery("SELECT\n" + "STDDEV(d1*7),\n" + "STDDEV(f1*7),\n" + "STDDEV(l1*7)\n" + "FROM numfoo", ImmutableList.of(Druids.newTimeseriesQueryBuilder().dataSource(CalciteTests.DATASOURCE3).intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity()))).granularity(Granularities.ALL).virtualColumns(BaseCalciteQueryTest.expressionVirtualColumn("v0", "(\"d1\" * 7)", ColumnType.DOUBLE), BaseCalciteQueryTest.expressionVirtualColumn("v1", "(\"f1\" * 7)", ColumnType.FLOAT), BaseCalciteQueryTest.expressionVirtualColumn("v2", "(\"l1\" * 7)", ColumnType.LONG)).aggregators(ImmutableList.of(new VarianceAggregatorFactory("a0:agg", "v0", "sample", "double"), new VarianceAggregatorFactory("a1:agg", "v1", "sample", "float"), new VarianceAggregatorFactory("a2:agg", "v2", "sample", "long"))).postAggregators(new StandardDeviationPostAggregator("a0", "a0:agg", "sample"), new StandardDeviationPostAggregator("a1", "a1:agg", "sample"), new StandardDeviationPostAggregator("a2", "a2:agg", "sample")).context(BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT).build()), expectedResults);
}
use of org.apache.druid.query.aggregation.variance.StandardDeviationPostAggregator in project druid by druid-io.
the class VarianceSqlAggregatorTest method testStdDevPop.
@Test
public void testStdDevPop() throws Exception {
VarianceAggregatorCollector holder1 = new VarianceAggregatorCollector();
VarianceAggregatorCollector holder2 = new VarianceAggregatorCollector();
VarianceAggregatorCollector holder3 = new VarianceAggregatorCollector();
for (InputRow row : CalciteTests.ROWS1_WITH_NUMERIC_DIMS) {
Object raw1 = row.getRaw("d1");
Object raw2 = row.getRaw("f1");
Object raw3 = row.getRaw("l1");
addToHolder(holder1, raw1);
addToHolder(holder2, raw2);
addToHolder(holder3, raw3);
}
final List<Object[]> expectedResults = ImmutableList.of(new Object[] { Math.sqrt(holder1.getVariance(true)), (float) Math.sqrt(holder2.getVariance(true)), (long) Math.sqrt(holder3.getVariance(true)) });
testQuery("SELECT\n" + "STDDEV_POP(d1),\n" + "STDDEV_POP(f1),\n" + "STDDEV_POP(l1)\n" + "FROM numfoo", ImmutableList.of(Druids.newTimeseriesQueryBuilder().dataSource(CalciteTests.DATASOURCE3).intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity()))).granularity(Granularities.ALL).aggregators(ImmutableList.of(new VarianceAggregatorFactory("a0:agg", "d1", "population", "double"), new VarianceAggregatorFactory("a1:agg", "f1", "population", "float"), new VarianceAggregatorFactory("a2:agg", "l1", "population", "long"))).postAggregators(ImmutableList.of(new StandardDeviationPostAggregator("a0", "a0:agg", "population"), new StandardDeviationPostAggregator("a1", "a1:agg", "population"), new StandardDeviationPostAggregator("a2", "a2:agg", "population"))).context(BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT).build()), expectedResults);
}
use of org.apache.druid.query.aggregation.variance.StandardDeviationPostAggregator in project druid by druid-io.
the class BaseVarianceSqlAggregator method toDruidAggregation.
@Nullable
@Override
public Aggregation toDruidAggregation(PlannerContext plannerContext, RowSignature rowSignature, VirtualColumnRegistry virtualColumnRegistry, RexBuilder rexBuilder, String name, AggregateCall aggregateCall, Project project, List<Aggregation> existingAggregations, boolean finalizeAggregations) {
final RexNode inputOperand = Expressions.fromFieldAccess(rowSignature, project, aggregateCall.getArgList().get(0));
final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator(plannerContext, rowSignature, inputOperand);
if (input == null) {
return null;
}
final AggregatorFactory aggregatorFactory;
final RelDataType dataType = inputOperand.getType();
final ColumnType inputType = Calcites.getColumnTypeForRelDataType(dataType);
final DimensionSpec dimensionSpec;
final String aggName = StringUtils.format("%s:agg", name);
final SqlAggFunction func = calciteFunction();
final String estimator;
final String inputTypeName;
PostAggregator postAggregator = null;
if (input.isSimpleExtraction()) {
dimensionSpec = input.getSimpleExtraction().toDimensionSpec(null, inputType);
} else {
String virtualColumnName = virtualColumnRegistry.getOrCreateVirtualColumnForExpression(input, dataType);
dimensionSpec = new DefaultDimensionSpec(virtualColumnName, null, inputType);
}
if (inputType == null) {
throw new IAE("VarianceSqlAggregator[%s] has invalid inputType", func);
}
if (inputType.isNumeric()) {
inputTypeName = StringUtils.toLowerCase(inputType.getType().name());
} else {
throw new IAE("VarianceSqlAggregator[%s] has invalid inputType[%s]", func, inputType.asTypeString());
}
if (func == SqlStdOperatorTable.VAR_POP || func == SqlStdOperatorTable.STDDEV_POP) {
estimator = "population";
} else {
estimator = "sample";
}
aggregatorFactory = new VarianceAggregatorFactory(aggName, dimensionSpec.getDimension(), estimator, inputTypeName);
if (func == SqlStdOperatorTable.STDDEV_POP || func == SqlStdOperatorTable.STDDEV_SAMP || func == SqlStdOperatorTable.STDDEV) {
postAggregator = new StandardDeviationPostAggregator(name, aggregatorFactory.getName(), estimator);
}
return Aggregation.create(ImmutableList.of(aggregatorFactory), postAggregator);
}
use of org.apache.druid.query.aggregation.variance.StandardDeviationPostAggregator in project druid by druid-io.
the class VarianceSqlAggregatorTest method testStdDevSamp.
@Test
public void testStdDevSamp() throws Exception {
VarianceAggregatorCollector holder1 = new VarianceAggregatorCollector();
VarianceAggregatorCollector holder2 = new VarianceAggregatorCollector();
VarianceAggregatorCollector holder3 = new VarianceAggregatorCollector();
for (InputRow row : CalciteTests.ROWS1_WITH_NUMERIC_DIMS) {
Object raw1 = row.getRaw("d1");
Object raw2 = row.getRaw("f1");
Object raw3 = row.getRaw("l1");
addToHolder(holder1, raw1);
addToHolder(holder2, raw2);
addToHolder(holder3, raw3);
}
final List<Object[]> expectedResults = ImmutableList.of(new Object[] { Math.sqrt(holder1.getVariance(false)), (float) Math.sqrt(holder2.getVariance(false)), (long) Math.sqrt(holder3.getVariance(false)) });
testQuery("SELECT\n" + "STDDEV_SAMP(d1),\n" + "STDDEV_SAMP(f1),\n" + "STDDEV_SAMP(l1)\n" + "FROM numfoo", ImmutableList.of(Druids.newTimeseriesQueryBuilder().dataSource(CalciteTests.DATASOURCE3).intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity()))).granularity(Granularities.ALL).aggregators(ImmutableList.of(new VarianceAggregatorFactory("a0:agg", "d1", "sample", "double"), new VarianceAggregatorFactory("a1:agg", "f1", "sample", "float"), new VarianceAggregatorFactory("a2:agg", "l1", "sample", "long"))).postAggregators(new StandardDeviationPostAggregator("a0", "a0:agg", "sample"), new StandardDeviationPostAggregator("a1", "a1:agg", "sample"), new StandardDeviationPostAggregator("a2", "a2:agg", "sample")).context(BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT).build()), expectedResults);
}
Aggregations