Search in sources :

Example 1 with PlanNodeId

use of com.facebook.presto.spi.plan.PlanNodeId in project presto by prestodb.

the class PrestoSparkTaskExecutorFactory method doCreate.

public <T extends PrestoSparkTaskOutput> IPrestoSparkTaskExecutor<T> doCreate(int partitionId, int attemptNumber, SerializedPrestoSparkTaskDescriptor serializedTaskDescriptor, Iterator<SerializedPrestoSparkTaskSource> serializedTaskSources, PrestoSparkTaskInputs inputs, CollectionAccumulator<SerializedTaskInfo> taskInfoCollector, CollectionAccumulator<PrestoSparkShuffleStats> shuffleStatsCollector, Class<T> outputType) {
    PrestoSparkTaskDescriptor taskDescriptor = taskDescriptorJsonCodec.fromJson(serializedTaskDescriptor.getBytes());
    ImmutableMap.Builder<String, TokenAuthenticator> extraAuthenticators = ImmutableMap.builder();
    authenticatorProviders.forEach(provider -> extraAuthenticators.putAll(provider.getTokenAuthenticators()));
    Session session = taskDescriptor.getSession().toSession(sessionPropertyManager, taskDescriptor.getExtraCredentials(), extraAuthenticators.build());
    PlanFragment fragment = taskDescriptor.getFragment();
    StageId stageId = new StageId(session.getQueryId(), fragment.getId().getId());
    // Clear the cache if the cache does not have broadcast table for current stageId.
    // We will only cache 1 HT at any time. If the stageId changes, we will drop the old cached HT
    prestoSparkBroadcastTableCacheManager.removeCachedTablesForStagesOtherThan(stageId);
    // TODO: include attemptId in taskId
    TaskId taskId = new TaskId(new StageExecutionId(stageId, 0), partitionId);
    List<TaskSource> taskSources = getTaskSources(serializedTaskSources);
    log.info("Task [%s] received %d splits.", taskId, taskSources.stream().mapToInt(taskSource -> taskSource.getSplits().size()).sum());
    OptionalLong totalSplitSize = computeAllSplitsSize(taskSources);
    if (totalSplitSize.isPresent()) {
        log.info("Total split size: %s bytes.", totalSplitSize.getAsLong());
    }
    // TODO: Remove this once we can display the plan on Spark UI.
    log.info(PlanPrinter.textPlanFragment(fragment, functionAndTypeManager, session, true));
    DataSize maxUserMemory = new DataSize(min(nodeMemoryConfig.getMaxQueryMemoryPerNode().toBytes(), getQueryMaxMemoryPerNode(session).toBytes()), BYTE);
    DataSize maxTotalMemory = new DataSize(min(nodeMemoryConfig.getMaxQueryTotalMemoryPerNode().toBytes(), getQueryMaxTotalMemoryPerNode(session).toBytes()), BYTE);
    DataSize maxBroadcastMemory = getSparkBroadcastJoinMaxMemoryOverride(session);
    if (maxBroadcastMemory == null) {
        maxBroadcastMemory = new DataSize(min(nodeMemoryConfig.getMaxQueryBroadcastMemory().toBytes(), getQueryMaxBroadcastMemory(session).toBytes()), BYTE);
    }
    MemoryPool memoryPool = new MemoryPool(new MemoryPoolId("spark-executor-memory-pool"), maxTotalMemory);
    SpillSpaceTracker spillSpaceTracker = new SpillSpaceTracker(maxQuerySpillPerNode);
    QueryContext queryContext = new QueryContext(session.getQueryId(), maxUserMemory, maxTotalMemory, maxBroadcastMemory, maxRevocableMemory, memoryPool, new TestingGcMonitor(), notificationExecutor, yieldExecutor, maxQuerySpillPerNode, spillSpaceTracker, memoryReservationSummaryJsonCodec);
    queryContext.setVerboseExceededMemoryLimitErrorsEnabled(isVerboseExceededMemoryLimitErrorsEnabled(session));
    queryContext.setHeapDumpOnExceededMemoryLimitEnabled(isHeapDumpOnExceededMemoryLimitEnabled(session));
    String heapDumpFilePath = Paths.get(getHeapDumpFileDirectory(session), format("%s_%s.hprof", session.getQueryId().getId(), stageId.getId())).toString();
    queryContext.setHeapDumpFilePath(heapDumpFilePath);
    TaskStateMachine taskStateMachine = new TaskStateMachine(taskId, notificationExecutor);
    TaskContext taskContext = queryContext.addTaskContext(taskStateMachine, session, // Plan has to be retained only if verbose memory exceeded errors are requested
    isVerboseExceededMemoryLimitErrorsEnabled(session) ? Optional.of(fragment.getRoot()) : Optional.empty(), perOperatorCpuTimerEnabled, cpuTimerEnabled, perOperatorAllocationTrackingEnabled, allocationTrackingEnabled, false);
    final double memoryRevokingThreshold = getMemoryRevokingThreshold(session);
    final double memoryRevokingTarget = getMemoryRevokingTarget(session);
    checkArgument(memoryRevokingTarget <= memoryRevokingThreshold, "memoryRevokingTarget should be less than or equal memoryRevokingThreshold, but got %s and %s respectively", memoryRevokingTarget, memoryRevokingThreshold);
    if (isSpillEnabled(session)) {
        memoryPool.addListener((pool, queryId, totalMemoryReservationBytes) -> {
            if (totalMemoryReservationBytes > queryContext.getPeakNodeTotalMemory()) {
                queryContext.setPeakNodeTotalMemory(totalMemoryReservationBytes);
            }
            if (totalMemoryReservationBytes > pool.getMaxBytes() * memoryRevokingThreshold && memoryRevokeRequestInProgress.compareAndSet(false, true)) {
                memoryRevocationExecutor.execute(() -> {
                    try {
                        AtomicLong remainingBytesToRevoke = new AtomicLong(totalMemoryReservationBytes - (long) (memoryRevokingTarget * pool.getMaxBytes()));
                        remainingBytesToRevoke.addAndGet(-MemoryRevokingSchedulerUtils.getMemoryAlreadyBeingRevoked(ImmutableList.of(taskContext), remainingBytesToRevoke.get()));
                        taskContext.accept(new VoidTraversingQueryContextVisitor<AtomicLong>() {

                            @Override
                            public Void visitOperatorContext(OperatorContext operatorContext, AtomicLong remainingBytesToRevoke) {
                                if (remainingBytesToRevoke.get() > 0) {
                                    long revokedBytes = operatorContext.requestMemoryRevoking();
                                    if (revokedBytes > 0) {
                                        memoryRevokePending.set(true);
                                        remainingBytesToRevoke.addAndGet(-revokedBytes);
                                    }
                                }
                                return null;
                            }
                        }, remainingBytesToRevoke);
                        memoryRevokeRequestInProgress.set(false);
                    } catch (Exception e) {
                        log.error(e, "Error requesting memory revoking");
                    }
                });
            }
            // Get the latest memory reservation info since it might have changed due to revoke
            long totalReservedMemory = pool.getQueryMemoryReservation(queryId) + pool.getQueryRevocableMemoryReservation(queryId);
            // If total memory usage is over maxTotalMemory and memory revoke request is not pending, fail the query with EXCEEDED_MEMORY_LIMIT error
            if (totalReservedMemory > maxTotalMemory.toBytes() && !memoryRevokeRequestInProgress.get() && !isMemoryRevokePending(taskContext)) {
                throw exceededLocalTotalMemoryLimit(maxTotalMemory, queryContext.getAdditionalFailureInfo(totalReservedMemory, 0) + format("Total reserved memory: %s, Total revocable memory: %s", succinctBytes(pool.getQueryMemoryReservation(queryId)), succinctBytes(pool.getQueryRevocableMemoryReservation(queryId))), isHeapDumpOnExceededMemoryLimitEnabled(session), Optional.ofNullable(heapDumpFilePath));
            }
        });
    }
    ImmutableMap.Builder<PlanNodeId, List<PrestoSparkShuffleInput>> shuffleInputs = ImmutableMap.builder();
    ImmutableMap.Builder<PlanNodeId, List<java.util.Iterator<PrestoSparkSerializedPage>>> pageInputs = ImmutableMap.builder();
    ImmutableMap.Builder<PlanNodeId, List<?>> broadcastInputs = ImmutableMap.builder();
    for (RemoteSourceNode remoteSource : fragment.getRemoteSourceNodes()) {
        List<PrestoSparkShuffleInput> remoteSourceRowInputs = new ArrayList<>();
        List<java.util.Iterator<PrestoSparkSerializedPage>> remoteSourcePageInputs = new ArrayList<>();
        List<List<?>> broadcastInputsList = new ArrayList<>();
        for (PlanFragmentId sourceFragmentId : remoteSource.getSourceFragmentIds()) {
            Iterator<Tuple2<MutablePartitionId, PrestoSparkMutableRow>> shuffleInput = inputs.getShuffleInputs().get(sourceFragmentId.toString());
            Broadcast<?> broadcastInput = inputs.getBroadcastInputs().get(sourceFragmentId.toString());
            List<PrestoSparkSerializedPage> inMemoryInput = inputs.getInMemoryInputs().get(sourceFragmentId.toString());
            if (shuffleInput != null) {
                checkArgument(broadcastInput == null, "single remote source is not expected to accept different kind of inputs");
                checkArgument(inMemoryInput == null, "single remote source is not expected to accept different kind of inputs");
                remoteSourceRowInputs.add(new PrestoSparkShuffleInput(sourceFragmentId.getId(), shuffleInput));
                continue;
            }
            if (broadcastInput != null) {
                checkArgument(inMemoryInput == null, "single remote source is not expected to accept different kind of inputs");
                // TODO: Enable NullifyingIterator once migrated to one task per JVM model
                // NullifyingIterator removes element from the list upon return
                // This allows GC to gradually reclaim memory
                // remoteSourcePageInputs.add(getNullifyingIterator(broadcastInput.value()));
                broadcastInputsList.add((List<?>) broadcastInput.value());
                continue;
            }
            if (inMemoryInput != null) {
                // for inmemory inputs pages can be released incrementally to save memory
                remoteSourcePageInputs.add(getNullifyingIterator(inMemoryInput));
                continue;
            }
            throw new IllegalArgumentException("Input not found for sourceFragmentId: " + sourceFragmentId);
        }
        if (!remoteSourceRowInputs.isEmpty()) {
            shuffleInputs.put(remoteSource.getId(), remoteSourceRowInputs);
        }
        if (!remoteSourcePageInputs.isEmpty()) {
            pageInputs.put(remoteSource.getId(), remoteSourcePageInputs);
        }
        if (!broadcastInputsList.isEmpty()) {
            broadcastInputs.put(remoteSource.getId(), broadcastInputsList);
        }
    }
    OutputBufferMemoryManager memoryManager = new OutputBufferMemoryManager(sinkMaxBufferSize.toBytes(), () -> queryContext.getTaskContextByTaskId(taskId).localSystemMemoryContext(), notificationExecutor);
    Optional<OutputPartitioning> preDeterminedPartition = Optional.empty();
    if (fragment.getPartitioningScheme().getPartitioning().getHandle().equals(FIXED_ARBITRARY_DISTRIBUTION)) {
        int partitionCount = getHashPartitionCount(session);
        preDeterminedPartition = Optional.of(new OutputPartitioning(new PreDeterminedPartitionFunction(partitionId % partitionCount, partitionCount), ImmutableList.of(), ImmutableList.of(), false, OptionalInt.empty()));
    }
    TempDataOperationContext tempDataOperationContext = new TempDataOperationContext(session.getSource(), session.getQueryId().getId(), session.getClientInfo(), Optional.of(session.getClientTags()), session.getIdentity());
    TempStorage tempStorage = tempStorageManager.getTempStorage(storageBasedBroadcastJoinStorage);
    Output<T> output = configureOutput(outputType, blockEncodingManager, memoryManager, getShuffleOutputTargetAverageRowSize(session), preDeterminedPartition, tempStorage, tempDataOperationContext, getStorageBasedBroadcastJoinWriteBufferSize(session));
    PrestoSparkOutputBuffer<?> outputBuffer = output.getOutputBuffer();
    LocalExecutionPlan localExecutionPlan = localExecutionPlanner.plan(taskContext, fragment.getRoot(), fragment.getPartitioningScheme(), fragment.getStageExecutionDescriptor(), fragment.getTableScanSchedulingOrder(), output.getOutputFactory(), new PrestoSparkRemoteSourceFactory(blockEncodingManager, shuffleInputs.build(), pageInputs.build(), broadcastInputs.build(), partitionId, shuffleStatsCollector, tempStorage, tempDataOperationContext, prestoSparkBroadcastTableCacheManager, stageId), taskDescriptor.getTableWriteInfo(), true);
    taskStateMachine.addStateChangeListener(state -> {
        if (state.isDone()) {
            outputBuffer.setNoMoreRows();
        }
    });
    PrestoSparkTaskExecution taskExecution = new PrestoSparkTaskExecution(taskStateMachine, taskContext, localExecutionPlan, taskExecutor, splitMonitor, notificationExecutor, memoryUpdateExecutor);
    taskExecution.start(taskSources);
    return new PrestoSparkTaskExecutor<>(taskContext, taskStateMachine, output.getOutputSupplier(), taskInfoCodec, taskInfoCollector, shuffleStatsCollector, executionExceptionFactory, output.getOutputBufferType(), outputBuffer, tempStorage, tempDataOperationContext);
}
Also used : StageId(com.facebook.presto.execution.StageId) ArrayList(java.util.ArrayList) PlanFragment(com.facebook.presto.sql.planner.PlanFragment) TaskStateMachine(com.facebook.presto.execution.TaskStateMachine) PlanNodeId(com.facebook.presto.spi.plan.PlanNodeId) RemoteSourceNode(com.facebook.presto.sql.planner.plan.RemoteSourceNode) DataSize(io.airlift.units.DataSize) OperatorContext(com.facebook.presto.operator.OperatorContext) OutputBufferMemoryManager(com.facebook.presto.execution.buffer.OutputBufferMemoryManager) ArrayList(java.util.ArrayList) List(java.util.List) ImmutableList(com.google.common.collect.ImmutableList) TempDataOperationContext(com.facebook.presto.spi.storage.TempDataOperationContext) PrestoSparkSessionProperties.getSparkBroadcastJoinMaxMemoryOverride(com.facebook.presto.spark.PrestoSparkSessionProperties.getSparkBroadcastJoinMaxMemoryOverride) PreDeterminedPartitionFunction(com.facebook.presto.spark.execution.PrestoSparkRowOutputOperator.PreDeterminedPartitionFunction) IPrestoSparkTaskExecutor(com.facebook.presto.spark.classloader_interface.IPrestoSparkTaskExecutor) ImmutableMap(com.google.common.collect.ImmutableMap) TokenAuthenticator(com.facebook.presto.spi.security.TokenAuthenticator) SerializedPrestoSparkTaskDescriptor(com.facebook.presto.spark.classloader_interface.SerializedPrestoSparkTaskDescriptor) PrestoSparkTaskDescriptor(com.facebook.presto.spark.PrestoSparkTaskDescriptor) TaskId(com.facebook.presto.execution.TaskId) StageExecutionId(com.facebook.presto.execution.StageExecutionId) PrestoSparkUtils.getNullifyingIterator(com.facebook.presto.spark.util.PrestoSparkUtils.getNullifyingIterator) AbstractIterator(scala.collection.AbstractIterator) Iterator(scala.collection.Iterator) TestingGcMonitor(com.facebook.airlift.stats.TestingGcMonitor) PlanFragmentId(com.facebook.presto.sql.planner.plan.PlanFragmentId) MemoryPoolId(com.facebook.presto.spi.memory.MemoryPoolId) SpillSpaceTracker(com.facebook.presto.spiller.SpillSpaceTracker) TaskContext(com.facebook.presto.operator.TaskContext) QueryContext(com.facebook.presto.memory.QueryContext) UncheckedIOException(java.io.UncheckedIOException) IOException(java.io.IOException) NoSuchElementException(java.util.NoSuchElementException) PrestoSparkUtils.toPrestoSparkSerializedPage(com.facebook.presto.spark.util.PrestoSparkUtils.toPrestoSparkSerializedPage) PrestoSparkSerializedPage(com.facebook.presto.spark.classloader_interface.PrestoSparkSerializedPage) LocalExecutionPlan(com.facebook.presto.sql.planner.LocalExecutionPlanner.LocalExecutionPlan) AtomicLong(java.util.concurrent.atomic.AtomicLong) TempStorage(com.facebook.presto.spi.storage.TempStorage) Tuple2(scala.Tuple2) OptionalLong(java.util.OptionalLong) SerializedPrestoSparkTaskSource(com.facebook.presto.spark.classloader_interface.SerializedPrestoSparkTaskSource) TaskSource(com.facebook.presto.execution.TaskSource) OutputPartitioning(com.facebook.presto.sql.planner.OutputPartitioning) Session(com.facebook.presto.Session) MemoryPool(com.facebook.presto.memory.MemoryPool)

