Search in sources :

Example 1 with QueryContext

use of io.trino.memory.QueryContext in project trino by trinodb.

the class TestSqlTask method createInitialTask.

private SqlTask createInitialTask() {
    TaskId taskId = new TaskId(new StageId("query", 0), nextTaskId.incrementAndGet(), 0);
    URI location = URI.create("fake://task/" + taskId);
    QueryContext queryContext = new QueryContext(new QueryId("query"), DataSize.of(1, MEGABYTE), new MemoryPool(DataSize.of(1, GIGABYTE)), new TestingGcMonitor(), taskNotificationExecutor, driverYieldExecutor, DataSize.of(1, MEGABYTE), new SpillSpaceTracker(DataSize.of(1, GIGABYTE)));
    queryContext.addTaskContext(new TaskStateMachine(taskId, taskNotificationExecutor), testSessionBuilder().build(), () -> {
    }, false, false);
    return createSqlTask(taskId, location, "fake", queryContext, sqlTaskExecutionFactory, taskNotificationExecutor, sqlTask -> {
    }, DataSize.of(32, MEGABYTE), DataSize.of(200, MEGABYTE), new ExchangeManagerRegistry(new ExchangeHandleResolver()), new CounterStat());
}
Also used : SpillSpaceTracker(io.trino.spiller.SpillSpaceTracker) CounterStat(io.airlift.stats.CounterStat) QueryId(io.trino.spi.QueryId) TestingGcMonitor(io.airlift.stats.TestingGcMonitor) QueryContext(io.trino.memory.QueryContext) URI(java.net.URI) ExchangeManagerRegistry(io.trino.exchange.ExchangeManagerRegistry) ExchangeHandleResolver(io.trino.metadata.ExchangeHandleResolver) MemoryPool(io.trino.memory.MemoryPool)

Example 2 with QueryContext

use of io.trino.memory.QueryContext in project trino by trinodb.

the class TestMemoryRevokingScheduler method testScheduleMemoryRevoking.

@Test
public void testScheduleMemoryRevoking() throws Exception {
    QueryContext q1 = getOrCreateQueryContext(new QueryId("q1"));
    QueryContext q2 = getOrCreateQueryContext(new QueryId("q2"));
    SqlTask sqlTask1 = newSqlTask(q1.getQueryId());
    SqlTask sqlTask2 = newSqlTask(q2.getQueryId());
    TaskContext taskContext1 = getOrCreateTaskContext(sqlTask1);
    PipelineContext pipelineContext11 = taskContext1.addPipelineContext(0, false, false, false);
    DriverContext driverContext111 = pipelineContext11.addDriverContext();
    OperatorContext operatorContext1 = driverContext111.addOperatorContext(1, new PlanNodeId("na"), "na");
    OperatorContext operatorContext2 = driverContext111.addOperatorContext(2, new PlanNodeId("na"), "na");
    DriverContext driverContext112 = pipelineContext11.addDriverContext();
    OperatorContext operatorContext3 = driverContext112.addOperatorContext(3, new PlanNodeId("na"), "na");
    TaskContext taskContext2 = getOrCreateTaskContext(sqlTask2);
    PipelineContext pipelineContext21 = taskContext2.addPipelineContext(1, false, false, false);
    DriverContext driverContext211 = pipelineContext21.addDriverContext();
    OperatorContext operatorContext4 = driverContext211.addOperatorContext(4, new PlanNodeId("na"), "na");
    OperatorContext operatorContext5 = driverContext211.addOperatorContext(5, new PlanNodeId("na"), "na");
    Collection<SqlTask> tasks = ImmutableList.of(sqlTask1, sqlTask2);
    MemoryRevokingScheduler scheduler = new MemoryRevokingScheduler(memoryPool, () -> tasks, executor, 1.0, 1.0);
    allOperatorContexts = ImmutableSet.of(operatorContext1, operatorContext2, operatorContext3, operatorContext4, operatorContext5);
    assertMemoryRevokingNotRequested();
    requestMemoryRevoking(scheduler);
    assertEquals(10, memoryPool.getFreeBytes());
    assertMemoryRevokingNotRequested();
    LocalMemoryContext revocableMemory1 = operatorContext1.localRevocableMemoryContext();
    LocalMemoryContext revocableMemory3 = operatorContext3.localRevocableMemoryContext();
    LocalMemoryContext revocableMemory4 = operatorContext4.localRevocableMemoryContext();
    LocalMemoryContext revocableMemory5 = operatorContext5.localRevocableMemoryContext();
    revocableMemory1.setBytes(3);
    revocableMemory3.setBytes(6);
    assertEquals(1, memoryPool.getFreeBytes());
    requestMemoryRevoking(scheduler);
    // we are still good - no revoking needed
    assertMemoryRevokingNotRequested();
    revocableMemory4.setBytes(7);
    assertEquals(-6, memoryPool.getFreeBytes());
    requestMemoryRevoking(scheduler);
    // we need to revoke 3 and 6
    assertMemoryRevokingRequestedFor(operatorContext1, operatorContext3);
    // yet another revoking request should not change anything
    requestMemoryRevoking(scheduler);
    assertMemoryRevokingRequestedFor(operatorContext1, operatorContext3);
    // lets revoke some bytes
    revocableMemory1.setBytes(0);
    operatorContext1.resetMemoryRevokingRequested();
    requestMemoryRevoking(scheduler);
    assertMemoryRevokingRequestedFor(operatorContext3);
    assertEquals(-3, memoryPool.getFreeBytes());
    // and allocate some more
    revocableMemory5.setBytes(3);
    assertEquals(-6, memoryPool.getFreeBytes());
    requestMemoryRevoking(scheduler);
    // we are still good with just OC3 in process of revoking
    assertMemoryRevokingRequestedFor(operatorContext3);
    // and allocate some more
    revocableMemory5.setBytes(4);
    assertEquals(-7, memoryPool.getFreeBytes());
    requestMemoryRevoking(scheduler);
    // no we have to trigger revoking for OC4
    assertMemoryRevokingRequestedFor(operatorContext3, operatorContext4);
}
Also used : PlanNodeId(io.trino.sql.planner.plan.PlanNodeId) SqlTask.createSqlTask(io.trino.execution.SqlTask.createSqlTask) DriverContext(io.trino.operator.DriverContext) LocalMemoryContext(io.trino.memory.context.LocalMemoryContext) TaskContext(io.trino.operator.TaskContext) PipelineContext(io.trino.operator.PipelineContext) QueryId(io.trino.spi.QueryId) OperatorContext(io.trino.operator.OperatorContext) QueryContext(io.trino.memory.QueryContext) Test(org.testng.annotations.Test)

