Search in sources :

Example 1 with MemoryPoolId

use of com.facebook.presto.spi.memory.MemoryPoolId in project presto by prestodb.

the class TestMemoryPools method testBlocking.

@Test
public void testBlocking() throws Exception {
    Session session = testSessionBuilder().setCatalog("tpch").setSchema("tiny").setSystemProperty("task_default_concurrency", "1").build();
    LocalQueryRunner localQueryRunner = queryRunnerWithInitialTransaction(session);
    // add tpch
    localQueryRunner.createCatalog("tpch", new TpchConnectorFactory(1), ImmutableMap.of());
    // reserve all the memory in the pool
    MemoryPool pool = new MemoryPool(new MemoryPoolId("test"), new DataSize(10, MEGABYTE));
    QueryId fakeQueryId = new QueryId("fake");
    assertTrue(pool.tryReserve(fakeQueryId, TEN_MEGABYTES));
    MemoryPool systemPool = new MemoryPool(new MemoryPoolId("testSystem"), new DataSize(10, MEGABYTE));
    QueryContext queryContext = new QueryContext(new QueryId("query"), new DataSize(10, MEGABYTE), pool, systemPool, localQueryRunner.getExecutor());
    // discard all output
    OutputFactory outputFactory = new PageConsumerOutputFactory(types -> (page -> {
    }));
    TaskContext taskContext = createTaskContext(queryContext, localQueryRunner.getExecutor(), session);
    List<Driver> drivers = localQueryRunner.createDrivers("SELECT COUNT(*) FROM orders JOIN lineitem USING (orderkey)", outputFactory, taskContext);
    // run driver, until it blocks
    while (!isWaitingForMemory(drivers)) {
        for (Driver driver : drivers) {
            driver.process();
        }
    }
    // driver should be blocked waiting for memory
    for (Driver driver : drivers) {
        assertFalse(driver.isFinished());
    }
    assertTrue(pool.getFreeBytes() <= 0);
    pool.free(fakeQueryId, TEN_MEGABYTES);
    do {
        assertFalse(isWaitingForMemory(drivers));
        boolean progress = false;
        for (Driver driver : drivers) {
            ListenableFuture<?> blocked = driver.process();
            progress = progress | blocked.isDone();
        }
        // query should not block
        assertTrue(progress);
    } while (!drivers.stream().allMatch(Driver::isFinished));
}
Also used : LocalQueryRunner(com.facebook.presto.testing.LocalQueryRunner) PageConsumerOutputFactory(com.facebook.presto.testing.PageConsumerOperator.PageConsumerOutputFactory) TaskContext(com.facebook.presto.operator.TaskContext) ListenableFuture(com.google.common.util.concurrent.ListenableFuture) TestingTaskContext.createTaskContext(com.facebook.presto.testing.TestingTaskContext.createTaskContext) ImmutableMap(com.google.common.collect.ImmutableMap) Session(com.facebook.presto.Session) MEGABYTE(io.airlift.units.DataSize.Unit.MEGABYTE) TpchConnectorFactory(com.facebook.presto.tpch.TpchConnectorFactory) OutputFactory(com.facebook.presto.operator.OutputFactory) Test(org.testng.annotations.Test) Driver(com.facebook.presto.operator.Driver) TestingSession.testSessionBuilder(com.facebook.presto.testing.TestingSession.testSessionBuilder) MemoryPoolId(com.facebook.presto.spi.memory.MemoryPoolId) DataSize(io.airlift.units.DataSize) List(java.util.List) QueryId(com.facebook.presto.spi.QueryId) Assert.assertTrue(org.testng.Assert.assertTrue) LocalQueryRunner.queryRunnerWithInitialTransaction(com.facebook.presto.testing.LocalQueryRunner.queryRunnerWithInitialTransaction) OperatorContext(com.facebook.presto.operator.OperatorContext) Assert.assertFalse(org.testng.Assert.assertFalse) TpchConnectorFactory(com.facebook.presto.tpch.TpchConnectorFactory) TaskContext(com.facebook.presto.operator.TaskContext) TestingTaskContext.createTaskContext(com.facebook.presto.testing.TestingTaskContext.createTaskContext) PageConsumerOutputFactory(com.facebook.presto.testing.PageConsumerOperator.PageConsumerOutputFactory) QueryId(com.facebook.presto.spi.QueryId) Driver(com.facebook.presto.operator.Driver) LocalQueryRunner(com.facebook.presto.testing.LocalQueryRunner) DataSize(io.airlift.units.DataSize) PageConsumerOutputFactory(com.facebook.presto.testing.PageConsumerOperator.PageConsumerOutputFactory) OutputFactory(com.facebook.presto.operator.OutputFactory) MemoryPoolId(com.facebook.presto.spi.memory.MemoryPoolId) Session(com.facebook.presto.Session) Test(org.testng.annotations.Test)

Example 2 with MemoryPoolId

