Search in sources :

Example 1 with MemoryPool

use of com.facebook.presto.memory.MemoryPool in project presto by prestodb.

the class PrestoSparkTaskExecutorFactory method doCreate.

public <T extends PrestoSparkTaskOutput> IPrestoSparkTaskExecutor<T> doCreate(int partitionId, int attemptNumber, SerializedPrestoSparkTaskDescriptor serializedTaskDescriptor, Iterator<SerializedPrestoSparkTaskSource> serializedTaskSources, PrestoSparkTaskInputs inputs, CollectionAccumulator<SerializedTaskInfo> taskInfoCollector, CollectionAccumulator<PrestoSparkShuffleStats> shuffleStatsCollector, Class<T> outputType) {
    PrestoSparkTaskDescriptor taskDescriptor = taskDescriptorJsonCodec.fromJson(serializedTaskDescriptor.getBytes());
    ImmutableMap.Builder<String, TokenAuthenticator> extraAuthenticators = ImmutableMap.builder();
    authenticatorProviders.forEach(provider -> extraAuthenticators.putAll(provider.getTokenAuthenticators()));
    Session session = taskDescriptor.getSession().toSession(sessionPropertyManager, taskDescriptor.getExtraCredentials(), extraAuthenticators.build());
    PlanFragment fragment = taskDescriptor.getFragment();
    StageId stageId = new StageId(session.getQueryId(), fragment.getId().getId());
    // Clear the cache if the cache does not have broadcast table for current stageId.
    // We will only cache 1 HT at any time. If the stageId changes, we will drop the old cached HT
    prestoSparkBroadcastTableCacheManager.removeCachedTablesForStagesOtherThan(stageId);
    // TODO: include attemptId in taskId
    TaskId taskId = new TaskId(new StageExecutionId(stageId, 0), partitionId);
    List<TaskSource> taskSources = getTaskSources(serializedTaskSources);
    log.info("Task [%s] received %d splits.", taskId, taskSources.stream().mapToInt(taskSource -> taskSource.getSplits().size()).sum());
    OptionalLong totalSplitSize = computeAllSplitsSize(taskSources);
    if (totalSplitSize.isPresent()) {
        log.info("Total split size: %s bytes.", totalSplitSize.getAsLong());
    }
    // TODO: Remove this once we can display the plan on Spark UI.
    log.info(PlanPrinter.textPlanFragment(fragment, functionAndTypeManager, session, true));
    DataSize maxUserMemory = new DataSize(min(nodeMemoryConfig.getMaxQueryMemoryPerNode().toBytes(), getQueryMaxMemoryPerNode(session).toBytes()), BYTE);
    DataSize maxTotalMemory = new DataSize(min(nodeMemoryConfig.getMaxQueryTotalMemoryPerNode().toBytes(), getQueryMaxTotalMemoryPerNode(session).toBytes()), BYTE);
    DataSize maxBroadcastMemory = getSparkBroadcastJoinMaxMemoryOverride(session);
    if (maxBroadcastMemory == null) {
        maxBroadcastMemory = new DataSize(min(nodeMemoryConfig.getMaxQueryBroadcastMemory().toBytes(), getQueryMaxBroadcastMemory(session).toBytes()), BYTE);
    }
    MemoryPool memoryPool = new MemoryPool(new MemoryPoolId("spark-executor-memory-pool"), maxTotalMemory);
    SpillSpaceTracker spillSpaceTracker = new SpillSpaceTracker(maxQuerySpillPerNode);
    QueryContext queryContext = new QueryContext(session.getQueryId(), maxUserMemory, maxTotalMemory, maxBroadcastMemory, maxRevocableMemory, memoryPool, new TestingGcMonitor(), notificationExecutor, yieldExecutor, maxQuerySpillPerNode, spillSpaceTracker, memoryReservationSummaryJsonCodec);
    queryContext.setVerboseExceededMemoryLimitErrorsEnabled(isVerboseExceededMemoryLimitErrorsEnabled(session));
    queryContext.setHeapDumpOnExceededMemoryLimitEnabled(isHeapDumpOnExceededMemoryLimitEnabled(session));
    String heapDumpFilePath = Paths.get(getHeapDumpFileDirectory(session), format("%s_%s.hprof", session.getQueryId().getId(), stageId.getId())).toString();
    queryContext.setHeapDumpFilePath(heapDumpFilePath);
    TaskStateMachine taskStateMachine = new TaskStateMachine(taskId, notificationExecutor);
    TaskContext taskContext = queryContext.addTaskContext(taskStateMachine, session, // Plan has to be retained only if verbose memory exceeded errors are requested
    isVerboseExceededMemoryLimitErrorsEnabled(session) ? Optional.of(fragment.getRoot()) : Optional.empty(), perOperatorCpuTimerEnabled, cpuTimerEnabled, perOperatorAllocationTrackingEnabled, allocationTrackingEnabled, false);
    final double memoryRevokingThreshold = getMemoryRevokingThreshold(session);
    final double memoryRevokingTarget = getMemoryRevokingTarget(session);
    checkArgument(memoryRevokingTarget <= memoryRevokingThreshold, "memoryRevokingTarget should be less than or equal memoryRevokingThreshold, but got %s and %s respectively", memoryRevokingTarget, memoryRevokingThreshold);
    if (isSpillEnabled(session)) {
        memoryPool.addListener((pool, queryId, totalMemoryReservationBytes) -> {
            if (totalMemoryReservationBytes > queryContext.getPeakNodeTotalMemory()) {
                queryContext.setPeakNodeTotalMemory(totalMemoryReservationBytes);
            }
            if (totalMemoryReservationBytes > pool.getMaxBytes() * memoryRevokingThreshold && memoryRevokeRequestInProgress.compareAndSet(false, true)) {
                memoryRevocationExecutor.execute(() -> {
                    try {
                        AtomicLong remainingBytesToRevoke = new AtomicLong(totalMemoryReservationBytes - (long) (memoryRevokingTarget * pool.getMaxBytes()));
                        remainingBytesToRevoke.addAndGet(-MemoryRevokingSchedulerUtils.getMemoryAlreadyBeingRevoked(ImmutableList.of(taskContext), remainingBytesToRevoke.get()));
                        taskContext.accept(new VoidTraversingQueryContextVisitor<AtomicLong>() {

                            @Override
                            public Void visitOperatorContext(OperatorContext operatorContext, AtomicLong remainingBytesToRevoke) {
                                if (remainingBytesToRevoke.get() > 0) {
                                    long revokedBytes = operatorContext.requestMemoryRevoking();
                                    if (revokedBytes > 0) {
                                        memoryRevokePending.set(true);
                                        remainingBytesToRevoke.addAndGet(-revokedBytes);
                                    }
                                }
                                return null;
                            }
                        }, remainingBytesToRevoke);
                        memoryRevokeRequestInProgress.set(false);
                    } catch (Exception e) {
                        log.error(e, "Error requesting memory revoking");
                    }
                });
            }
            // Get the latest memory reservation info since it might have changed due to revoke
            long totalReservedMemory = pool.getQueryMemoryReservation(queryId) + pool.getQueryRevocableMemoryReservation(queryId);
            // If total memory usage is over maxTotalMemory and memory revoke request is not pending, fail the query with EXCEEDED_MEMORY_LIMIT error
            if (totalReservedMemory > maxTotalMemory.toBytes() && !memoryRevokeRequestInProgress.get() && !isMemoryRevokePending(taskContext)) {
                throw exceededLocalTotalMemoryLimit(maxTotalMemory, queryContext.getAdditionalFailureInfo(totalReservedMemory, 0) + format("Total reserved memory: %s, Total revocable memory: %s", succinctBytes(pool.getQueryMemoryReservation(queryId)), succinctBytes(pool.getQueryRevocableMemoryReservation(queryId))), isHeapDumpOnExceededMemoryLimitEnabled(session), Optional.ofNullable(heapDumpFilePath));
            }
        });
    }
    ImmutableMap.Builder<PlanNodeId, List<PrestoSparkShuffleInput>> shuffleInputs = ImmutableMap.builder();
    ImmutableMap.Builder<PlanNodeId, List<java.util.Iterator<PrestoSparkSerializedPage>>> pageInputs = ImmutableMap.builder();
    ImmutableMap.Builder<PlanNodeId, List<?>> broadcastInputs = ImmutableMap.builder();
    for (RemoteSourceNode remoteSource : fragment.getRemoteSourceNodes()) {
        List<PrestoSparkShuffleInput> remoteSourceRowInputs = new ArrayList<>();
        List<java.util.Iterator<PrestoSparkSerializedPage>> remoteSourcePageInputs = new ArrayList<>();
        List<List<?>> broadcastInputsList = new ArrayList<>();
        for (PlanFragmentId sourceFragmentId : remoteSource.getSourceFragmentIds()) {
            Iterator<Tuple2<MutablePartitionId, PrestoSparkMutableRow>> shuffleInput = inputs.getShuffleInputs().get(sourceFragmentId.toString());
            Broadcast<?> broadcastInput = inputs.getBroadcastInputs().get(sourceFragmentId.toString());
            List<PrestoSparkSerializedPage> inMemoryInput = inputs.getInMemoryInputs().get(sourceFragmentId.toString());
            if (shuffleInput != null) {
                checkArgument(broadcastInput == null, "single remote source is not expected to accept different kind of inputs");
                checkArgument(inMemoryInput == null, "single remote source is not expected to accept different kind of inputs");
                remoteSourceRowInputs.add(new PrestoSparkShuffleInput(sourceFragmentId.getId(), shuffleInput));
                continue;
            }
            if (broadcastInput != null) {
                checkArgument(inMemoryInput == null, "single remote source is not expected to accept different kind of inputs");
                // TODO: Enable NullifyingIterator once migrated to one task per JVM model
                // NullifyingIterator removes element from the list upon return
                // This allows GC to gradually reclaim memory
                // remoteSourcePageInputs.add(getNullifyingIterator(broadcastInput.value()));
                broadcastInputsList.add((List<?>) broadcastInput.value());
                continue;
            }
            if (inMemoryInput != null) {
                // for inmemory inputs pages can be released incrementally to save memory
                remoteSourcePageInputs.add(getNullifyingIterator(inMemoryInput));
                continue;
            }
            throw new IllegalArgumentException("Input not found for sourceFragmentId: " + sourceFragmentId);
        }
        if (!remoteSourceRowInputs.isEmpty()) {
            shuffleInputs.put(remoteSource.getId(), remoteSourceRowInputs);
        }
        if (!remoteSourcePageInputs.isEmpty()) {
            pageInputs.put(remoteSource.getId(), remoteSourcePageInputs);
        }
        if (!broadcastInputsList.isEmpty()) {
            broadcastInputs.put(remoteSource.getId(), broadcastInputsList);
        }
    }
    OutputBufferMemoryManager memoryManager = new OutputBufferMemoryManager(sinkMaxBufferSize.toBytes(), () -> queryContext.getTaskContextByTaskId(taskId).localSystemMemoryContext(), notificationExecutor);
    Optional<OutputPartitioning> preDeterminedPartition = Optional.empty();
    if (fragment.getPartitioningScheme().getPartitioning().getHandle().equals(FIXED_ARBITRARY_DISTRIBUTION)) {
        int partitionCount = getHashPartitionCount(session);
        preDeterminedPartition = Optional.of(new OutputPartitioning(new PreDeterminedPartitionFunction(partitionId % partitionCount, partitionCount), ImmutableList.of(), ImmutableList.of(), false, OptionalInt.empty()));
    }
    TempDataOperationContext tempDataOperationContext = new TempDataOperationContext(session.getSource(), session.getQueryId().getId(), session.getClientInfo(), Optional.of(session.getClientTags()), session.getIdentity());
    TempStorage tempStorage = tempStorageManager.getTempStorage(storageBasedBroadcastJoinStorage);
    Output<T> output = configureOutput(outputType, blockEncodingManager, memoryManager, getShuffleOutputTargetAverageRowSize(session), preDeterminedPartition, tempStorage, tempDataOperationContext, getStorageBasedBroadcastJoinWriteBufferSize(session));
    PrestoSparkOutputBuffer<?> outputBuffer = output.getOutputBuffer();
    LocalExecutionPlan localExecutionPlan = localExecutionPlanner.plan(taskContext, fragment.getRoot(), fragment.getPartitioningScheme(), fragment.getStageExecutionDescriptor(), fragment.getTableScanSchedulingOrder(), output.getOutputFactory(), new PrestoSparkRemoteSourceFactory(blockEncodingManager, shuffleInputs.build(), pageInputs.build(), broadcastInputs.build(), partitionId, shuffleStatsCollector, tempStorage, tempDataOperationContext, prestoSparkBroadcastTableCacheManager, stageId), taskDescriptor.getTableWriteInfo(), true);
    taskStateMachine.addStateChangeListener(state -> {
        if (state.isDone()) {
            outputBuffer.setNoMoreRows();
        }
    });
    PrestoSparkTaskExecution taskExecution = new PrestoSparkTaskExecution(taskStateMachine, taskContext, localExecutionPlan, taskExecutor, splitMonitor, notificationExecutor, memoryUpdateExecutor);
    taskExecution.start(taskSources);
    return new PrestoSparkTaskExecutor<>(taskContext, taskStateMachine, output.getOutputSupplier(), taskInfoCodec, taskInfoCollector, shuffleStatsCollector, executionExceptionFactory, output.getOutputBufferType(), outputBuffer, tempStorage, tempDataOperationContext);
}
Also used : StageId(com.facebook.presto.execution.StageId) ArrayList(java.util.ArrayList) PlanFragment(com.facebook.presto.sql.planner.PlanFragment) TaskStateMachine(com.facebook.presto.execution.TaskStateMachine) PlanNodeId(com.facebook.presto.spi.plan.PlanNodeId) RemoteSourceNode(com.facebook.presto.sql.planner.plan.RemoteSourceNode) DataSize(io.airlift.units.DataSize) OperatorContext(com.facebook.presto.operator.OperatorContext) OutputBufferMemoryManager(com.facebook.presto.execution.buffer.OutputBufferMemoryManager) ArrayList(java.util.ArrayList) List(java.util.List) ImmutableList(com.google.common.collect.ImmutableList) TempDataOperationContext(com.facebook.presto.spi.storage.TempDataOperationContext) PrestoSparkSessionProperties.getSparkBroadcastJoinMaxMemoryOverride(com.facebook.presto.spark.PrestoSparkSessionProperties.getSparkBroadcastJoinMaxMemoryOverride) PreDeterminedPartitionFunction(com.facebook.presto.spark.execution.PrestoSparkRowOutputOperator.PreDeterminedPartitionFunction) IPrestoSparkTaskExecutor(com.facebook.presto.spark.classloader_interface.IPrestoSparkTaskExecutor) ImmutableMap(com.google.common.collect.ImmutableMap) TokenAuthenticator(com.facebook.presto.spi.security.TokenAuthenticator) SerializedPrestoSparkTaskDescriptor(com.facebook.presto.spark.classloader_interface.SerializedPrestoSparkTaskDescriptor) PrestoSparkTaskDescriptor(com.facebook.presto.spark.PrestoSparkTaskDescriptor) TaskId(com.facebook.presto.execution.TaskId) StageExecutionId(com.facebook.presto.execution.StageExecutionId) PrestoSparkUtils.getNullifyingIterator(com.facebook.presto.spark.util.PrestoSparkUtils.getNullifyingIterator) AbstractIterator(scala.collection.AbstractIterator) Iterator(scala.collection.Iterator) TestingGcMonitor(com.facebook.airlift.stats.TestingGcMonitor) PlanFragmentId(com.facebook.presto.sql.planner.plan.PlanFragmentId) MemoryPoolId(com.facebook.presto.spi.memory.MemoryPoolId) SpillSpaceTracker(com.facebook.presto.spiller.SpillSpaceTracker) TaskContext(com.facebook.presto.operator.TaskContext) QueryContext(com.facebook.presto.memory.QueryContext) UncheckedIOException(java.io.UncheckedIOException) IOException(java.io.IOException) NoSuchElementException(java.util.NoSuchElementException) PrestoSparkUtils.toPrestoSparkSerializedPage(com.facebook.presto.spark.util.PrestoSparkUtils.toPrestoSparkSerializedPage) PrestoSparkSerializedPage(com.facebook.presto.spark.classloader_interface.PrestoSparkSerializedPage) LocalExecutionPlan(com.facebook.presto.sql.planner.LocalExecutionPlanner.LocalExecutionPlan) AtomicLong(java.util.concurrent.atomic.AtomicLong) TempStorage(com.facebook.presto.spi.storage.TempStorage) Tuple2(scala.Tuple2) OptionalLong(java.util.OptionalLong) SerializedPrestoSparkTaskSource(com.facebook.presto.spark.classloader_interface.SerializedPrestoSparkTaskSource) TaskSource(com.facebook.presto.execution.TaskSource) OutputPartitioning(com.facebook.presto.sql.planner.OutputPartitioning) Session(com.facebook.presto.Session) MemoryPool(com.facebook.presto.memory.MemoryPool)