Example 3 with QueryContext

use of io.trino.memory.QueryContext in project trino by trinodb.

the class GroupByHashYieldAssertion method finishOperatorWithYieldingGroupByHash.

/**
 * @param operatorFactory creates an Operator that should directly or indirectly contain GroupByHash
 * @param getHashCapacity returns the hash table capacity for the input operator
 * @param additionalMemoryInBytes the memory used in addition to the GroupByHash in the operator (e.g., aggregator)
 */
public static GroupByHashYieldResult finishOperatorWithYieldingGroupByHash(List<Page> input, Type hashKeyType, OperatorFactory operatorFactory, Function<Operator, Integer> getHashCapacity, long additionalMemoryInBytes) {
    assertLessThan(additionalMemoryInBytes, 1L << 21, "additionalMemoryInBytes should be a relatively small number");
    List<Page> result = new LinkedList<>();
    // mock an adjustable memory pool
    QueryId queryId = new QueryId("test_query");
    MemoryPool memoryPool = new MemoryPool(DataSize.of(1, GIGABYTE));
    QueryContext queryContext = new QueryContext(queryId, DataSize.of(512, MEGABYTE), memoryPool, new TestingGcMonitor(), EXECUTOR, SCHEDULED_EXECUTOR, DataSize.of(512, MEGABYTE), new SpillSpaceTracker(DataSize.of(512, MEGABYTE)));
    DriverContext driverContext = createTaskContext(queryContext, EXECUTOR, TEST_SESSION).addPipelineContext(0, true, true, false).addDriverContext();
    Operator operator = operatorFactory.createOperator(driverContext);
    // run operator
    int yieldCount = 0;
    long expectedReservedExtraBytes = 0;
    for (Page page : input) {
        // unblocked
        assertTrue(operator.needsInput());
        // saturate the pool with a tiny memory left
        long reservedMemoryInBytes = memoryPool.getFreeBytes() - additionalMemoryInBytes;
        memoryPool.reserve(queryId, "test", reservedMemoryInBytes);
        long oldMemoryUsage = operator.getOperatorContext().getDriverContext().getMemoryUsage();
        int oldCapacity = getHashCapacity.apply(operator);
        // add a page and verify different behaviors
        operator.addInput(page);
        // get output to consume the input
        Page output = operator.getOutput();
        if (output != null) {
            result.add(output);
        }
        long newMemoryUsage = operator.getOperatorContext().getDriverContext().getMemoryUsage();
        // between rehash and memory used by aggregator
        if (newMemoryUsage < DataSize.of(4, MEGABYTE).toBytes()) {
            // free the pool for the next iteration
            memoryPool.free(queryId, "test", reservedMemoryInBytes);
            // this required in case input is blocked
            operator.getOutput();
            continue;
        }
        long actualIncreasedMemory = newMemoryUsage - oldMemoryUsage;
        if (operator.needsInput()) {
            // We have successfully added a page
            // Assert we are not blocked
            assertTrue(operator.getOperatorContext().isWaitingForMemory().isDone());
            // assert the hash capacity is not changed; otherwise, we should have yielded
            assertTrue(oldCapacity == getHashCapacity.apply(operator));
            // We are not going to rehash; therefore, assert the memory increase only comes from the aggregator
            assertLessThan(actualIncreasedMemory, additionalMemoryInBytes);
            // free the pool for the next iteration
            memoryPool.free(queryId, "test", reservedMemoryInBytes);
        } else {
            // We failed to finish the page processing i.e. we yielded
            yieldCount++;
            // Assert we are blocked
            assertFalse(operator.getOperatorContext().isWaitingForMemory().isDone());
            // Hash table capacity should not change
            assertEquals(oldCapacity, (long) getHashCapacity.apply(operator));
            // Increased memory is no smaller than the hash table size and no greater than the hash table size + the memory used by aggregator
            if (hashKeyType == BIGINT) {
                // groupIds and values double by hashCapacity; while valuesByGroupId double by maxFill = hashCapacity / 0.75
                expectedReservedExtraBytes = oldCapacity * (long) (Long.BYTES * 1.75 + Integer.BYTES) + page.getRetainedSizeInBytes();
            } else {
                // groupAddressByHash, groupIdsByHash, and rawHashByHashPosition double by hashCapacity; while groupAddressByGroupId double by maxFill = hashCapacity / 0.75
                expectedReservedExtraBytes = oldCapacity * (long) (Long.BYTES * 1.75 + Integer.BYTES + Byte.BYTES) + page.getRetainedSizeInBytes();
            }
            assertBetweenInclusive(actualIncreasedMemory, expectedReservedExtraBytes, expectedReservedExtraBytes + additionalMemoryInBytes);
            // Output should be blocked as well
            assertNull(operator.getOutput());
            // Free the pool to unblock
            memoryPool.free(queryId, "test", reservedMemoryInBytes);
            // Trigger a process through getOutput() or needsInput()
            output = operator.getOutput();
            if (output != null) {
                result.add(output);
            }
            assertTrue(operator.needsInput());
            // Hash table capacity has increased
            assertGreaterThan(getHashCapacity.apply(operator), oldCapacity);
            // Assert the estimated reserved memory before rehash is very close to the one after rehash
            long rehashedMemoryUsage = operator.getOperatorContext().getDriverContext().getMemoryUsage();
            assertBetweenInclusive(rehashedMemoryUsage * 1.0 / newMemoryUsage, 0.99, 1.01);
            // unblocked
            assertTrue(operator.needsInput());
        }
    }
    result.addAll(finishOperator(operator));
    return new GroupByHashYieldResult(yieldCount, expectedReservedExtraBytes, result);
}
Also used : OperatorAssertion.finishOperator(io.trino.operator.OperatorAssertion.finishOperator) SpillSpaceTracker(io.trino.spiller.SpillSpaceTracker) QueryId(io.trino.spi.QueryId) Page(io.trino.spi.Page) QueryContext(io.trino.memory.QueryContext) LinkedList(java.util.LinkedList) TestingGcMonitor(io.airlift.stats.TestingGcMonitor) MemoryPool(io.trino.memory.MemoryPool)