use of com.facebook.presto.spi.memory.MemoryPoolId in project presto by prestodb.

the class ClusterMemoryManager method process.

public synchronized void process(Iterable<QueryExecution> queries) {
    if (!enabled) {
        return;
    }
    boolean outOfMemory = isClusterOutOfMemory();
    if (!outOfMemory) {
        lastTimeNotOutOfMemory = System.nanoTime();
    }
    boolean queryKilled = false;
    long totalBytes = 0;
    for (QueryExecution query : queries) {
        long bytes = query.getTotalMemoryReservation();
        DataSize sessionMaxQueryMemory = getQueryMaxMemory(query.getSession());
        long queryMemoryLimit = Math.min(maxQueryMemory.toBytes(), sessionMaxQueryMemory.toBytes());
        totalBytes += bytes;
        if (resourceOvercommit(query.getSession()) && outOfMemory) {
            // If a query has requested resource overcommit, only kill it if the cluster has run out of memory
            DataSize memory = succinctBytes(bytes);
            query.fail(new PrestoException(CLUSTER_OUT_OF_MEMORY, format("The cluster is out of memory and %s=true, so this query was killed. It was using %s of memory", RESOURCE_OVERCOMMIT, memory)));
            queryKilled = true;
        }
        if (!resourceOvercommit(query.getSession()) && bytes > queryMemoryLimit) {
            DataSize maxMemory = succinctBytes(queryMemoryLimit);
            query.fail(exceededGlobalLimit(maxMemory));
            queryKilled = true;
        }
    }
    clusterMemoryUsageBytes.set(totalBytes);
    if (killOnOutOfMemory) {
        boolean shouldKillQuery = nanosSince(lastTimeNotOutOfMemory).compareTo(killOnOutOfMemoryDelay) > 0 && outOfMemory;
        boolean lastKilledQueryIsGone = (lastKilledQuery == null);
        if (!lastKilledQueryIsGone) {
            ClusterMemoryPool generalPool = pools.get(GENERAL_POOL);
            if (generalPool != null) {
                lastKilledQueryIsGone = generalPool.getQueryMemoryReservations().containsKey(lastKilledQuery);
            }
        }
        if (shouldKillQuery && lastKilledQueryIsGone && !queryKilled) {
            // Kill the biggest query in the general pool
            QueryExecution biggestQuery = null;
            long maxMemory = -1;
            for (QueryExecution query : queries) {
                long bytesUsed = query.getTotalMemoryReservation();
                if (bytesUsed > maxMemory && query.getMemoryPool().getId().equals(GENERAL_POOL)) {
                    biggestQuery = query;
                    maxMemory = bytesUsed;
                }
            }
            if (biggestQuery != null) {
                biggestQuery.fail(new PrestoException(CLUSTER_OUT_OF_MEMORY, "The cluster is out of memory, and your query was killed. Please try again in a few minutes."));
                queriesKilledDueToOutOfMemory.incrementAndGet();
                lastKilledQuery = biggestQuery.getQueryId();
            }
        }
    }
    Map<MemoryPoolId, Integer> countByPool = new HashMap<>();
    for (QueryExecution query : queries) {
        MemoryPoolId id = query.getMemoryPool().getId();
        countByPool.put(id, countByPool.getOrDefault(id, 0) + 1);
    }
    updatePools(countByPool);
    updateNodes(updateAssignments(queries));
    // check if CPU usage is over limit
    for (QueryExecution query : queries) {
        Duration cpuTime = query.getTotalCpuTime();
        Duration sessionLimit = getQueryMaxCpuTime(query.getSession());
        Duration limit = maxQueryCpuTime.compareTo(sessionLimit) < 0 ? maxQueryCpuTime : sessionLimit;
        if (cpuTime.compareTo(limit) > 0) {
            query.fail(new ExceededCpuLimitException(limit));
        }
    }
}
Also used : HashMap(java.util.HashMap) DataSize(io.airlift.units.DataSize) PrestoException(com.facebook.presto.spi.PrestoException) Duration(io.airlift.units.Duration) ExceededCpuLimitException(com.facebook.presto.ExceededCpuLimitException) QueryExecution(com.facebook.presto.execution.QueryExecution) MemoryPoolId(com.facebook.presto.spi.memory.MemoryPoolId)

Example 3 with MemoryPoolId

use of com.facebook.presto.spi.memory.MemoryPoolId in project presto by prestodb.

the class PrestoSparkQueryExecutionFactory method createQueryInfo.