Example 2 with MemoryPool

use of com.facebook.presto.memory.MemoryPool in project presto by prestodb.

the class GroupByHashYieldAssertion method finishOperatorWithYieldingGroupByHash.

/**
 * @param operatorFactory creates an Operator that should directly or indirectly contain GroupByHash
 * @param getHashCapacity returns the hash table capacity for the input operator
 * @param additionalMemoryInBytes the memory used in addition to the GroupByHash in the operator (e.g., aggregator)
 */
public static GroupByHashYieldResult finishOperatorWithYieldingGroupByHash(List<Page> input, Type hashKeyType, OperatorFactory operatorFactory, Function<Operator, Integer> getHashCapacity, long additionalMemoryInBytes) {
    assertLessThan(additionalMemoryInBytes, 1L << 21, "additionalMemoryInBytes should be a relatively small number");
    List<Page> result = new LinkedList<>();
    // mock an adjustable memory pool
    QueryId queryId1 = new QueryId("test_query1");
    QueryId queryId2 = new QueryId("test_query2");
    MemoryPool memoryPool = new MemoryPool(new MemoryPoolId("test"), new DataSize(1, GIGABYTE));
    QueryContext queryContext = new QueryContext(queryId2, new DataSize(512, MEGABYTE), new DataSize(1024, MEGABYTE), new DataSize(512, MEGABYTE), new DataSize(1, GIGABYTE), memoryPool, new TestingGcMonitor(), EXECUTOR, SCHEDULED_EXECUTOR, new DataSize(512, MEGABYTE), new SpillSpaceTracker(new DataSize(512, MEGABYTE)), listJsonCodec(TaskMemoryReservationSummary.class));
    DriverContext driverContext = createTaskContext(queryContext, EXECUTOR, TEST_SESSION).addPipelineContext(0, true, true, false).addDriverContext();
    Operator operator = operatorFactory.createOperator(driverContext);
    // run operator
    int yieldCount = 0;
    long expectedReservedExtraBytes = 0;
    for (Page page : input) {
        // unblocked
        assertTrue(operator.needsInput());
        // saturate the pool with a tiny memory left
        long reservedMemoryInBytes = memoryPool.getFreeBytes() - additionalMemoryInBytes;
        memoryPool.reserve(queryId1, "test", reservedMemoryInBytes);
        long oldMemoryUsage = operator.getOperatorContext().getDriverContext().getMemoryUsage();
        int oldCapacity = getHashCapacity.apply(operator);
        // add a page and verify different behaviors
        operator.addInput(page);
        // get output to consume the input
        Page output = operator.getOutput();
        if (output != null) {
            result.add(output);
        }
        long newMemoryUsage = operator.getOperatorContext().getDriverContext().getMemoryUsage();
        // between rehash and memory used by aggregator
        if (newMemoryUsage < new DataSize(4, MEGABYTE).toBytes()) {
            // free the pool for the next iteration
            memoryPool.free(queryId1, "test", reservedMemoryInBytes);
            // this required in case input is blocked
            operator.getOutput();
            continue;
        }
        long actualIncreasedMemory = newMemoryUsage - oldMemoryUsage;
        if (operator.needsInput()) {
            // We have successfully added a page
            // Assert we are not blocked
            assertTrue(operator.getOperatorContext().isWaitingForMemory().isDone());
            // assert the hash capacity is not changed; otherwise, we should have yielded
            assertTrue(oldCapacity == getHashCapacity.apply(operator));
            // We are not going to rehash; therefore, assert the memory increase only comes from the aggregator
            assertLessThan(actualIncreasedMemory, additionalMemoryInBytes);
            // free the pool for the next iteration
            memoryPool.free(queryId1, "test", reservedMemoryInBytes);
        } else {
            // We failed to finish the page processing i.e. we yielded
            yieldCount++;
            // Assert we are blocked
            assertFalse(operator.getOperatorContext().isWaitingForMemory().isDone());
            // Hash table capacity should not change
            assertEquals(oldCapacity, (long) getHashCapacity.apply(operator));
            // Increased memory is no smaller than the hash table size and no greater than the hash table size + the memory used by aggregator
            if (hashKeyType == BIGINT) {
                // groupIds and values double by hashCapacity; while valuesByGroupId double by maxFill = hashCapacity / 0.75
                expectedReservedExtraBytes = oldCapacity * (long) (Long.BYTES * 1.75 + Integer.BYTES) + page.getRetainedSizeInBytes();
            } else {
                // groupAddressByHash, groupIdsByHash, and rawHashByHashPosition double by hashCapacity; while groupAddressByGroupId double by maxFill = hashCapacity / 0.75
                expectedReservedExtraBytes = oldCapacity * (long) (Long.BYTES * 1.75 + Integer.BYTES + Byte.BYTES) + page.getRetainedSizeInBytes();
            }
            assertBetweenInclusive(actualIncreasedMemory, expectedReservedExtraBytes, expectedReservedExtraBytes + additionalMemoryInBytes);
            // Output should be blocked as well
            assertNull(operator.getOutput());
            // Free the pool to unblock
            memoryPool.free(queryId1, "test", reservedMemoryInBytes);
            // Trigger a process through getOutput() or needsInput()
            output = operator.getOutput();
            if (output != null) {
                result.add(output);
            }
            assertTrue(operator.needsInput());
            // Hash table capacity has increased
            assertGreaterThan(getHashCapacity.apply(operator), oldCapacity);
            // Assert the estimated reserved memory before rehash is very close to the one after rehash
            long rehashedMemoryUsage = operator.getOperatorContext().getDriverContext().getMemoryUsage();
            assertBetweenInclusive(rehashedMemoryUsage * 1.0 / newMemoryUsage, 0.99, 1.01);
            // unblocked
            assertTrue(operator.needsInput());
        }
    }
    result.addAll(finishOperator(operator));
    return new GroupByHashYieldResult(yieldCount, expectedReservedExtraBytes, result);
}
Also used : OperatorAssertion.finishOperator(com.facebook.presto.operator.OperatorAssertion.finishOperator) SpillSpaceTracker(com.facebook.presto.spiller.SpillSpaceTracker) QueryId(com.facebook.presto.spi.QueryId) Page(com.facebook.presto.common.Page) QueryContext(com.facebook.presto.memory.QueryContext) LinkedList(java.util.LinkedList) DataSize(io.airlift.units.DataSize) TestingGcMonitor(com.facebook.airlift.stats.TestingGcMonitor) MemoryPoolId(com.facebook.presto.spi.memory.MemoryPoolId) MemoryPool(com.facebook.presto.memory.MemoryPool)