Example 2 with PlanNodeId

use of com.facebook.presto.spi.plan.PlanNodeId in project presto by prestodb.

the class PrestoSparkRddFactory method createRdd.

private <T extends PrestoSparkTaskOutput> JavaPairRDD<MutablePartitionId, T> createRdd(JavaSparkContext sparkContext, Session session, PlanFragment fragment, PrestoSparkTaskExecutorFactoryProvider executorFactoryProvider, CollectionAccumulator<SerializedTaskInfo> taskInfoCollector, CollectionAccumulator<PrestoSparkShuffleStats> shuffleStatsCollector, TableWriteInfo tableWriteInfo, Map<PlanFragmentId, JavaPairRDD<MutablePartitionId, PrestoSparkMutableRow>> rddInputs, Map<PlanFragmentId, Broadcast<?>> broadcastInputs, Class<T> outputType) {
    checkInputs(fragment.getRemoteSourceNodes(), rddInputs, broadcastInputs);
    PrestoSparkTaskDescriptor taskDescriptor = new PrestoSparkTaskDescriptor(session.toSessionRepresentation(), session.getIdentity().getExtraCredentials(), fragment, tableWriteInfo);
    SerializedPrestoSparkTaskDescriptor serializedTaskDescriptor = new SerializedPrestoSparkTaskDescriptor(taskDescriptorJsonCodec.toJsonBytes(taskDescriptor));
    Optional<Integer> numberOfShufflePartitions = Optional.empty();
    Map<String, RDD<Tuple2<MutablePartitionId, PrestoSparkMutableRow>>> shuffleInputRddMap = new HashMap<>();
    for (Map.Entry<PlanFragmentId, JavaPairRDD<MutablePartitionId, PrestoSparkMutableRow>> input : rddInputs.entrySet()) {
        RDD<Tuple2<MutablePartitionId, PrestoSparkMutableRow>> rdd = input.getValue().rdd();
        shuffleInputRddMap.put(input.getKey().toString(), rdd);
        if (!numberOfShufflePartitions.isPresent()) {
            numberOfShufflePartitions = Optional.of(rdd.getNumPartitions());
        } else {
            checkArgument(numberOfShufflePartitions.get() == rdd.getNumPartitions(), "Incompatible number of input partitions: %s != %s", numberOfShufflePartitions.get(), rdd.getNumPartitions());
        }
    }
    PrestoSparkTaskProcessor<T> taskProcessor = new PrestoSparkTaskProcessor<>(executorFactoryProvider, serializedTaskDescriptor, taskInfoCollector, shuffleStatsCollector, toTaskProcessorBroadcastInputs(broadcastInputs), outputType);
    Optional<PrestoSparkTaskSourceRdd> taskSourceRdd;
    List<TableScanNode> tableScans = findTableScanNodes(fragment.getRoot());
    if (!tableScans.isEmpty()) {
        try (CloseableSplitSourceProvider splitSourceProvider = new CloseableSplitSourceProvider(splitManager::getSplits)) {
            SplitSourceFactory splitSourceFactory = new SplitSourceFactory(splitSourceProvider, WarningCollector.NOOP);
            Map<PlanNodeId, SplitSource> splitSources = splitSourceFactory.createSplitSources(fragment, session, tableWriteInfo);
            taskSourceRdd = Optional.of(createTaskSourcesRdd(fragment.getId(), sparkContext, session, fragment.getPartitioning(), tableScans, splitSources, numberOfShufflePartitions));
        }
    } else if (rddInputs.size() == 0) {
        checkArgument(fragment.getPartitioning().equals(SINGLE_DISTRIBUTION), "SINGLE_DISTRIBUTION partitioning is expected: %s", fragment.getPartitioning());
        // In case of no inputs we still need to schedule a task.
        // Task with no inputs may produce results (e.g.: ValuesNode).
        // To force the task to be scheduled we create a PrestoSparkTaskSourceRdd that contains exactly one partition.
        // Since there's also no table scans in the fragment, the list of TaskSource's for this partition is empty.
        taskSourceRdd = Optional.of(new PrestoSparkTaskSourceRdd(sparkContext.sc(), ImmutableList.of(ImmutableList.of())));
    } else {
        taskSourceRdd = Optional.empty();
    }
    return JavaPairRDD.fromRDD(PrestoSparkTaskRdd.create(sparkContext.sc(), taskSourceRdd, shuffleInputRddMap, taskProcessor), classTag(MutablePartitionId.class), classTag(outputType));
}
Also used : SerializedPrestoSparkTaskDescriptor(com.facebook.presto.spark.classloader_interface.SerializedPrestoSparkTaskDescriptor) PrestoSparkTaskDescriptor(com.facebook.presto.spark.PrestoSparkTaskDescriptor) HashMap(java.util.HashMap) SplitSourceFactory(com.facebook.presto.sql.planner.SplitSourceFactory) PlanNodeId(com.facebook.presto.spi.plan.PlanNodeId) MutablePartitionId(com.facebook.presto.spark.classloader_interface.MutablePartitionId) RDD(org.apache.spark.rdd.RDD) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) PlanFragmentId(com.facebook.presto.sql.planner.plan.PlanFragmentId) PrestoSparkMutableRow(com.facebook.presto.spark.classloader_interface.PrestoSparkMutableRow) PrestoSparkTaskSourceRdd(com.facebook.presto.spark.classloader_interface.PrestoSparkTaskSourceRdd) PrestoSparkTaskProcessor(com.facebook.presto.spark.classloader_interface.PrestoSparkTaskProcessor) SerializedPrestoSparkTaskDescriptor(com.facebook.presto.spark.classloader_interface.SerializedPrestoSparkTaskDescriptor) CloseableSplitSourceProvider(com.facebook.presto.split.CloseableSplitSourceProvider) TableScanNode(com.facebook.presto.spi.plan.TableScanNode) Tuple2(scala.Tuple2) SplitSource(com.facebook.presto.split.SplitSource) Map(java.util.Map) ImmutableMap.toImmutableMap(com.google.common.collect.ImmutableMap.toImmutableMap) HashMap(java.util.HashMap)