private static QueryInfo createQueryInfo(Session session, String query, QueryState queryState, Optional<PlanAndMore> planAndMore, Optional<String> sparkQueueName, Optional<ExecutionFailureInfo> failureInfo, QueryStateTimer queryStateTimer, Optional<StageInfo> rootStage, WarningCollector warningCollector) {
    checkArgument(failureInfo.isPresent() || queryState != FAILED, "unexpected query state: %s", queryState);
    int peakRunningTasks = 0;
    long peakUserMemoryReservationInBytes = 0;
    long peakTotalMemoryReservationInBytes = 0;
    long peakTaskUserMemoryInBytes = 0;
    long peakTaskTotalMemoryInBytes = 0;
    long peakNodeTotalMemoryInBytes = 0;
    for (StageInfo stageInfo : getAllStages(rootStage)) {
        StageExecutionInfo stageExecutionInfo = stageInfo.getLatestAttemptExecutionInfo();
        for (TaskInfo taskInfo : stageExecutionInfo.getTasks()) {
            // there's no way to know how many tasks were running in parallel in Spark
            // for now let's assume that all the tasks were running in parallel
            peakRunningTasks++;
            long taskPeakUserMemoryInBytes = taskInfo.getStats().getPeakUserMemoryInBytes();
            long taskPeakTotalMemoryInBytes = taskInfo.getStats().getPeakTotalMemoryInBytes();
            peakUserMemoryReservationInBytes += taskPeakUserMemoryInBytes;
            peakTotalMemoryReservationInBytes += taskPeakTotalMemoryInBytes;
            peakTaskUserMemoryInBytes = max(peakTaskUserMemoryInBytes, taskPeakUserMemoryInBytes);
            peakTaskTotalMemoryInBytes = max(peakTaskTotalMemoryInBytes, taskPeakTotalMemoryInBytes);
            peakNodeTotalMemoryInBytes = max(taskInfo.getStats().getPeakNodeTotalMemoryInBytes(), peakNodeTotalMemoryInBytes);
        }
    }
    QueryStats queryStats = QueryStats.create(queryStateTimer, rootStage, peakRunningTasks, succinctBytes(peakUserMemoryReservationInBytes), succinctBytes(peakTotalMemoryReservationInBytes), succinctBytes(peakTaskUserMemoryInBytes), succinctBytes(peakTaskTotalMemoryInBytes), succinctBytes(peakNodeTotalMemoryInBytes), session.getRuntimeStats());
    return new QueryInfo(session.getQueryId(), session.toSessionRepresentation(), queryState, new MemoryPoolId("spark-memory-pool"), queryStats.isScheduled(), URI.create("http://fake.invalid/query/" + session.getQueryId()), planAndMore.map(PlanAndMore::getFieldNames).orElse(ImmutableList.of()), query, Optional.empty(), Optional.empty(), queryStats, Optional.empty(), Optional.empty(), ImmutableMap.of(), ImmutableSet.of(), ImmutableMap.of(), ImmutableMap.of(), ImmutableSet.of(), Optional.empty(), false, planAndMore.flatMap(PlanAndMore::getUpdateType).orElse(null), rootStage, failureInfo.orElse(null), failureInfo.map(ExecutionFailureInfo::getErrorCode).orElse(null), warningCollector.getWarnings(), planAndMore.map(PlanAndMore::getInputs).orElse(ImmutableSet.of()), planAndMore.flatMap(PlanAndMore::getOutput), true, sparkQueueName.map(ResourceGroupId::new), planAndMore.flatMap(PlanAndMore::getQueryType), Optional.empty(), Optional.empty(), ImmutableMap.of(), ImmutableSet.of());
}
Also used : TaskInfo(com.facebook.presto.execution.TaskInfo) SerializedTaskInfo(com.facebook.presto.spark.classloader_interface.SerializedTaskInfo) PlanAndMore(com.facebook.presto.spark.planner.PrestoSparkQueryPlanner.PlanAndMore) QueryStats(com.facebook.presto.execution.QueryStats) StageInfo(com.facebook.presto.execution.StageInfo) StageExecutionInfo(com.facebook.presto.execution.StageExecutionInfo) BasicQueryInfo(com.facebook.presto.server.BasicQueryInfo) QueryInfo(com.facebook.presto.execution.QueryInfo) MemoryPoolId(com.facebook.presto.spi.memory.MemoryPoolId) ExecutionFailureInfo(com.facebook.presto.execution.ExecutionFailureInfo)

Example 4 with MemoryPoolId

use of com.facebook.presto.spi.memory.MemoryPoolId in project presto by prestodb.

the class PrestoSparkTaskExecutorFactory method doCreate.