Example 3 with MemoryPool

use of com.facebook.presto.memory.MemoryPool in project presto by prestodb.

the class MemoryLocalQueryRunner method execute.

public List<Page> execute(@Language("SQL") String query) {
    MemoryPool memoryPool = new MemoryPool(new MemoryPoolId("test"), new DataSize(2, GIGABYTE));
    SpillSpaceTracker spillSpaceTracker = new SpillSpaceTracker(new DataSize(1, GIGABYTE));
    QueryContext queryContext = new QueryContext(new QueryId("test"), new DataSize(1, GIGABYTE), new DataSize(2, GIGABYTE), new DataSize(1, GIGABYTE), new DataSize(2, GIGABYTE), memoryPool, new TestingGcMonitor(), localQueryRunner.getExecutor(), localQueryRunner.getScheduler(), new DataSize(4, GIGABYTE), spillSpaceTracker, listJsonCodec(TaskMemoryReservationSummary.class));
    TaskContext taskContext = queryContext.addTaskContext(new TaskStateMachine(new TaskId("query", 0, 0, 0), localQueryRunner.getExecutor()), localQueryRunner.getDefaultSession(), Optional.empty(), false, false, false, false, false);
    // Use NullOutputFactory to avoid coping out results to avoid affecting benchmark results
    ImmutableList.Builder<Page> output = ImmutableList.builder();
    List<Driver> drivers = localQueryRunner.createDrivers(query, new PageConsumerOperator.PageConsumerOutputFactory(types -> output::add), taskContext);
    boolean done = false;
    while (!done) {
        boolean processed = false;
        for (Driver driver : drivers) {
            if (!driver.isFinished()) {
                driver.process();
                processed = true;
            }
        }
        done = !processed;
    }
    return output.build();
}
Also used : TaskMemoryReservationSummary(com.facebook.presto.operator.TaskMemoryReservationSummary) Page(com.facebook.presto.common.Page) PageConsumerOperator(com.facebook.presto.testing.PageConsumerOperator) JsonCodec.listJsonCodec(com.facebook.airlift.json.JsonCodec.listJsonCodec) MemoryConnectorFactory(com.facebook.presto.plugin.memory.MemoryConnectorFactory) SpillSpaceTracker(com.facebook.presto.spiller.SpillSpaceTracker) GIGABYTE(io.airlift.units.DataSize.Unit.GIGABYTE) ImmutableList(com.google.common.collect.ImmutableList) Map(java.util.Map) QualifiedObjectName(com.facebook.presto.common.QualifiedObjectName) QueryContext(com.facebook.presto.memory.QueryContext) TableHandle(com.facebook.presto.spi.TableHandle) MemoryPool(com.facebook.presto.memory.MemoryPool) LocalQueryRunner(com.facebook.presto.testing.LocalQueryRunner) TaskContext(com.facebook.presto.operator.TaskContext) ImmutableMap(com.google.common.collect.ImmutableMap) Language(org.intellij.lang.annotations.Language) Session(com.facebook.presto.Session) TpchConnectorFactory(com.facebook.presto.tpch.TpchConnectorFactory) TestingGcMonitor(com.facebook.airlift.stats.TestingGcMonitor) Driver(com.facebook.presto.operator.Driver) TestingSession.testSessionBuilder(com.facebook.presto.testing.TestingSession.testSessionBuilder) Plugin(com.facebook.presto.spi.Plugin) MemoryPoolId(com.facebook.presto.spi.memory.MemoryPoolId) DataSize(io.airlift.units.DataSize) List(java.util.List) TaskId(com.facebook.presto.execution.TaskId) QueryId(com.facebook.presto.spi.QueryId) Optional(java.util.Optional) Assert.assertTrue(org.testng.Assert.assertTrue) TaskMemoryReservationSummary(com.facebook.presto.operator.TaskMemoryReservationSummary) TaskStateMachine(com.facebook.presto.execution.TaskStateMachine) Metadata(com.facebook.presto.metadata.Metadata) SpillSpaceTracker(com.facebook.presto.spiller.SpillSpaceTracker) TaskContext(com.facebook.presto.operator.TaskContext) TaskId(com.facebook.presto.execution.TaskId) ImmutableList(com.google.common.collect.ImmutableList) QueryId(com.facebook.presto.spi.QueryId) Driver(com.facebook.presto.operator.Driver) Page(com.facebook.presto.common.Page) QueryContext(com.facebook.presto.memory.QueryContext) TaskStateMachine(com.facebook.presto.execution.TaskStateMachine) PageConsumerOperator(com.facebook.presto.testing.PageConsumerOperator) DataSize(io.airlift.units.DataSize) TestingGcMonitor(com.facebook.airlift.stats.TestingGcMonitor) MemoryPoolId(com.facebook.presto.spi.memory.MemoryPoolId) MemoryPool(com.facebook.presto.memory.MemoryPool)