Example 3 with PlanNodeId

use of com.facebook.presto.spi.plan.PlanNodeId in project presto by prestodb.

the class PrestoSparkRddFactory method createTaskSourcesRdd.

private PrestoSparkTaskSourceRdd createTaskSourcesRdd(PlanFragmentId fragmentId, JavaSparkContext sparkContext, Session session, PartitioningHandle partitioning, List<TableScanNode> tableScans, Map<PlanNodeId, SplitSource> splitSources, Optional<Integer> numberOfShufflePartitions) {
    ListMultimap<Integer, SerializedPrestoSparkTaskSource> taskSourcesMap = ArrayListMultimap.create();
    for (TableScanNode tableScan : tableScans) {
        int totalNumberOfSplits = 0;
        SplitSource splitSource = requireNonNull(splitSources.get(tableScan.getId()), "split source is missing for table scan node with id: " + tableScan.getId());
        try (PrestoSparkSplitAssigner splitAssigner = createSplitAssigner(session, tableScan.getId(), splitSource, partitioning)) {
            while (true) {
                Optional<SetMultimap<Integer, ScheduledSplit>> batch = splitAssigner.getNextBatch();
                if (!batch.isPresent()) {
                    break;
                }
                int numberOfSplitsInCurrentBatch = batch.get().size();
                log.info("Found %s splits for table scan node with id %s", numberOfSplitsInCurrentBatch, tableScan.getId());
                totalNumberOfSplits += numberOfSplitsInCurrentBatch;
                taskSourcesMap.putAll(createTaskSources(tableScan.getId(), batch.get()));
            }
        }
        log.info("Total number of splits for table scan node with id %s: %s", tableScan.getId(), totalNumberOfSplits);
    }
    long allTaskSourcesSerializedSizeInBytes = taskSourcesMap.values().stream().mapToLong(serializedTaskSource -> serializedTaskSource.getBytes().length).sum();
    log.info("Total serialized size of all task sources for fragment %s: %s", fragmentId, DataSize.succinctBytes(allTaskSourcesSerializedSizeInBytes));
    List<List<SerializedPrestoSparkTaskSource>> taskSourcesByPartitionId = new ArrayList<>();
    // If the fragment contains any shuffle inputs, this value will be present
    if (numberOfShufflePartitions.isPresent()) {
        // non bucketed tables match, an empty partition must be inserted if bucket is missing.
        for (int partitionId = 0; partitionId < numberOfShufflePartitions.get(); partitionId++) {
            // Eagerly remove task sources from the map to let GC reclaim the memory
            // If task sources are missing for a partition the removeAll returns an empty list
            taskSourcesByPartitionId.add(requireNonNull(taskSourcesMap.removeAll(partitionId), "taskSources is null"));
        }
    } else {
        taskSourcesByPartitionId.addAll(Multimaps.asMap(taskSourcesMap).values());
    }
    return new PrestoSparkTaskSourceRdd(sparkContext.sc(), taskSourcesByPartitionId);
}
Also used : ArrayListMultimap(com.google.common.collect.ArrayListMultimap) WarningCollector(com.facebook.presto.spi.WarningCollector) JsonCodec(com.facebook.airlift.json.JsonCodec) ListMultimap(com.google.common.collect.ListMultimap) RemoteSourceNode(com.facebook.presto.sql.planner.plan.RemoteSourceNode) PrestoSparkTaskRdd(com.facebook.presto.spark.classloader_interface.PrestoSparkTaskRdd) SplitSourceFactory(com.facebook.presto.sql.planner.SplitSourceFactory) PrestoSparkUtils.serializeZstdCompressed(com.facebook.presto.spark.util.PrestoSparkUtils.serializeZstdCompressed) TableWriteInfo(com.facebook.presto.execution.scheduler.TableWriteInfo) Preconditions.checkArgument(com.google.common.base.Preconditions.checkArgument) Sets.difference(com.google.common.collect.Sets.difference) PlanFragment(com.facebook.presto.sql.planner.PlanFragment) MutablePartitionId(com.facebook.presto.spark.classloader_interface.MutablePartitionId) PrestoSparkShuffleStats(com.facebook.presto.spark.classloader_interface.PrestoSparkShuffleStats) Map(java.util.Map) Sets.union(com.google.common.collect.Sets.union) FIXED_BROADCAST_DISTRIBUTION(com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION) SplitSource(com.facebook.presto.split.SplitSource) Broadcast(org.apache.spark.broadcast.Broadcast) ImmutableSet(com.google.common.collect.ImmutableSet) Set(java.util.Set) SplitManager(com.facebook.presto.split.SplitManager) Tuple2(scala.Tuple2) SOURCE_DISTRIBUTION(com.facebook.presto.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION) Codec(com.facebook.airlift.json.Codec) String.format(java.lang.String.format) PrestoSparkTaskProcessor(com.facebook.presto.spark.classloader_interface.PrestoSparkTaskProcessor) DataSize(io.airlift.units.DataSize) List(java.util.List) ImmutableMap.toImmutableMap(com.google.common.collect.ImmutableMap.toImmutableMap) NOT_SUPPORTED(com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED) PrestoSparkTaskExecutorFactoryProvider(com.facebook.presto.spark.classloader_interface.PrestoSparkTaskExecutorFactoryProvider) SerializedPrestoSparkTaskDescriptor(com.facebook.presto.spark.classloader_interface.SerializedPrestoSparkTaskDescriptor) SerializedTaskInfo(com.facebook.presto.spark.classloader_interface.SerializedTaskInfo) Optional(java.util.Optional) FIXED_HASH_DISTRIBUTION(com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION) RDD(org.apache.spark.rdd.RDD) PrestoSparkUtils.classTag(com.facebook.presto.spark.util.PrestoSparkUtils.classTag) PlanNodeId(com.facebook.presto.spi.plan.PlanNodeId) ARBITRARY_DISTRIBUTION(com.facebook.presto.sql.planner.SystemPartitioningHandle.ARBITRARY_DISTRIBUTION) Logger(com.facebook.airlift.log.Logger) FIXED_ARBITRARY_DISTRIBUTION(com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION) SINGLE_DISTRIBUTION(com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) HashMap(java.util.HashMap) PrestoException(com.facebook.presto.spi.PrestoException) Multimaps(com.google.common.collect.Multimaps) ArrayList(java.util.ArrayList) Inject(javax.inject.Inject) PrestoSparkTaskSourceRdd(com.facebook.presto.spark.classloader_interface.PrestoSparkTaskSourceRdd) PrestoSparkTaskOutput(com.facebook.presto.spark.classloader_interface.PrestoSparkTaskOutput) ImmutableList(com.google.common.collect.ImmutableList) Objects.requireNonNull(java.util.Objects.requireNonNull) ImmutableSet.toImmutableSet(com.google.common.collect.ImmutableSet.toImmutableSet) ScheduledSplit(com.facebook.presto.execution.ScheduledSplit) PlanFragmentId(com.facebook.presto.sql.planner.plan.PlanFragmentId) CloseableSplitSourceProvider(com.facebook.presto.split.CloseableSplitSourceProvider) FIXED_PASSTHROUGH_DISTRIBUTION(com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_PASSTHROUGH_DISTRIBUTION) PlanNodeSearcher.searchFrom(com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom) PrestoSparkTaskDescriptor(com.facebook.presto.spark.PrestoSparkTaskDescriptor) SerializedPrestoSparkTaskSource(com.facebook.presto.spark.classloader_interface.SerializedPrestoSparkTaskSource) PrestoSparkMutableRow(com.facebook.presto.spark.classloader_interface.PrestoSparkMutableRow) Session(com.facebook.presto.Session) TaskSource(com.facebook.presto.execution.TaskSource) SCALED_WRITER_DISTRIBUTION(com.facebook.presto.sql.planner.SystemPartitioningHandle.SCALED_WRITER_DISTRIBUTION) CollectionAccumulator(org.apache.spark.util.CollectionAccumulator) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) SetMultimap(com.google.common.collect.SetMultimap) PlanNode(com.facebook.presto.spi.plan.PlanNode) TableScanNode(com.facebook.presto.spi.plan.TableScanNode) PartitioningHandle(com.facebook.presto.sql.planner.PartitioningHandle) PartitioningProviderManager(com.facebook.presto.sql.planner.PartitioningProviderManager) COORDINATOR_DISTRIBUTION(com.facebook.presto.sql.planner.SystemPartitioningHandle.COORDINATOR_DISTRIBUTION) ArrayList(java.util.ArrayList) SerializedPrestoSparkTaskSource(com.facebook.presto.spark.classloader_interface.SerializedPrestoSparkTaskSource) SetMultimap(com.google.common.collect.SetMultimap) TableScanNode(com.facebook.presto.spi.plan.TableScanNode) List(java.util.List) ArrayList(java.util.ArrayList) ImmutableList(com.google.common.collect.ImmutableList) SplitSource(com.facebook.presto.split.SplitSource) PrestoSparkTaskSourceRdd(com.facebook.presto.spark.classloader_interface.PrestoSparkTaskSourceRdd)