public <T extends PrestoSparkTaskOutput> IPrestoSparkTaskExecutor<T> doCreate(int partitionId, int attemptNumber, SerializedPrestoSparkTaskDescriptor serializedTaskDescriptor, Iterator<SerializedPrestoSparkTaskSource> serializedTaskSources, PrestoSparkTaskInputs inputs, CollectionAccumulator<SerializedTaskInfo> taskInfoCollector, CollectionAccumulator<PrestoSparkShuffleStats> shuffleStatsCollector, Class<T> outputType) {
    PrestoSparkTaskDescriptor taskDescriptor = taskDescriptorJsonCodec.fromJson(serializedTaskDescriptor.getBytes());
    ImmutableMap.Builder<String, TokenAuthenticator> extraAuthenticators = ImmutableMap.builder();
    authenticatorProviders.forEach(provider -> extraAuthenticators.putAll(provider.getTokenAuthenticators()));
    Session session = taskDescriptor.getSession().toSession(sessionPropertyManager, taskDescriptor.getExtraCredentials(), extraAuthenticators.build());
    PlanFragment fragment = taskDescriptor.getFragment();
    StageId stageId = new StageId(session.getQueryId(), fragment.getId().getId());
    // Clear the cache if the cache does not have broadcast table for current stageId.
    // We will only cache 1 HT at any time. If the stageId changes, we will drop the old cached HT
    prestoSparkBroadcastTableCacheManager.removeCachedTablesForStagesOtherThan(stageId);
    // TODO: include attemptId in taskId
    TaskId taskId = new TaskId(new StageExecutionId(stageId, 0), partitionId);
    List<TaskSource> taskSources = getTaskSources(serializedTaskSources);
    log.info("Task [%s] received %d splits.", taskId, taskSources.stream().mapToInt(taskSource -> taskSource.getSplits().size()).sum());
    OptionalLong totalSplitSize = computeAllSplitsSize(taskSources);
    if (totalSplitSize.isPresent()) {
        log.info("Total split size: %s bytes.", totalSplitSize.getAsLong());
    }
    // TODO: Remove this once we can display the plan on Spark UI.
    log.info(PlanPrinter.textPlanFragment(fragment, functionAndTypeManager, session, true));
    DataSize maxUserMemory = new DataSize(min(nodeMemoryConfig.getMaxQueryMemoryPerNode().toBytes(), getQueryMaxMemoryPerNode(session).toBytes()), BYTE);
    DataSize maxTotalMemory = new DataSize(min(nodeMemoryConfig.getMaxQueryTotalMemoryPerNode().toBytes(), getQueryMaxTotalMemoryPerNode(session).toBytes()), BYTE);
    DataSize maxBroadcastMemory = getSparkBroadcastJoinMaxMemoryOverride(session);
    if (maxBroadcastMemory == null) {
        maxBroadcastMemory = new DataSize(min(nodeMemoryConfig.getMaxQueryBroadcastMemory().toBytes(), getQueryMaxBroadcastMemory(session).toBytes()), BYTE);
    }
    MemoryPool memoryPool = new MemoryPool(new MemoryPoolId("spark-executor-memory-pool"), maxTotalMemory);
    SpillSpaceTracker spillSpaceTracker = new SpillSpaceTracker(maxQuerySpillPerNode);
    QueryContext queryContext = new QueryContext(session.getQueryId(), maxUserMemory, maxTotalMemory, maxBroadcastMemory, maxRevocableMemory, memoryPool, new TestingGcMonitor(), notificationExecutor, yieldExecutor, maxQuerySpillPerNode, spillSpaceTracker, memoryReservationSummaryJsonCodec);
    queryContext.setVerboseExceededMemoryLimitErrorsEnabled(isVerboseExceededMemoryLimitErrorsEnabled(session));
    queryContext.setHeapDumpOnExceededMemoryLimitEnabled(isHeapDumpOnExceededMemoryLimitEnabled(session));
    String heapDumpFilePath = Paths.get(getHeapDumpFileDirectory(session), format("%s_%s.hprof", session.getQueryId().getId(), stageId.getId())).toString();
    queryContext.setHeapDumpFilePath(heapDumpFilePath);
    TaskStateMachine taskStateMachine = new TaskStateMachine(taskId, notificationExecutor);
    TaskContext taskContext = queryContext.addTaskContext(taskStateMachine, session, // Plan has to be retained only if verbose memory exceeded errors are requested
    isVerboseExceededMemoryLimitErrorsEnabled(session) ? Optional.of(fragment.getRoot()) : Optional.empty(), perOperatorCpuTimerEnabled, cpuTimerEnabled, perOperatorAllocationTrackingEnabled, allocationTrackingEnabled, false);
    final double memoryRevokingThreshold = getMemoryRevokingThreshold(session);
    final double memoryRevokingTarget = getMemoryRevokingTarget(session);
    checkArgument(memoryRevokingTarget <= memoryRevokingThreshold, "memoryRevokingTarget should be less than or equal memoryRevokingThreshold, but got %s and %s respectively", memoryRevokingTarget, memoryRevokingThreshold);
    if (isSpillEnabled(session)) {
        memoryPool.addListener((pool, queryId, totalMemoryReservationBytes) -> {
            if (totalMemoryReservationBytes > queryContext.getPeakNodeTotalMemory()) {
                queryContext.setPeakNodeTotalMemory(totalMemoryReservationBytes);
            }
            if (totalMemoryReservationBytes > pool.getMaxBytes() * memoryRevokingThreshold && memoryRevokeRequestInProgress.compareAndSet(false, true)) {
                memoryRevocationExecutor.execute(() -> {
                    try {
                        AtomicLong remainingBytesToRevoke = new AtomicLong(totalMemoryReservationBytes - (long) (memoryRevokingTarget * pool.getMaxBytes()));
                        remainingBytesToRevoke.addAndGet(-MemoryRevokingSchedulerUtils.getMemoryAlreadyBeingRevoked(ImmutableList.of(taskContext), remainingBytesToRevoke.get()));
                        taskContext.accept(new VoidTraversingQueryContextVisitor<AtomicLong>() {

                            @Override
                            public Void visitOperatorContext(OperatorContext operatorContext, AtomicLong remainingBytesToRevoke) {
                                if (remainingBytesToRevoke.get() > 0) {
                                    long revokedBytes = operatorContext.requestMemoryRevoking();
                                    if (revokedBytes > 0) {
                                        memoryRevokePending.set(true);
                                        remainingBytesToRevoke.addAndGet(-revokedBytes);
                                    }
                                }
                                return null;
                            }
                        }, remainingBytesToRevoke);
                        memoryRevokeRequestInProgress.set(false);
                    } catch (Exception e) {
                        log.error(e, "Error requesting memory revoking");
                    }
                });
            }
            // Get the latest memory reservation info since it might have changed due to revoke
            long totalReservedMemory = pool.getQueryMemoryReservation(queryId) + pool.getQueryRevocableMemoryReservation(queryId);
            // If total memory usage is over maxTotalMemory and memory revoke request is not pending, fail the query with EXCEEDED_MEMORY_LIMIT error
            if (totalReservedMemory > maxTotalMemory.toBytes() && !memoryRevokeRequestInProgress.get() && !isMemoryRevokePending(taskContext)) {
                throw exceededLocalTotalMemoryLimit(maxTotalMemory, queryContext.getAdditionalFailureInfo(totalReservedMemory, 0) + format("Total reserved memory: %s, Total revocable memory: %s", succinctBytes(pool.getQueryMemoryReservation(queryId)), succinctBytes(pool.getQueryRevocableMemoryReservation(queryId))), isHeapDumpOnExceededMemoryLimitEnabled(session), Optional.ofNullable(heapDumpFilePath));
            }
        });
    }
    ImmutableMap.Builder<PlanNodeId, List<PrestoSparkShuffleInput>> shuffleInputs = ImmutableMap.builder();
    ImmutableMap.Builder<PlanNodeId, List<java.util.Iterator<PrestoSparkSerializedPage>>> pageInputs = ImmutableMap.builder();
    ImmutableMap.Builder<PlanNodeId, List<?>> broadcastInputs = ImmutableMap.builder();
    for (RemoteSourceNode remoteSource : fragment.getRemoteSourceNodes()) {
        List<PrestoSparkShuffleInput> remoteSourceRowInputs = new ArrayList<>();
        List<java.util.Iterator<PrestoSparkSerializedPage>> remoteSourcePageInputs = new ArrayList<>();
        List<List<?>> broadcastInputsList = new ArrayList<>();
        for (PlanFragmentId sourceFragmentId : remoteSource.getSourceFragmentIds()) {
            Iterator<Tuple2<MutablePartitionId, PrestoSparkMutableRow>> shuffleInput = inputs.getShuffleInputs().get(sourceFragmentId.toString());
            Broadcast<?> broadcastInput = inputs.getBroadcastInputs().get(sourceFragmentId.toString());
            List<PrestoSparkSerializedPage> inMemoryInput = inputs.getInMemoryInputs().get(sourceFragmentId.toString());
            if (shuffleInput != null) {
                checkArgument(broadcastInput == null, "single remote source is not expected to accept different kind of inputs");
                checkArgument(inMemoryInput == null, "single remote source is not expected to accept different kind of inputs");
                remoteSourceRowInputs.add(new PrestoSparkShuffleInput(sourceFragmentId.getId(), shuffleInput));
                continue;
            }
            if (broadcastInput != null) {
                checkArgument(inMemoryInput == null, "single remote source is not expected to accept different kind of inputs");
                // TODO: Enable NullifyingIterator once migrated to one task per JVM model
                // NullifyingIterator removes element from the list upon return
                // This allows GC to gradually reclaim memory
                // remoteSourcePageInputs.add(getNullifyingIterator(broadcastInput.value()));
                broadcastInputsList.add((List<?>) broadcastInput.value());
                continue;
            }
            if (inMemoryInput != null) {
                // for inmemory inputs pages can be released incrementally to save memory
                remoteSourcePageInputs.add(getNullifyingIterator(inMemoryInput));
                continue;
            }
            throw new IllegalArgumentException("Input not found for sourceFragmentId: " + sourceFragmentId);
        }
        if (!remoteSourceRowInputs.isEmpty()) {
            shuffleInputs.put(remoteSource.getId(), remoteSourceRowInputs);
        }
        if (!remoteSourcePageInputs.isEmpty()) {
            pageInputs.put(remoteSource.getId(), remoteSourcePageInputs);
        }
        if (!broadcastInputsList.isEmpty()) {
            broadcastInputs.put(remoteSource.getId(), broadcastInputsList);
        }
    }
    OutputBufferMemoryManager memoryManager = new OutputBufferMemoryManager(sinkMaxBufferSize.toBytes(), () -> queryContext.getTaskContextByTaskId(taskId).localSystemMemoryContext(), notificationExecutor);
    Optional<OutputPartitioning> preDeterminedPartition = Optional.empty();
    if (fragment.getPartitioningScheme().getPartitioning().getHandle().equals(FIXED_ARBITRARY_DISTRIBUTION)) {
        int partitionCount = getHashPartitionCount(session);
        preDeterminedPartition = Optional.of(new OutputPartitioning(new PreDeterminedPartitionFunction(partitionId % partitionCount, partitionCount), ImmutableList.of(), ImmutableList.of(), false, OptionalInt.empty()));
    }
    TempDataOperationContext tempDataOperationContext = new TempDataOperationContext(session.getSource(), session.getQueryId().getId(), session.getClientInfo(), Optional.of(session.getClientTags()), session.getIdentity());
    TempStorage tempStorage = tempStorageManager.getTempStorage(storageBasedBroadcastJoinStorage);
    Output<T> output = configureOutput(outputType, blockEncodingManager, memoryManager, getShuffleOutputTargetAverageRowSize(session), preDeterminedPartition, tempStorage, tempDataOperationContext, getStorageBasedBroadcastJoinWriteBufferSize(session));
    PrestoSparkOutputBuffer<?> outputBuffer = output.getOutputBuffer();
    LocalExecutionPlan localExecutionPlan = localExecutionPlanner.plan(taskContext, fragment.getRoot(), fragment.getPartitioningScheme(), fragment.getStageExecutionDescriptor(), fragment.getTableScanSchedulingOrder(), output.getOutputFactory(), new PrestoSparkRemoteSourceFactory(blockEncodingManager, shuffleInputs.build(), pageInputs.build(), broadcastInputs.build(), partitionId, shuffleStatsCollector, tempStorage, tempDataOperationContext, prestoSparkBroadcastTableCacheManager, stageId), taskDescriptor.getTableWriteInfo(), true);
    taskStateMachine.addStateChangeListener(state -> {
        if (state.isDone()) {
            outputBuffer.setNoMoreRows();
        }
    });
    PrestoSparkTaskExecution taskExecution = new PrestoSparkTaskExecution(taskStateMachine, taskContext, localExecutionPlan, taskExecutor, splitMonitor, notificationExecutor, memoryUpdateExecutor);
    taskExecution.start(taskSources);
    return new PrestoSparkTaskExecutor<>(taskContext, taskStateMachine, output.getOutputSupplier(), taskInfoCodec, taskInfoCollector, shuffleStatsCollector, executionExceptionFactory, output.getOutputBufferType(), outputBuffer, tempStorage, tempDataOperationContext);
}
Also used : StageId(com.facebook.presto.execution.StageId) ArrayList(java.util.ArrayList) PlanFragment(com.facebook.presto.sql.planner.PlanFragment) TaskStateMachine(com.facebook.presto.execution.TaskStateMachine) PlanNodeId(com.facebook.presto.spi.plan.PlanNodeId) RemoteSourceNode(com.facebook.presto.sql.planner.plan.RemoteSourceNode) DataSize(io.airlift.units.DataSize) OperatorContext(com.facebook.presto.operator.OperatorContext) OutputBufferMemoryManager(com.facebook.presto.execution.buffer.OutputBufferMemoryManager) ArrayList(java.util.ArrayList) List(java.util.List) ImmutableList(com.google.common.collect.ImmutableList) TempDataOperationContext(com.facebook.presto.spi.storage.TempDataOperationContext) PrestoSparkSessionProperties.getSparkBroadcastJoinMaxMemoryOverride(com.facebook.presto.spark.PrestoSparkSessionProperties.getSparkBroadcastJoinMaxMemoryOverride) PreDeterminedPartitionFunction(com.facebook.presto.spark.execution.PrestoSparkRowOutputOperator.PreDeterminedPartitionFunction) IPrestoSparkTaskExecutor(com.facebook.presto.spark.classloader_interface.IPrestoSparkTaskExecutor) ImmutableMap(com.google.common.collect.ImmutableMap) TokenAuthenticator(com.facebook.presto.spi.security.TokenAuthenticator) SerializedPrestoSparkTaskDescriptor(com.facebook.presto.spark.classloader_interface.SerializedPrestoSparkTaskDescriptor) PrestoSparkTaskDescriptor(com.facebook.presto.spark.PrestoSparkTaskDescriptor) TaskId(com.facebook.presto.execution.TaskId) StageExecutionId(com.facebook.presto.execution.StageExecutionId) PrestoSparkUtils.getNullifyingIterator(com.facebook.presto.spark.util.PrestoSparkUtils.getNullifyingIterator) AbstractIterator(scala.collection.AbstractIterator) Iterator(scala.collection.Iterator) TestingGcMonitor(com.facebook.airlift.stats.TestingGcMonitor) PlanFragmentId(com.facebook.presto.sql.planner.plan.PlanFragmentId) MemoryPoolId(com.facebook.presto.spi.memory.MemoryPoolId) SpillSpaceTracker(com.facebook.presto.spiller.SpillSpaceTracker) TaskContext(com.facebook.presto.operator.TaskContext) QueryContext(com.facebook.presto.memory.QueryContext) UncheckedIOException(java.io.UncheckedIOException) IOException(java.io.IOException) NoSuchElementException(java.util.NoSuchElementException) PrestoSparkUtils.toPrestoSparkSerializedPage(com.facebook.presto.spark.util.PrestoSparkUtils.toPrestoSparkSerializedPage) PrestoSparkSerializedPage(com.facebook.presto.spark.classloader_interface.PrestoSparkSerializedPage) LocalExecutionPlan(com.facebook.presto.sql.planner.LocalExecutionPlanner.LocalExecutionPlan) AtomicLong(java.util.concurrent.atomic.AtomicLong) TempStorage(com.facebook.presto.spi.storage.TempStorage) Tuple2(scala.Tuple2) OptionalLong(java.util.OptionalLong) SerializedPrestoSparkTaskSource(com.facebook.presto.spark.classloader_interface.SerializedPrestoSparkTaskSource) TaskSource(com.facebook.presto.execution.TaskSource) OutputPartitioning(com.facebook.presto.sql.planner.OutputPartitioning) Session(com.facebook.presto.Session) MemoryPool(com.facebook.presto.memory.MemoryPool)