Example 4 with QueryContext

use of io.trino.memory.QueryContext in project trino by trinodb.

the class SqlTaskManager method doUpdateTask.

private TaskInfo doUpdateTask(Session session, TaskId taskId, Optional<PlanFragment> fragment, List<SplitAssignment> splitAssignments, OutputBuffers outputBuffers, Map<DynamicFilterId, Domain> dynamicFilterDomains) {
    requireNonNull(session, "session is null");
    requireNonNull(taskId, "taskId is null");
    requireNonNull(fragment, "fragment is null");
    requireNonNull(splitAssignments, "splitAssignments is null");
    requireNonNull(outputBuffers, "outputBuffers is null");
    SqlTask sqlTask = tasks.getUnchecked(taskId);
    QueryContext queryContext = sqlTask.getQueryContext();
    if (!queryContext.isMemoryLimitsInitialized()) {
        long sessionQueryMaxMemoryPerNode = getQueryMaxMemoryPerNode(session).toBytes();
        // Session properties are only allowed to decrease memory limits, not increase them
        queryContext.initializeMemoryLimits(resourceOvercommit(session), min(sessionQueryMaxMemoryPerNode, queryMaxMemoryPerNode));
    }
    sqlTask.recordHeartbeat();
    return sqlTask.updateTask(session, fragment, splitAssignments, outputBuffers, dynamicFilterDomains);
}
Also used : SqlTask.createSqlTask(io.trino.execution.SqlTask.createSqlTask) QueryContext(io.trino.memory.QueryContext)