Example 4 with PlanNodeId

use of com.facebook.presto.spi.plan.PlanNodeId in project presto by prestodb.

the class TestPrestoSparkSourceDistributionSplitAssigner method assertSplitAssignment.

private static void assertSplitAssignment(boolean autoTuneEnabled, DataSize maxSplitsDataSizePerSparkPartition, int initialPartitionCount, int minSparkInputPartitionCountForAutoTune, int maxSparkInputPartitionCountForAutoTune, List<Long> splitSizes, Map<Integer, List<Long>> expectedAssignment) {
    // assign splits in one shot
    {
        PrestoSparkSplitAssigner assigner = new PrestoSparkSourceDistributionSplitAssigner(new PlanNodeId("test"), createSplitSource(splitSizes), Integer.MAX_VALUE, maxSplitsDataSizePerSparkPartition.toBytes(), initialPartitionCount, autoTuneEnabled, minSparkInputPartitionCountForAutoTune, maxSparkInputPartitionCountForAutoTune);
        Optional<SetMultimap<Integer, ScheduledSplit>> actualAssignment = assigner.getNextBatch();
        if (!splitSizes.isEmpty()) {
            assertThat(actualAssignment).isPresent();
            assertAssignedSplits(actualAssignment.get(), expectedAssignment);
        } else {
            assertThat(actualAssignment).isNotPresent();
        }
    }
    // assign splits iteratively
    for (int splitBatchSize = 1; splitBatchSize < splitSizes.size(); splitBatchSize *= 2) {
        HashMultimap<Integer, ScheduledSplit> actualAssignment = HashMultimap.create();
        // sort splits to make assignment match the assignment done in one shot
        List<Long> sortedSplits = new ArrayList<>(splitSizes);
        sortedSplits.sort(Comparator.<Long>naturalOrder().reversed());
        PrestoSparkSplitAssigner assigner = new PrestoSparkSourceDistributionSplitAssigner(new PlanNodeId("test"), createSplitSource(sortedSplits), splitBatchSize, maxSplitsDataSizePerSparkPartition.toBytes(), initialPartitionCount, autoTuneEnabled, minSparkInputPartitionCountForAutoTune, maxSparkInputPartitionCountForAutoTune);
        while (true) {
            Optional<SetMultimap<Integer, ScheduledSplit>> assignment = assigner.getNextBatch();
            if (!assignment.isPresent()) {
                break;
            }
            actualAssignment.putAll(assignment.get());
        }
        assertAssignedSplits(actualAssignment, expectedAssignment);
    }
}
Also used : ScheduledSplit(com.facebook.presto.execution.ScheduledSplit) Optional(java.util.Optional) ArrayList(java.util.ArrayList) PlanNodeId(com.facebook.presto.spi.plan.PlanNodeId) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) SetMultimap(com.google.common.collect.SetMultimap) OptionalLong(java.util.OptionalLong)