Example 5 with MemoryPoolId

use of com.facebook.presto.spi.memory.MemoryPoolId in project presto by prestodb.

the class GroupByHashYieldAssertion method finishOperatorWithYieldingGroupByHash.

/**
 * @param operatorFactory creates an Operator that should directly or indirectly contain GroupByHash
 * @param getHashCapacity returns the hash table capacity for the input operator
 * @param additionalMemoryInBytes the memory used in addition to the GroupByHash in the operator (e.g., aggregator)
 */
public static GroupByHashYieldResult finishOperatorWithYieldingGroupByHash(List<Page> input, Type hashKeyType, OperatorFactory operatorFactory, Function<Operator, Integer> getHashCapacity, long additionalMemoryInBytes) {
    assertLessThan(additionalMemoryInBytes, 1L << 21, "additionalMemoryInBytes should be a relatively small number");
    List<Page> result = new LinkedList<>();
    // mock an adjustable memory pool
    QueryId queryId1 = new QueryId("test_query1");
    QueryId queryId2 = new QueryId("test_query2");
    MemoryPool memoryPool = new MemoryPool(new MemoryPoolId("test"), new DataSize(1, GIGABYTE));
    QueryContext queryContext = new QueryContext(queryId2, new DataSize(512, MEGABYTE), new DataSize(1024, MEGABYTE), new DataSize(512, MEGABYTE), new DataSize(1, GIGABYTE), memoryPool, new TestingGcMonitor(), EXECUTOR, SCHEDULED_EXECUTOR, new DataSize(512, MEGABYTE), new SpillSpaceTracker(new DataSize(512, MEGABYTE)), listJsonCodec(TaskMemoryReservationSummary.class));
    DriverContext driverContext = createTaskContext(queryContext, EXECUTOR, TEST_SESSION).addPipelineContext(0, true, true, false).addDriverContext();
    Operator operator = operatorFactory.createOperator(driverContext);
    // run operator
    int yieldCount = 0;
    long expectedReservedExtraBytes = 0;
    for (Page page : input) {
        // unblocked
        assertTrue(operator.needsInput());
        // saturate the pool with a tiny memory left
        long reservedMemoryInBytes = memoryPool.getFreeBytes() - additionalMemoryInBytes;
        memoryPool.reserve(queryId1, "test", reservedMemoryInBytes);
        long oldMemoryUsage = operator.getOperatorContext().getDriverContext().getMemoryUsage();
        int oldCapacity = getHashCapacity.apply(operator);
        // add a page and verify different behaviors
        operator.addInput(page);
        // get output to consume the input
        Page output = operator.getOutput();
        if (output != null) {
            result.add(output);
        }
        long newMemoryUsage = operator.getOperatorContext().getDriverContext().getMemoryUsage();
        // between rehash and memory used by aggregator
        if (newMemoryUsage < new DataSize(4, MEGABYTE).toBytes()) {
            // free the pool for the next iteration
            memoryPool.free(queryId1, "test", reservedMemoryInBytes);
            // this required in case input is blocked
            operator.getOutput();
            continue;
        }
        long actualIncreasedMemory = newMemoryUsage - oldMemoryUsage;
        if (operator.needsInput()) {
            // We have successfully added a page
            // Assert we are not blocked
            assertTrue(operator.getOperatorContext().isWaitingForMemory().isDone());
            // assert the hash capacity is not changed; otherwise, we should have yielded
            assertTrue(oldCapacity == getHashCapacity.apply(operator));
            // We are not going to rehash; therefore, assert the memory increase only comes from the aggregator
            assertLessThan(actualIncreasedMemory, additionalMemoryInBytes);
            // free the pool for the next iteration
            memoryPool.free(queryId1, "test", reservedMemoryInBytes);
        } else {
            // We failed to finish the page processing i.e. we yielded
            yieldCount++;
            // Assert we are blocked
            assertFalse(operator.getOperatorContext().isWaitingForMemory().isDone());
            // Hash table capacity should not change
            assertEquals(oldCapacity, (long) getHashCapacity.apply(operator));
            // Increased memory is no smaller than the hash table size and no greater than the hash table size + the memory used by aggregator
            if (hashKeyType == BIGINT) {
                // groupIds and values double by hashCapacity; while valuesByGroupId double by maxFill = hashCapacity / 0.75
                expectedReservedExtraBytes = oldCapacity * (long) (Long.BYTES * 1.75 + Integer.BYTES) + page.getRetainedSizeInBytes();
            } else {
                // groupAddressByHash, groupIdsByHash, and rawHashByHashPosition double by hashCapacity; while groupAddressByGroupId double by maxFill = hashCapacity / 0.75
                expectedReservedExtraBytes = oldCapacity * (long) (Long.BYTES * 1.75 + Integer.BYTES + Byte.BYTES) + page.getRetainedSizeInBytes();
            }
            assertBetweenInclusive(actualIncreasedMemory, expectedReservedExtraBytes, expectedReservedExtraBytes + additionalMemoryInBytes);
            // Output should be blocked as well
            assertNull(operator.getOutput());
            // Free the pool to unblock
            memoryPool.free(queryId1, "test", reservedMemoryInBytes);
            // Trigger a process through getOutput() or needsInput()
            output = operator.getOutput();
            if (output != null) {
                result.add(output);
            }
            assertTrue(operator.needsInput());
            // Hash table capacity has increased
            assertGreaterThan(getHashCapacity.apply(operator), oldCapacity);
            // Assert the estimated reserved memory before rehash is very close to the one after rehash
            long rehashedMemoryUsage = operator.getOperatorContext().getDriverContext().getMemoryUsage();
            assertBetweenInclusive(rehashedMemoryUsage * 1.0 / newMemoryUsage, 0.99, 1.01);
            // unblocked
            assertTrue(operator.needsInput());
        }
    }
    result.addAll(finishOperator(operator));
    return new GroupByHashYieldResult(yieldCount, expectedReservedExtraBytes, result);
}
Also used : OperatorAssertion.finishOperator(com.facebook.presto.operator.OperatorAssertion.finishOperator) SpillSpaceTracker(com.facebook.presto.spiller.SpillSpaceTracker) QueryId(com.facebook.presto.spi.QueryId) Page(com.facebook.presto.common.Page) QueryContext(com.facebook.presto.memory.QueryContext) LinkedList(java.util.LinkedList) DataSize(io.airlift.units.DataSize) TestingGcMonitor(com.facebook.airlift.stats.TestingGcMonitor) MemoryPoolId(com.facebook.presto.spi.memory.MemoryPoolId) MemoryPool(com.facebook.presto.memory.MemoryPool)