Example 4 with MemoryPool

use of com.facebook.presto.memory.MemoryPool in project presto by prestodb.

the class TestSqlTask method createInitialTask.

public SqlTask createInitialTask() {
    TaskId taskId = new TaskId("query", 0, 0, nextTaskId.incrementAndGet());
    URI location = URI.create("fake://task/" + taskId);
    QueryContext queryContext = new QueryContext(new QueryId("query"), new DataSize(1, MEGABYTE), new DataSize(2, MEGABYTE), new DataSize(1, MEGABYTE), new DataSize(1, GIGABYTE), new MemoryPool(new MemoryPoolId("test"), new DataSize(1, GIGABYTE)), new TestingGcMonitor(), taskNotificationExecutor, driverYieldExecutor, new DataSize(1, MEGABYTE), new SpillSpaceTracker(new DataSize(1, GIGABYTE)), listJsonCodec(TaskMemoryReservationSummary.class));
    queryContext.addTaskContext(new TaskStateMachine(taskId, taskNotificationExecutor), testSessionBuilder().build(), Optional.of(PLAN_FRAGMENT.getRoot()), false, false, false, false, false);
    return createSqlTask(taskId, location, "fake", queryContext, sqlTaskExecutionFactory, new MockExchangeClientSupplier(), taskNotificationExecutor, Functions.identity(), new DataSize(32, MEGABYTE), new CounterStat(), new SpoolingOutputBufferFactory(new FeaturesConfig()));
}
Also used : TaskMemoryReservationSummary(com.facebook.presto.operator.TaskMemoryReservationSummary) SpillSpaceTracker(com.facebook.presto.spiller.SpillSpaceTracker) MockExchangeClientSupplier(com.facebook.presto.execution.TestSqlTaskManager.MockExchangeClientSupplier) CounterStat(com.facebook.airlift.stats.CounterStat) FeaturesConfig(com.facebook.presto.sql.analyzer.FeaturesConfig) QueryId(com.facebook.presto.spi.QueryId) QueryContext(com.facebook.presto.memory.QueryContext) URI(java.net.URI) DataSize(io.airlift.units.DataSize) TestingGcMonitor(com.facebook.airlift.stats.TestingGcMonitor) SpoolingOutputBufferFactory(com.facebook.presto.execution.buffer.SpoolingOutputBufferFactory) MemoryPoolId(com.facebook.presto.spi.memory.MemoryPoolId) MemoryPool(com.facebook.presto.memory.MemoryPool)