Example 5 with PlanNodeId

use of com.facebook.presto.spi.plan.PlanNodeId in project presto by prestodb.

the class TestHashJoinOperator method testFullOuterJoinWithEmptyLookupSource.

@Test(dataProvider = "hashJoinTestValues")
public void testFullOuterJoinWithEmptyLookupSource(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled) {
    TaskContext taskContext = createTaskContext();
    // build factory
    List<Type> buildTypes = ImmutableList.of(VARCHAR);
    RowPagesBuilder buildPages = rowPagesBuilder(buildHashEnabled, Ints.asList(0), buildTypes);
    BuildSideSetup buildSideSetup = setupBuildSide(parallelBuild, taskContext, Ints.asList(0), buildPages, Optional.empty(), false, SINGLE_STREAM_SPILLER_FACTORY);
    JoinBridgeManager<PartitionedLookupSourceFactory> lookupSourceFactoryManager = buildSideSetup.getLookupSourceFactoryManager();
    // probe factory
    List<Type> probeTypes = ImmutableList.of(VARCHAR);
    RowPagesBuilder probePages = rowPagesBuilder(probeHashEnabled, Ints.asList(0), probeTypes);
    List<Page> probeInput = probePages.row("a").row("b").row((String) null).row("c").build();
    OperatorFactory joinOperatorFactory = new LookupJoinOperators().fullOuterJoin(0, new PlanNodeId("test"), lookupSourceFactoryManager, probePages.getTypes(), Ints.asList(0), getHashChannelAsInt(probePages), Optional.empty(), OptionalInt.of(1), PARTITIONING_SPILLER_FACTORY);
    // build drivers and operators
    instantiateBuildDrivers(buildSideSetup, taskContext);
    buildLookupSource(buildSideSetup);
    // expected
    MaterializedResult expected = MaterializedResult.resultBuilder(taskContext.getSession(), concat(probeTypes, buildTypes)).row("a", null).row("b", null).row(null, null).row("c", null).build();
    assertOperatorEquals(joinOperatorFactory, taskContext.addPipelineContext(0, true, true, false).addDriverContext(), probeInput, expected, true, getHashChannels(probePages, buildPages));
}
Also used : TestingTaskContext(com.facebook.presto.testing.TestingTaskContext) RowPagesBuilder(com.facebook.presto.RowPagesBuilder) Page(com.facebook.presto.common.Page) PlanNodeId(com.facebook.presto.spi.plan.PlanNodeId) Type(com.facebook.presto.common.type.Type) HashBuilderOperatorFactory(com.facebook.presto.operator.HashBuilderOperator.HashBuilderOperatorFactory) LocalExchangeSinkOperatorFactory(com.facebook.presto.operator.exchange.LocalExchangeSinkOperator.LocalExchangeSinkOperatorFactory) LocalExchangeSourceOperatorFactory(com.facebook.presto.operator.exchange.LocalExchangeSourceOperator.LocalExchangeSourceOperatorFactory) PageBufferOperatorFactory(com.facebook.presto.operator.index.PageBufferOperator.PageBufferOperatorFactory) ValuesOperatorFactory(com.facebook.presto.operator.ValuesOperator.ValuesOperatorFactory) MaterializedResult(com.facebook.presto.testing.MaterializedResult) Test(org.testng.annotations.Test)