Aggregations

MemoryPoolId (com.facebook.presto.spi.memory.MemoryPoolId)23 DataSize (io.airlift.units.DataSize)17 QueryId (com.facebook.presto.spi.QueryId)16 MemoryPool (com.facebook.presto.memory.MemoryPool)9 TestingGcMonitor (com.facebook.airlift.stats.TestingGcMonitor)7 ImmutableMap (com.google.common.collect.ImmutableMap)7 Test (org.testng.annotations.Test)7 QueryContext (com.facebook.presto.memory.QueryContext)6 Session (com.facebook.presto.Session)5 SpillSpaceTracker (com.facebook.presto.spiller.SpillSpaceTracker)5 TaskId (com.facebook.presto.execution.TaskId)4 TaskStateMachine (com.facebook.presto.execution.TaskStateMachine)4 TaskContext (com.facebook.presto.operator.TaskContext)4 ImmutableList (com.google.common.collect.ImmutableList)4 HashMap (java.util.HashMap)4 AtomicLong (java.util.concurrent.atomic.AtomicLong)4 SqlTask.createSqlTask (com.facebook.presto.execution.SqlTask.createSqlTask)3 OperatorContext (com.facebook.presto.operator.OperatorContext)3 TaskMemoryReservationSummary (com.facebook.presto.operator.TaskMemoryReservationSummary)3 TpchConnectorFactory (com.facebook.presto.tpch.TpchConnectorFactory)3