Example 5 with MemoryPool

use of com.facebook.presto.memory.MemoryPool in project presto by prestodb.

the class TestMemoryRevokingScheduler method testQueryMemoryRevoking.

@Test
public void testQueryMemoryRevoking() throws Exception {
    // The various tasks created here use a small amount of system memory independent of what's set explicitly
    // in this test. Triggering spilling based on differences of thousands of bytes rather than hundreds
    // makes the test resilient to any noise that creates.
    // There can still be a race condition where some of these allocations are made when the total memory is above
    // the spill threshold, but in revokeMemory() some memory is reduced between when we get the total memory usage
    // and when we get the task memory usage.  This can cause some extra spilling.
    // To prevent flakiness in the test, we reset revoke memory requested for all operators, even if only one spilled.
    QueryId queryId = new QueryId("query");
    // use a larger memory pool so that we don't trigger spilling due to filling the memory pool
    MemoryPool queryLimitMemoryPool = new MemoryPool(new MemoryPoolId("test"), new DataSize(100, GIGABYTE));
    SqlTask sqlTask1 = newSqlTask(queryId, queryLimitMemoryPool);
    TestOperatorContext operatorContext11 = createTestingOperatorContexts(sqlTask1, "operator11");
    TestOperatorContext operatorContext12 = createTestingOperatorContexts(sqlTask1, "operator12");
    SqlTask sqlTask2 = newSqlTask(queryId, queryLimitMemoryPool);
    TestOperatorContext operatorContext2 = createTestingOperatorContexts(sqlTask2, "operator2");
    allOperatorContexts = ImmutableSet.of(operatorContext11, operatorContext12, operatorContext2);
    List<SqlTask> tasks = ImmutableList.of(sqlTask1, sqlTask2);
    MemoryRevokingScheduler scheduler = new MemoryRevokingScheduler(singletonList(queryLimitMemoryPool), () -> tasks, queryContexts::get, 1.0, 1.0, ORDER_BY_REVOCABLE_BYTES, true);
    try {
        scheduler.start();
        assertMemoryRevokingNotRequested();
        operatorContext11.localRevocableMemoryContext().setBytes(150_000);
        operatorContext2.localRevocableMemoryContext().setBytes(100_000);
        // at this point, Task1 = 150k total bytes, Task2 = 100k total bytes
        // this ensures that we are waiting for the memory revocation listener and not using polling-based revoking
        scheduler.awaitAsynchronousCallbacksRun();
        assertMemoryRevokingNotRequested();
        operatorContext12.localRevocableMemoryContext().setBytes(300_000);
        // at this point, Task1 =  450k total bytes, Task2 = 100k total bytes
        scheduler.awaitAsynchronousCallbacksRun();
        // only operator11 should revoke since we need to revoke only 50k bytes
        // limit - (task1 + task2) => 500k - (450k + 100k) = 50k byte to revoke
        assertMemoryRevokingRequestedFor(operatorContext11);
        // revoke all bytes in operator11
        operatorContext11.localRevocableMemoryContext().setBytes(0);
        // at this point, Task1 = 300k total bytes, Task2 = 100k total bytes
        scheduler.awaitAsynchronousCallbacksRun();
        operatorContext11.resetMemoryRevokingRequested();
        operatorContext12.resetMemoryRevokingRequested();
        operatorContext2.resetMemoryRevokingRequested();
        assertMemoryRevokingNotRequested();
        operatorContext11.localRevocableMemoryContext().setBytes(20_000);
        // at this point, Task1 = 320,000 total bytes (oc11 - 20k, oc12 - 300k), Task2 = 100k total bytes
        scheduler.awaitAsynchronousCallbacksRun();
        assertMemoryRevokingNotRequested();
        operatorContext2.localSystemMemoryContext().setBytes(150_000);
        // at this point, Task1 = 320K total bytes, Task2 = 250K total bytes
        // both operator11 and operator 12 are revoking since we revoke in order of operator creation within the task until we are below the memory revoking threshold
        scheduler.awaitAsynchronousCallbacksRun();
        assertMemoryRevokingRequestedFor(operatorContext11, operatorContext12);
        operatorContext11.localRevocableMemoryContext().setBytes(0);
        operatorContext12.localRevocableMemoryContext().setBytes(0);
        scheduler.awaitAsynchronousCallbacksRun();
        operatorContext11.resetMemoryRevokingRequested();
        operatorContext12.resetMemoryRevokingRequested();
        operatorContext2.resetMemoryRevokingRequested();
        assertMemoryRevokingNotRequested();
        operatorContext11.localRevocableMemoryContext().setBytes(50_000);
        operatorContext12.localRevocableMemoryContext().setBytes(50_000);
        operatorContext2.localSystemMemoryContext().setBytes(150_000);
        operatorContext2.localRevocableMemoryContext().setBytes(150_000);
        scheduler.awaitAsynchronousCallbacksRun();
        // no need to revoke
        assertMemoryRevokingNotRequested();
        // at this point, Task1 = 75k total bytes, Task2 = 300k total bytes (150k revocable, 150k system)
        operatorContext12.localUserMemoryContext().setBytes(300_000);
        // at this point, Task1 = 400K total bytes (100k revocable, 300k user), Task2 = 300k total bytes (150k revocable, 150k system)
        scheduler.awaitAsynchronousCallbacksRun();
        assertMemoryRevokingRequestedFor(operatorContext2, operatorContext11);
    } finally {
        scheduler.stop();
    }
}
Also used : SqlTask.createSqlTask(com.facebook.presto.execution.SqlTask.createSqlTask) QueryId(com.facebook.presto.spi.QueryId) DataSize(io.airlift.units.DataSize) MemoryPoolId(com.facebook.presto.spi.memory.MemoryPoolId) MemoryPool(com.facebook.presto.memory.MemoryPool) Test(org.testng.annotations.Test)