Aggregations

PlanNodeId (com.facebook.presto.spi.plan.PlanNodeId)204 Test (org.testng.annotations.Test)123 Page (com.facebook.presto.common.Page)83 MaterializedResult (com.facebook.presto.testing.MaterializedResult)52 Type (com.facebook.presto.common.type.Type)47 VariableReferenceExpression (com.facebook.presto.spi.relation.VariableReferenceExpression)43 ImmutableList (com.google.common.collect.ImmutableList)43 RowPagesBuilder (com.facebook.presto.RowPagesBuilder)39 DataSize (io.airlift.units.DataSize)39 Optional (java.util.Optional)35 ImmutableMap (com.google.common.collect.ImmutableMap)34 JoinNode (com.facebook.presto.sql.planner.plan.JoinNode)25 BIGINT (com.facebook.presto.common.type.BigintType.BIGINT)23 VariableStatsEstimate (com.facebook.presto.cost.VariableStatsEstimate)23 Split (com.facebook.presto.metadata.Split)23 OperatorFactory (com.facebook.presto.operator.OperatorFactory)23 PlanNodeStatsEstimate (com.facebook.presto.cost.PlanNodeStatsEstimate)22 RowExpression (com.facebook.presto.spi.relation.RowExpression)21 PlanMatchPattern.values (com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values)21 JOIN_DISTRIBUTION_TYPE (com.facebook.presto.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE)20