use of com.facebook.presto.spi.memory.MemoryPoolId in project presto by prestodb.
the class TestMemoryPools method testBlocking.
@Test
public void testBlocking() throws Exception {
Session session = testSessionBuilder().setCatalog("tpch").setSchema("tiny").setSystemProperty("task_default_concurrency", "1").build();
LocalQueryRunner localQueryRunner = queryRunnerWithInitialTransaction(session);
// add tpch
localQueryRunner.createCatalog("tpch", new TpchConnectorFactory(1), ImmutableMap.of());
// reserve all the memory in the pool
MemoryPool pool = new MemoryPool(new MemoryPoolId("test"), new DataSize(10, MEGABYTE));
QueryId fakeQueryId = new QueryId("fake");
assertTrue(pool.tryReserve(fakeQueryId, TEN_MEGABYTES));
MemoryPool systemPool = new MemoryPool(new MemoryPoolId("testSystem"), new DataSize(10, MEGABYTE));
QueryContext queryContext = new QueryContext(new QueryId("query"), new DataSize(10, MEGABYTE), pool, systemPool, localQueryRunner.getExecutor());
// discard all output
OutputFactory outputFactory = new PageConsumerOutputFactory(types -> (page -> {
}));
TaskContext taskContext = createTaskContext(queryContext, localQueryRunner.getExecutor(), session);
List<Driver> drivers = localQueryRunner.createDrivers("SELECT COUNT(*) FROM orders JOIN lineitem USING (orderkey)", outputFactory, taskContext);
// run driver, until it blocks
while (!isWaitingForMemory(drivers)) {
for (Driver driver : drivers) {
driver.process();
}
}
// driver should be blocked waiting for memory
for (Driver driver : drivers) {
assertFalse(driver.isFinished());
}
assertTrue(pool.getFreeBytes() <= 0);
pool.free(fakeQueryId, TEN_MEGABYTES);
do {
assertFalse(isWaitingForMemory(drivers));
boolean progress = false;
for (Driver driver : drivers) {
ListenableFuture<?> blocked = driver.process();
progress = progress | blocked.isDone();
}
// query should not block
assertTrue(progress);
} while (!drivers.stream().allMatch(Driver::isFinished));
}
use of com.facebook.presto.spi.memory.MemoryPoolId in project presto by prestodb.
the class ClusterMemoryManager method process.
public synchronized void process(Iterable<QueryExecution> queries) {
if (!enabled) {
return;
}
boolean outOfMemory = isClusterOutOfMemory();
if (!outOfMemory) {
lastTimeNotOutOfMemory = System.nanoTime();
}
boolean queryKilled = false;
long totalBytes = 0;
for (QueryExecution query : queries) {
long bytes = query.getTotalMemoryReservation();
DataSize sessionMaxQueryMemory = getQueryMaxMemory(query.getSession());
long queryMemoryLimit = Math.min(maxQueryMemory.toBytes(), sessionMaxQueryMemory.toBytes());
totalBytes += bytes;
if (resourceOvercommit(query.getSession()) && outOfMemory) {
// If a query has requested resource overcommit, only kill it if the cluster has run out of memory
DataSize memory = succinctBytes(bytes);
query.fail(new PrestoException(CLUSTER_OUT_OF_MEMORY, format("The cluster is out of memory and %s=true, so this query was killed. It was using %s of memory", RESOURCE_OVERCOMMIT, memory)));
queryKilled = true;
}
if (!resourceOvercommit(query.getSession()) && bytes > queryMemoryLimit) {
DataSize maxMemory = succinctBytes(queryMemoryLimit);
query.fail(exceededGlobalLimit(maxMemory));
queryKilled = true;
}
}
clusterMemoryUsageBytes.set(totalBytes);
if (killOnOutOfMemory) {
boolean shouldKillQuery = nanosSince(lastTimeNotOutOfMemory).compareTo(killOnOutOfMemoryDelay) > 0 && outOfMemory;
boolean lastKilledQueryIsGone = (lastKilledQuery == null);
if (!lastKilledQueryIsGone) {
ClusterMemoryPool generalPool = pools.get(GENERAL_POOL);
if (generalPool != null) {
lastKilledQueryIsGone = generalPool.getQueryMemoryReservations().containsKey(lastKilledQuery);
}
}
if (shouldKillQuery && lastKilledQueryIsGone && !queryKilled) {
// Kill the biggest query in the general pool
QueryExecution biggestQuery = null;
long maxMemory = -1;
for (QueryExecution query : queries) {
long bytesUsed = query.getTotalMemoryReservation();
if (bytesUsed > maxMemory && query.getMemoryPool().getId().equals(GENERAL_POOL)) {
biggestQuery = query;
maxMemory = bytesUsed;
}
}
if (biggestQuery != null) {
biggestQuery.fail(new PrestoException(CLUSTER_OUT_OF_MEMORY, "The cluster is out of memory, and your query was killed. Please try again in a few minutes."));
queriesKilledDueToOutOfMemory.incrementAndGet();
lastKilledQuery = biggestQuery.getQueryId();
}
}
}
Map<MemoryPoolId, Integer> countByPool = new HashMap<>();
for (QueryExecution query : queries) {
MemoryPoolId id = query.getMemoryPool().getId();
countByPool.put(id, countByPool.getOrDefault(id, 0) + 1);
}
updatePools(countByPool);
updateNodes(updateAssignments(queries));
// check if CPU usage is over limit
for (QueryExecution query : queries) {
Duration cpuTime = query.getTotalCpuTime();
Duration sessionLimit = getQueryMaxCpuTime(query.getSession());
Duration limit = maxQueryCpuTime.compareTo(sessionLimit) < 0 ? maxQueryCpuTime : sessionLimit;
if (cpuTime.compareTo(limit) > 0) {
query.fail(new ExceededCpuLimitException(limit));
}
}
}
use of com.facebook.presto.spi.memory.MemoryPoolId in project presto by prestodb.
the class PrestoSparkQueryExecutionFactory method createQueryInfo.
private static QueryInfo createQueryInfo(Session session, String query, QueryState queryState, Optional<PlanAndMore> planAndMore, Optional<String> sparkQueueName, Optional<ExecutionFailureInfo> failureInfo, QueryStateTimer queryStateTimer, Optional<StageInfo> rootStage, WarningCollector warningCollector) {
checkArgument(failureInfo.isPresent() || queryState != FAILED, "unexpected query state: %s", queryState);
int peakRunningTasks = 0;
long peakUserMemoryReservationInBytes = 0;
long peakTotalMemoryReservationInBytes = 0;
long peakTaskUserMemoryInBytes = 0;
long peakTaskTotalMemoryInBytes = 0;
long peakNodeTotalMemoryInBytes = 0;
for (StageInfo stageInfo : getAllStages(rootStage)) {
StageExecutionInfo stageExecutionInfo = stageInfo.getLatestAttemptExecutionInfo();
for (TaskInfo taskInfo : stageExecutionInfo.getTasks()) {
// there's no way to know how many tasks were running in parallel in Spark
// for now let's assume that all the tasks were running in parallel
peakRunningTasks++;
long taskPeakUserMemoryInBytes = taskInfo.getStats().getPeakUserMemoryInBytes();
long taskPeakTotalMemoryInBytes = taskInfo.getStats().getPeakTotalMemoryInBytes();
peakUserMemoryReservationInBytes += taskPeakUserMemoryInBytes;
peakTotalMemoryReservationInBytes += taskPeakTotalMemoryInBytes;
peakTaskUserMemoryInBytes = max(peakTaskUserMemoryInBytes, taskPeakUserMemoryInBytes);
peakTaskTotalMemoryInBytes = max(peakTaskTotalMemoryInBytes, taskPeakTotalMemoryInBytes);
peakNodeTotalMemoryInBytes = max(taskInfo.getStats().getPeakNodeTotalMemoryInBytes(), peakNodeTotalMemoryInBytes);
}
}
QueryStats queryStats = QueryStats.create(queryStateTimer, rootStage, peakRunningTasks, succinctBytes(peakUserMemoryReservationInBytes), succinctBytes(peakTotalMemoryReservationInBytes), succinctBytes(peakTaskUserMemoryInBytes), succinctBytes(peakTaskTotalMemoryInBytes), succinctBytes(peakNodeTotalMemoryInBytes), session.getRuntimeStats());
return new QueryInfo(session.getQueryId(), session.toSessionRepresentation(), queryState, new MemoryPoolId("spark-memory-pool"), queryStats.isScheduled(), URI.create("http://fake.invalid/query/" + session.getQueryId()), planAndMore.map(PlanAndMore::getFieldNames).orElse(ImmutableList.of()), query, Optional.empty(), Optional.empty(), queryStats, Optional.empty(), Optional.empty(), ImmutableMap.of(), ImmutableSet.of(), ImmutableMap.of(), ImmutableMap.of(), ImmutableSet.of(), Optional.empty(), false, planAndMore.flatMap(PlanAndMore::getUpdateType).orElse(null), rootStage, failureInfo.orElse(null), failureInfo.map(ExecutionFailureInfo::getErrorCode).orElse(null), warningCollector.getWarnings(), planAndMore.map(PlanAndMore::getInputs).orElse(ImmutableSet.of()), planAndMore.flatMap(PlanAndMore::getOutput), true, sparkQueueName.map(ResourceGroupId::new), planAndMore.flatMap(PlanAndMore::getQueryType), Optional.empty(), Optional.empty(), ImmutableMap.of(), ImmutableSet.of());
}
use of com.facebook.presto.spi.memory.MemoryPoolId in project presto by prestodb.
the class PrestoSparkTaskExecutorFactory method doCreate.
public <T extends PrestoSparkTaskOutput> IPrestoSparkTaskExecutor<T> doCreate(int partitionId, int attemptNumber, SerializedPrestoSparkTaskDescriptor serializedTaskDescriptor, Iterator<SerializedPrestoSparkTaskSource> serializedTaskSources, PrestoSparkTaskInputs inputs, CollectionAccumulator<SerializedTaskInfo> taskInfoCollector, CollectionAccumulator<PrestoSparkShuffleStats> shuffleStatsCollector, Class<T> outputType) {
PrestoSparkTaskDescriptor taskDescriptor = taskDescriptorJsonCodec.fromJson(serializedTaskDescriptor.getBytes());
ImmutableMap.Builder<String, TokenAuthenticator> extraAuthenticators = ImmutableMap.builder();
authenticatorProviders.forEach(provider -> extraAuthenticators.putAll(provider.getTokenAuthenticators()));
Session session = taskDescriptor.getSession().toSession(sessionPropertyManager, taskDescriptor.getExtraCredentials(), extraAuthenticators.build());
PlanFragment fragment = taskDescriptor.getFragment();
StageId stageId = new StageId(session.getQueryId(), fragment.getId().getId());
// Clear the cache if the cache does not have broadcast table for current stageId.
// We will only cache 1 HT at any time. If the stageId changes, we will drop the old cached HT
prestoSparkBroadcastTableCacheManager.removeCachedTablesForStagesOtherThan(stageId);
// TODO: include attemptId in taskId
TaskId taskId = new TaskId(new StageExecutionId(stageId, 0), partitionId);
List<TaskSource> taskSources = getTaskSources(serializedTaskSources);
log.info("Task [%s] received %d splits.", taskId, taskSources.stream().mapToInt(taskSource -> taskSource.getSplits().size()).sum());
OptionalLong totalSplitSize = computeAllSplitsSize(taskSources);
if (totalSplitSize.isPresent()) {
log.info("Total split size: %s bytes.", totalSplitSize.getAsLong());
}
// TODO: Remove this once we can display the plan on Spark UI.
log.info(PlanPrinter.textPlanFragment(fragment, functionAndTypeManager, session, true));
DataSize maxUserMemory = new DataSize(min(nodeMemoryConfig.getMaxQueryMemoryPerNode().toBytes(), getQueryMaxMemoryPerNode(session).toBytes()), BYTE);
DataSize maxTotalMemory = new DataSize(min(nodeMemoryConfig.getMaxQueryTotalMemoryPerNode().toBytes(), getQueryMaxTotalMemoryPerNode(session).toBytes()), BYTE);
DataSize maxBroadcastMemory = getSparkBroadcastJoinMaxMemoryOverride(session);
if (maxBroadcastMemory == null) {
maxBroadcastMemory = new DataSize(min(nodeMemoryConfig.getMaxQueryBroadcastMemory().toBytes(), getQueryMaxBroadcastMemory(session).toBytes()), BYTE);
}
MemoryPool memoryPool = new MemoryPool(new MemoryPoolId("spark-executor-memory-pool"), maxTotalMemory);
SpillSpaceTracker spillSpaceTracker = new SpillSpaceTracker(maxQuerySpillPerNode);
QueryContext queryContext = new QueryContext(session.getQueryId(), maxUserMemory, maxTotalMemory, maxBroadcastMemory, maxRevocableMemory, memoryPool, new TestingGcMonitor(), notificationExecutor, yieldExecutor, maxQuerySpillPerNode, spillSpaceTracker, memoryReservationSummaryJsonCodec);
queryContext.setVerboseExceededMemoryLimitErrorsEnabled(isVerboseExceededMemoryLimitErrorsEnabled(session));
queryContext.setHeapDumpOnExceededMemoryLimitEnabled(isHeapDumpOnExceededMemoryLimitEnabled(session));
String heapDumpFilePath = Paths.get(getHeapDumpFileDirectory(session), format("%s_%s.hprof", session.getQueryId().getId(), stageId.getId())).toString();
queryContext.setHeapDumpFilePath(heapDumpFilePath);
TaskStateMachine taskStateMachine = new TaskStateMachine(taskId, notificationExecutor);
TaskContext taskContext = queryContext.addTaskContext(taskStateMachine, session, // Plan has to be retained only if verbose memory exceeded errors are requested
isVerboseExceededMemoryLimitErrorsEnabled(session) ? Optional.of(fragment.getRoot()) : Optional.empty(), perOperatorCpuTimerEnabled, cpuTimerEnabled, perOperatorAllocationTrackingEnabled, allocationTrackingEnabled, false);
final double memoryRevokingThreshold = getMemoryRevokingThreshold(session);
final double memoryRevokingTarget = getMemoryRevokingTarget(session);
checkArgument(memoryRevokingTarget <= memoryRevokingThreshold, "memoryRevokingTarget should be less than or equal memoryRevokingThreshold, but got %s and %s respectively", memoryRevokingTarget, memoryRevokingThreshold);
if (isSpillEnabled(session)) {
memoryPool.addListener((pool, queryId, totalMemoryReservationBytes) -> {
if (totalMemoryReservationBytes > queryContext.getPeakNodeTotalMemory()) {
queryContext.setPeakNodeTotalMemory(totalMemoryReservationBytes);
}
if (totalMemoryReservationBytes > pool.getMaxBytes() * memoryRevokingThreshold && memoryRevokeRequestInProgress.compareAndSet(false, true)) {
memoryRevocationExecutor.execute(() -> {
try {
AtomicLong remainingBytesToRevoke = new AtomicLong(totalMemoryReservationBytes - (long) (memoryRevokingTarget * pool.getMaxBytes()));
remainingBytesToRevoke.addAndGet(-MemoryRevokingSchedulerUtils.getMemoryAlreadyBeingRevoked(ImmutableList.of(taskContext), remainingBytesToRevoke.get()));
taskContext.accept(new VoidTraversingQueryContextVisitor<AtomicLong>() {
@Override
public Void visitOperatorContext(OperatorContext operatorContext, AtomicLong remainingBytesToRevoke) {
if (remainingBytesToRevoke.get() > 0) {
long revokedBytes = operatorContext.requestMemoryRevoking();
if (revokedBytes > 0) {
memoryRevokePending.set(true);
remainingBytesToRevoke.addAndGet(-revokedBytes);
}
}
return null;
}
}, remainingBytesToRevoke);
memoryRevokeRequestInProgress.set(false);
} catch (Exception e) {
log.error(e, "Error requesting memory revoking");
}
});
}
// Get the latest memory reservation info since it might have changed due to revoke
long totalReservedMemory = pool.getQueryMemoryReservation(queryId) + pool.getQueryRevocableMemoryReservation(queryId);
// If total memory usage is over maxTotalMemory and memory revoke request is not pending, fail the query with EXCEEDED_MEMORY_LIMIT error
if (totalReservedMemory > maxTotalMemory.toBytes() && !memoryRevokeRequestInProgress.get() && !isMemoryRevokePending(taskContext)) {
throw exceededLocalTotalMemoryLimit(maxTotalMemory, queryContext.getAdditionalFailureInfo(totalReservedMemory, 0) + format("Total reserved memory: %s, Total revocable memory: %s", succinctBytes(pool.getQueryMemoryReservation(queryId)), succinctBytes(pool.getQueryRevocableMemoryReservation(queryId))), isHeapDumpOnExceededMemoryLimitEnabled(session), Optional.ofNullable(heapDumpFilePath));
}
});
}
ImmutableMap.Builder<PlanNodeId, List<PrestoSparkShuffleInput>> shuffleInputs = ImmutableMap.builder();
ImmutableMap.Builder<PlanNodeId, List<java.util.Iterator<PrestoSparkSerializedPage>>> pageInputs = ImmutableMap.builder();
ImmutableMap.Builder<PlanNodeId, List<?>> broadcastInputs = ImmutableMap.builder();
for (RemoteSourceNode remoteSource : fragment.getRemoteSourceNodes()) {
List<PrestoSparkShuffleInput> remoteSourceRowInputs = new ArrayList<>();
List<java.util.Iterator<PrestoSparkSerializedPage>> remoteSourcePageInputs = new ArrayList<>();
List<List<?>> broadcastInputsList = new ArrayList<>();
for (PlanFragmentId sourceFragmentId : remoteSource.getSourceFragmentIds()) {
Iterator<Tuple2<MutablePartitionId, PrestoSparkMutableRow>> shuffleInput = inputs.getShuffleInputs().get(sourceFragmentId.toString());
Broadcast<?> broadcastInput = inputs.getBroadcastInputs().get(sourceFragmentId.toString());
List<PrestoSparkSerializedPage> inMemoryInput = inputs.getInMemoryInputs().get(sourceFragmentId.toString());
if (shuffleInput != null) {
checkArgument(broadcastInput == null, "single remote source is not expected to accept different kind of inputs");
checkArgument(inMemoryInput == null, "single remote source is not expected to accept different kind of inputs");
remoteSourceRowInputs.add(new PrestoSparkShuffleInput(sourceFragmentId.getId(), shuffleInput));
continue;
}
if (broadcastInput != null) {
checkArgument(inMemoryInput == null, "single remote source is not expected to accept different kind of inputs");
// TODO: Enable NullifyingIterator once migrated to one task per JVM model
// NullifyingIterator removes element from the list upon return
// This allows GC to gradually reclaim memory
// remoteSourcePageInputs.add(getNullifyingIterator(broadcastInput.value()));
broadcastInputsList.add((List<?>) broadcastInput.value());
continue;
}
if (inMemoryInput != null) {
// for inmemory inputs pages can be released incrementally to save memory
remoteSourcePageInputs.add(getNullifyingIterator(inMemoryInput));
continue;
}
throw new IllegalArgumentException("Input not found for sourceFragmentId: " + sourceFragmentId);
}
if (!remoteSourceRowInputs.isEmpty()) {
shuffleInputs.put(remoteSource.getId(), remoteSourceRowInputs);
}
if (!remoteSourcePageInputs.isEmpty()) {
pageInputs.put(remoteSource.getId(), remoteSourcePageInputs);
}
if (!broadcastInputsList.isEmpty()) {
broadcastInputs.put(remoteSource.getId(), broadcastInputsList);
}
}
OutputBufferMemoryManager memoryManager = new OutputBufferMemoryManager(sinkMaxBufferSize.toBytes(), () -> queryContext.getTaskContextByTaskId(taskId).localSystemMemoryContext(), notificationExecutor);
Optional<OutputPartitioning> preDeterminedPartition = Optional.empty();
if (fragment.getPartitioningScheme().getPartitioning().getHandle().equals(FIXED_ARBITRARY_DISTRIBUTION)) {
int partitionCount = getHashPartitionCount(session);
preDeterminedPartition = Optional.of(new OutputPartitioning(new PreDeterminedPartitionFunction(partitionId % partitionCount, partitionCount), ImmutableList.of(), ImmutableList.of(), false, OptionalInt.empty()));
}
TempDataOperationContext tempDataOperationContext = new TempDataOperationContext(session.getSource(), session.getQueryId().getId(), session.getClientInfo(), Optional.of(session.getClientTags()), session.getIdentity());
TempStorage tempStorage = tempStorageManager.getTempStorage(storageBasedBroadcastJoinStorage);
Output<T> output = configureOutput(outputType, blockEncodingManager, memoryManager, getShuffleOutputTargetAverageRowSize(session), preDeterminedPartition, tempStorage, tempDataOperationContext, getStorageBasedBroadcastJoinWriteBufferSize(session));
PrestoSparkOutputBuffer<?> outputBuffer = output.getOutputBuffer();
LocalExecutionPlan localExecutionPlan = localExecutionPlanner.plan(taskContext, fragment.getRoot(), fragment.getPartitioningScheme(), fragment.getStageExecutionDescriptor(), fragment.getTableScanSchedulingOrder(), output.getOutputFactory(), new PrestoSparkRemoteSourceFactory(blockEncodingManager, shuffleInputs.build(), pageInputs.build(), broadcastInputs.build(), partitionId, shuffleStatsCollector, tempStorage, tempDataOperationContext, prestoSparkBroadcastTableCacheManager, stageId), taskDescriptor.getTableWriteInfo(), true);
taskStateMachine.addStateChangeListener(state -> {
if (state.isDone()) {
outputBuffer.setNoMoreRows();
}
});
PrestoSparkTaskExecution taskExecution = new PrestoSparkTaskExecution(taskStateMachine, taskContext, localExecutionPlan, taskExecutor, splitMonitor, notificationExecutor, memoryUpdateExecutor);
taskExecution.start(taskSources);
return new PrestoSparkTaskExecutor<>(taskContext, taskStateMachine, output.getOutputSupplier(), taskInfoCodec, taskInfoCollector, shuffleStatsCollector, executionExceptionFactory, output.getOutputBufferType(), outputBuffer, tempStorage, tempDataOperationContext);
}
use of com.facebook.presto.spi.memory.MemoryPoolId in project presto by prestodb.
the class GroupByHashYieldAssertion method finishOperatorWithYieldingGroupByHash.
/**
* @param operatorFactory creates an Operator that should directly or indirectly contain GroupByHash
* @param getHashCapacity returns the hash table capacity for the input operator
* @param additionalMemoryInBytes the memory used in addition to the GroupByHash in the operator (e.g., aggregator)
*/
public static GroupByHashYieldResult finishOperatorWithYieldingGroupByHash(List<Page> input, Type hashKeyType, OperatorFactory operatorFactory, Function<Operator, Integer> getHashCapacity, long additionalMemoryInBytes) {
assertLessThan(additionalMemoryInBytes, 1L << 21, "additionalMemoryInBytes should be a relatively small number");
List<Page> result = new LinkedList<>();
// mock an adjustable memory pool
QueryId queryId1 = new QueryId("test_query1");
QueryId queryId2 = new QueryId("test_query2");
MemoryPool memoryPool = new MemoryPool(new MemoryPoolId("test"), new DataSize(1, GIGABYTE));
QueryContext queryContext = new QueryContext(queryId2, new DataSize(512, MEGABYTE), new DataSize(1024, MEGABYTE), new DataSize(512, MEGABYTE), new DataSize(1, GIGABYTE), memoryPool, new TestingGcMonitor(), EXECUTOR, SCHEDULED_EXECUTOR, new DataSize(512, MEGABYTE), new SpillSpaceTracker(new DataSize(512, MEGABYTE)), listJsonCodec(TaskMemoryReservationSummary.class));
DriverContext driverContext = createTaskContext(queryContext, EXECUTOR, TEST_SESSION).addPipelineContext(0, true, true, false).addDriverContext();
Operator operator = operatorFactory.createOperator(driverContext);
// run operator
int yieldCount = 0;
long expectedReservedExtraBytes = 0;
for (Page page : input) {
// unblocked
assertTrue(operator.needsInput());
// saturate the pool with a tiny memory left
long reservedMemoryInBytes = memoryPool.getFreeBytes() - additionalMemoryInBytes;
memoryPool.reserve(queryId1, "test", reservedMemoryInBytes);
long oldMemoryUsage = operator.getOperatorContext().getDriverContext().getMemoryUsage();
int oldCapacity = getHashCapacity.apply(operator);
// add a page and verify different behaviors
operator.addInput(page);
// get output to consume the input
Page output = operator.getOutput();
if (output != null) {
result.add(output);
}
long newMemoryUsage = operator.getOperatorContext().getDriverContext().getMemoryUsage();
// between rehash and memory used by aggregator
if (newMemoryUsage < new DataSize(4, MEGABYTE).toBytes()) {
// free the pool for the next iteration
memoryPool.free(queryId1, "test", reservedMemoryInBytes);
// this required in case input is blocked
operator.getOutput();
continue;
}
long actualIncreasedMemory = newMemoryUsage - oldMemoryUsage;
if (operator.needsInput()) {
// We have successfully added a page
// Assert we are not blocked
assertTrue(operator.getOperatorContext().isWaitingForMemory().isDone());
// assert the hash capacity is not changed; otherwise, we should have yielded
assertTrue(oldCapacity == getHashCapacity.apply(operator));
// We are not going to rehash; therefore, assert the memory increase only comes from the aggregator
assertLessThan(actualIncreasedMemory, additionalMemoryInBytes);
// free the pool for the next iteration
memoryPool.free(queryId1, "test", reservedMemoryInBytes);
} else {
// We failed to finish the page processing i.e. we yielded
yieldCount++;
// Assert we are blocked
assertFalse(operator.getOperatorContext().isWaitingForMemory().isDone());
// Hash table capacity should not change
assertEquals(oldCapacity, (long) getHashCapacity.apply(operator));
// Increased memory is no smaller than the hash table size and no greater than the hash table size + the memory used by aggregator
if (hashKeyType == BIGINT) {
// groupIds and values double by hashCapacity; while valuesByGroupId double by maxFill = hashCapacity / 0.75
expectedReservedExtraBytes = oldCapacity * (long) (Long.BYTES * 1.75 + Integer.BYTES) + page.getRetainedSizeInBytes();
} else {
// groupAddressByHash, groupIdsByHash, and rawHashByHashPosition double by hashCapacity; while groupAddressByGroupId double by maxFill = hashCapacity / 0.75
expectedReservedExtraBytes = oldCapacity * (long) (Long.BYTES * 1.75 + Integer.BYTES + Byte.BYTES) + page.getRetainedSizeInBytes();
}
assertBetweenInclusive(actualIncreasedMemory, expectedReservedExtraBytes, expectedReservedExtraBytes + additionalMemoryInBytes);
// Output should be blocked as well
assertNull(operator.getOutput());
// Free the pool to unblock
memoryPool.free(queryId1, "test", reservedMemoryInBytes);
// Trigger a process through getOutput() or needsInput()
output = operator.getOutput();
if (output != null) {
result.add(output);
}
assertTrue(operator.needsInput());
// Hash table capacity has increased
assertGreaterThan(getHashCapacity.apply(operator), oldCapacity);
// Assert the estimated reserved memory before rehash is very close to the one after rehash
long rehashedMemoryUsage = operator.getOperatorContext().getDriverContext().getMemoryUsage();
assertBetweenInclusive(rehashedMemoryUsage * 1.0 / newMemoryUsage, 0.99, 1.01);
// unblocked
assertTrue(operator.needsInput());
}
}
result.addAll(finishOperator(operator));
return new GroupByHashYieldResult(yieldCount, expectedReservedExtraBytes, result);
}
Aggregations