Aggregations

MemoryPool (com.facebook.presto.memory.MemoryPool)12 DataSize (io.airlift.units.DataSize)10 QueryId (com.facebook.presto.spi.QueryId)9 MemoryPoolId (com.facebook.presto.spi.memory.MemoryPoolId)9 QueryContext (com.facebook.presto.memory.QueryContext)6 TestingGcMonitor (com.facebook.airlift.stats.TestingGcMonitor)5 SpillSpaceTracker (com.facebook.presto.spiller.SpillSpaceTracker)5 TaskContext (com.facebook.presto.operator.TaskContext)4 Session (com.facebook.presto.Session)3 SqlTask.createSqlTask (com.facebook.presto.execution.SqlTask.createSqlTask)3 TaskId (com.facebook.presto.execution.TaskId)3 TaskStateMachine (com.facebook.presto.execution.TaskStateMachine)3 OperatorContext (com.facebook.presto.operator.OperatorContext)3 Page (com.facebook.presto.common.Page)2 TaskMemoryReservationSummary (com.facebook.presto.operator.TaskMemoryReservationSummary)2 ImmutableList (com.google.common.collect.ImmutableList)2 ImmutableMap (com.google.common.collect.ImmutableMap)2 List (java.util.List)2 Test (org.testng.annotations.Test)2 JsonCodec.listJsonCodec (com.facebook.airlift.json.JsonCodec.listJsonCodec)1