Example 5 with QueryContext

use of io.trino.memory.QueryContext in project trino by trinodb.

the class TestSqlTaskManager method testSessionPropertyMemoryLimitOverride.

@Test
public void testSessionPropertyMemoryLimitOverride() {
    NodeMemoryConfig memoryConfig = new NodeMemoryConfig().setMaxQueryMemoryPerNode(DataSize.ofBytes(3));
    try (SqlTaskManager sqlTaskManager = createSqlTaskManager(new TaskManagerConfig(), memoryConfig)) {
        TaskId reduceLimitsId = new TaskId(new StageId("q1", 0), 1, 0);
        TaskId increaseLimitsId = new TaskId(new StageId("q2", 0), 1, 0);
        QueryContext reducesLimitsContext = sqlTaskManager.getQueryContext(reduceLimitsId.getQueryId());
        QueryContext attemptsIncreaseContext = sqlTaskManager.getQueryContext(increaseLimitsId.getQueryId());
        // not initialized with a task update yet
        assertFalse(reducesLimitsContext.isMemoryLimitsInitialized());
        assertEquals(reducesLimitsContext.getMaxUserMemory(), memoryConfig.getMaxQueryMemoryPerNode().toBytes());
        assertFalse(attemptsIncreaseContext.isMemoryLimitsInitialized());
        assertEquals(attemptsIncreaseContext.getMaxUserMemory(), memoryConfig.getMaxQueryMemoryPerNode().toBytes());
        // memory limits reduced by session properties
        sqlTaskManager.updateTask(testSessionBuilder().setSystemProperty(QUERY_MAX_MEMORY_PER_NODE, "1B").build(), reduceLimitsId, Optional.of(PLAN_FRAGMENT), ImmutableList.of(new SplitAssignment(TABLE_SCAN_NODE_ID, ImmutableSet.of(SPLIT), true)), createInitialEmptyOutputBuffers(PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds(), ImmutableMap.of());
        assertTrue(reducesLimitsContext.isMemoryLimitsInitialized());
        assertEquals(reducesLimitsContext.getMaxUserMemory(), 1);
        // memory limits not increased by session properties
        sqlTaskManager.updateTask(testSessionBuilder().setSystemProperty(QUERY_MAX_MEMORY_PER_NODE, "10B").build(), increaseLimitsId, Optional.of(PLAN_FRAGMENT), ImmutableList.of(new SplitAssignment(TABLE_SCAN_NODE_ID, ImmutableSet.of(SPLIT), true)), createInitialEmptyOutputBuffers(PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds(), ImmutableMap.of());
        assertTrue(attemptsIncreaseContext.isMemoryLimitsInitialized());
        assertEquals(attemptsIncreaseContext.getMaxUserMemory(), memoryConfig.getMaxQueryMemoryPerNode().toBytes());
    }
}
Also used : NodeMemoryConfig(io.trino.memory.NodeMemoryConfig) QueryContext(io.trino.memory.QueryContext) Test(org.testng.annotations.Test)

Aggregations

QueryContext (io.trino.memory.QueryContext)8 QueryId (io.trino.spi.QueryId)5 TestingGcMonitor (io.airlift.stats.TestingGcMonitor)4 MemoryPool (io.trino.memory.MemoryPool)4 SpillSpaceTracker (io.trino.spiller.SpillSpaceTracker)4 TaskContext (io.trino.operator.TaskContext)3 CounterStat (io.airlift.stats.CounterStat)2 Session (io.trino.Session)2 ExchangeManagerRegistry (io.trino.exchange.ExchangeManagerRegistry)2 SqlTask.createSqlTask (io.trino.execution.SqlTask.createSqlTask)2 StageId (io.trino.execution.StageId)2 TaskId (io.trino.execution.TaskId)2 TaskStateMachine (io.trino.execution.TaskStateMachine)2 ExchangeHandleResolver (io.trino.metadata.ExchangeHandleResolver)2 Page (io.trino.spi.Page)2 URI (java.net.URI)2 Test (org.testng.annotations.Test)2 ImmutableList (com.google.common.collect.ImmutableList)1 ImmutableMap (com.google.common.collect.ImmutableMap)1 CpuTimer (io.airlift.stats.CpuTimer)1