use of com.facebook.presto.spi.plan.PlanNodeId in project presto by prestodb.
the class PrestoSparkTaskExecutorFactory method doCreate.
public <T extends PrestoSparkTaskOutput> IPrestoSparkTaskExecutor<T> doCreate(int partitionId, int attemptNumber, SerializedPrestoSparkTaskDescriptor serializedTaskDescriptor, Iterator<SerializedPrestoSparkTaskSource> serializedTaskSources, PrestoSparkTaskInputs inputs, CollectionAccumulator<SerializedTaskInfo> taskInfoCollector, CollectionAccumulator<PrestoSparkShuffleStats> shuffleStatsCollector, Class<T> outputType) {
PrestoSparkTaskDescriptor taskDescriptor = taskDescriptorJsonCodec.fromJson(serializedTaskDescriptor.getBytes());
ImmutableMap.Builder<String, TokenAuthenticator> extraAuthenticators = ImmutableMap.builder();
authenticatorProviders.forEach(provider -> extraAuthenticators.putAll(provider.getTokenAuthenticators()));
Session session = taskDescriptor.getSession().toSession(sessionPropertyManager, taskDescriptor.getExtraCredentials(), extraAuthenticators.build());
PlanFragment fragment = taskDescriptor.getFragment();
StageId stageId = new StageId(session.getQueryId(), fragment.getId().getId());
// Clear the cache if the cache does not have broadcast table for current stageId.
// We will only cache 1 HT at any time. If the stageId changes, we will drop the old cached HT
prestoSparkBroadcastTableCacheManager.removeCachedTablesForStagesOtherThan(stageId);
// TODO: include attemptId in taskId
TaskId taskId = new TaskId(new StageExecutionId(stageId, 0), partitionId);
List<TaskSource> taskSources = getTaskSources(serializedTaskSources);
log.info("Task [%s] received %d splits.", taskId, taskSources.stream().mapToInt(taskSource -> taskSource.getSplits().size()).sum());
OptionalLong totalSplitSize = computeAllSplitsSize(taskSources);
if (totalSplitSize.isPresent()) {
log.info("Total split size: %s bytes.", totalSplitSize.getAsLong());
}
// TODO: Remove this once we can display the plan on Spark UI.
log.info(PlanPrinter.textPlanFragment(fragment, functionAndTypeManager, session, true));
DataSize maxUserMemory = new DataSize(min(nodeMemoryConfig.getMaxQueryMemoryPerNode().toBytes(), getQueryMaxMemoryPerNode(session).toBytes()), BYTE);
DataSize maxTotalMemory = new DataSize(min(nodeMemoryConfig.getMaxQueryTotalMemoryPerNode().toBytes(), getQueryMaxTotalMemoryPerNode(session).toBytes()), BYTE);
DataSize maxBroadcastMemory = getSparkBroadcastJoinMaxMemoryOverride(session);
if (maxBroadcastMemory == null) {
maxBroadcastMemory = new DataSize(min(nodeMemoryConfig.getMaxQueryBroadcastMemory().toBytes(), getQueryMaxBroadcastMemory(session).toBytes()), BYTE);
}
MemoryPool memoryPool = new MemoryPool(new MemoryPoolId("spark-executor-memory-pool"), maxTotalMemory);
SpillSpaceTracker spillSpaceTracker = new SpillSpaceTracker(maxQuerySpillPerNode);
QueryContext queryContext = new QueryContext(session.getQueryId(), maxUserMemory, maxTotalMemory, maxBroadcastMemory, maxRevocableMemory, memoryPool, new TestingGcMonitor(), notificationExecutor, yieldExecutor, maxQuerySpillPerNode, spillSpaceTracker, memoryReservationSummaryJsonCodec);
queryContext.setVerboseExceededMemoryLimitErrorsEnabled(isVerboseExceededMemoryLimitErrorsEnabled(session));
queryContext.setHeapDumpOnExceededMemoryLimitEnabled(isHeapDumpOnExceededMemoryLimitEnabled(session));
String heapDumpFilePath = Paths.get(getHeapDumpFileDirectory(session), format("%s_%s.hprof", session.getQueryId().getId(), stageId.getId())).toString();
queryContext.setHeapDumpFilePath(heapDumpFilePath);
TaskStateMachine taskStateMachine = new TaskStateMachine(taskId, notificationExecutor);
TaskContext taskContext = queryContext.addTaskContext(taskStateMachine, session, // Plan has to be retained only if verbose memory exceeded errors are requested
isVerboseExceededMemoryLimitErrorsEnabled(session) ? Optional.of(fragment.getRoot()) : Optional.empty(), perOperatorCpuTimerEnabled, cpuTimerEnabled, perOperatorAllocationTrackingEnabled, allocationTrackingEnabled, false);
final double memoryRevokingThreshold = getMemoryRevokingThreshold(session);
final double memoryRevokingTarget = getMemoryRevokingTarget(session);
checkArgument(memoryRevokingTarget <= memoryRevokingThreshold, "memoryRevokingTarget should be less than or equal memoryRevokingThreshold, but got %s and %s respectively", memoryRevokingTarget, memoryRevokingThreshold);
if (isSpillEnabled(session)) {
memoryPool.addListener((pool, queryId, totalMemoryReservationBytes) -> {
if (totalMemoryReservationBytes > queryContext.getPeakNodeTotalMemory()) {
queryContext.setPeakNodeTotalMemory(totalMemoryReservationBytes);
}
if (totalMemoryReservationBytes > pool.getMaxBytes() * memoryRevokingThreshold && memoryRevokeRequestInProgress.compareAndSet(false, true)) {
memoryRevocationExecutor.execute(() -> {
try {
AtomicLong remainingBytesToRevoke = new AtomicLong(totalMemoryReservationBytes - (long) (memoryRevokingTarget * pool.getMaxBytes()));
remainingBytesToRevoke.addAndGet(-MemoryRevokingSchedulerUtils.getMemoryAlreadyBeingRevoked(ImmutableList.of(taskContext), remainingBytesToRevoke.get()));
taskContext.accept(new VoidTraversingQueryContextVisitor<AtomicLong>() {
@Override
public Void visitOperatorContext(OperatorContext operatorContext, AtomicLong remainingBytesToRevoke) {
if (remainingBytesToRevoke.get() > 0) {
long revokedBytes = operatorContext.requestMemoryRevoking();
if (revokedBytes > 0) {
memoryRevokePending.set(true);
remainingBytesToRevoke.addAndGet(-revokedBytes);
}
}
return null;
}
}, remainingBytesToRevoke);
memoryRevokeRequestInProgress.set(false);
} catch (Exception e) {
log.error(e, "Error requesting memory revoking");
}
});
}
// Get the latest memory reservation info since it might have changed due to revoke
long totalReservedMemory = pool.getQueryMemoryReservation(queryId) + pool.getQueryRevocableMemoryReservation(queryId);
// If total memory usage is over maxTotalMemory and memory revoke request is not pending, fail the query with EXCEEDED_MEMORY_LIMIT error
if (totalReservedMemory > maxTotalMemory.toBytes() && !memoryRevokeRequestInProgress.get() && !isMemoryRevokePending(taskContext)) {
throw exceededLocalTotalMemoryLimit(maxTotalMemory, queryContext.getAdditionalFailureInfo(totalReservedMemory, 0) + format("Total reserved memory: %s, Total revocable memory: %s", succinctBytes(pool.getQueryMemoryReservation(queryId)), succinctBytes(pool.getQueryRevocableMemoryReservation(queryId))), isHeapDumpOnExceededMemoryLimitEnabled(session), Optional.ofNullable(heapDumpFilePath));
}
});
}
ImmutableMap.Builder<PlanNodeId, List<PrestoSparkShuffleInput>> shuffleInputs = ImmutableMap.builder();
ImmutableMap.Builder<PlanNodeId, List<java.util.Iterator<PrestoSparkSerializedPage>>> pageInputs = ImmutableMap.builder();
ImmutableMap.Builder<PlanNodeId, List<?>> broadcastInputs = ImmutableMap.builder();
for (RemoteSourceNode remoteSource : fragment.getRemoteSourceNodes()) {
List<PrestoSparkShuffleInput> remoteSourceRowInputs = new ArrayList<>();
List<java.util.Iterator<PrestoSparkSerializedPage>> remoteSourcePageInputs = new ArrayList<>();
List<List<?>> broadcastInputsList = new ArrayList<>();
for (PlanFragmentId sourceFragmentId : remoteSource.getSourceFragmentIds()) {
Iterator<Tuple2<MutablePartitionId, PrestoSparkMutableRow>> shuffleInput = inputs.getShuffleInputs().get(sourceFragmentId.toString());
Broadcast<?> broadcastInput = inputs.getBroadcastInputs().get(sourceFragmentId.toString());
List<PrestoSparkSerializedPage> inMemoryInput = inputs.getInMemoryInputs().get(sourceFragmentId.toString());
if (shuffleInput != null) {
checkArgument(broadcastInput == null, "single remote source is not expected to accept different kind of inputs");
checkArgument(inMemoryInput == null, "single remote source is not expected to accept different kind of inputs");
remoteSourceRowInputs.add(new PrestoSparkShuffleInput(sourceFragmentId.getId(), shuffleInput));
continue;
}
if (broadcastInput != null) {
checkArgument(inMemoryInput == null, "single remote source is not expected to accept different kind of inputs");
// TODO: Enable NullifyingIterator once migrated to one task per JVM model
// NullifyingIterator removes element from the list upon return
// This allows GC to gradually reclaim memory
// remoteSourcePageInputs.add(getNullifyingIterator(broadcastInput.value()));
broadcastInputsList.add((List<?>) broadcastInput.value());
continue;
}
if (inMemoryInput != null) {
// for inmemory inputs pages can be released incrementally to save memory
remoteSourcePageInputs.add(getNullifyingIterator(inMemoryInput));
continue;
}
throw new IllegalArgumentException("Input not found for sourceFragmentId: " + sourceFragmentId);
}
if (!remoteSourceRowInputs.isEmpty()) {
shuffleInputs.put(remoteSource.getId(), remoteSourceRowInputs);
}
if (!remoteSourcePageInputs.isEmpty()) {
pageInputs.put(remoteSource.getId(), remoteSourcePageInputs);
}
if (!broadcastInputsList.isEmpty()) {
broadcastInputs.put(remoteSource.getId(), broadcastInputsList);
}
}
OutputBufferMemoryManager memoryManager = new OutputBufferMemoryManager(sinkMaxBufferSize.toBytes(), () -> queryContext.getTaskContextByTaskId(taskId).localSystemMemoryContext(), notificationExecutor);
Optional<OutputPartitioning> preDeterminedPartition = Optional.empty();
if (fragment.getPartitioningScheme().getPartitioning().getHandle().equals(FIXED_ARBITRARY_DISTRIBUTION)) {
int partitionCount = getHashPartitionCount(session);
preDeterminedPartition = Optional.of(new OutputPartitioning(new PreDeterminedPartitionFunction(partitionId % partitionCount, partitionCount), ImmutableList.of(), ImmutableList.of(), false, OptionalInt.empty()));
}
TempDataOperationContext tempDataOperationContext = new TempDataOperationContext(session.getSource(), session.getQueryId().getId(), session.getClientInfo(), Optional.of(session.getClientTags()), session.getIdentity());
TempStorage tempStorage = tempStorageManager.getTempStorage(storageBasedBroadcastJoinStorage);
Output<T> output = configureOutput(outputType, blockEncodingManager, memoryManager, getShuffleOutputTargetAverageRowSize(session), preDeterminedPartition, tempStorage, tempDataOperationContext, getStorageBasedBroadcastJoinWriteBufferSize(session));
PrestoSparkOutputBuffer<?> outputBuffer = output.getOutputBuffer();
LocalExecutionPlan localExecutionPlan = localExecutionPlanner.plan(taskContext, fragment.getRoot(), fragment.getPartitioningScheme(), fragment.getStageExecutionDescriptor(), fragment.getTableScanSchedulingOrder(), output.getOutputFactory(), new PrestoSparkRemoteSourceFactory(blockEncodingManager, shuffleInputs.build(), pageInputs.build(), broadcastInputs.build(), partitionId, shuffleStatsCollector, tempStorage, tempDataOperationContext, prestoSparkBroadcastTableCacheManager, stageId), taskDescriptor.getTableWriteInfo(), true);
taskStateMachine.addStateChangeListener(state -> {
if (state.isDone()) {
outputBuffer.setNoMoreRows();
}
});
PrestoSparkTaskExecution taskExecution = new PrestoSparkTaskExecution(taskStateMachine, taskContext, localExecutionPlan, taskExecutor, splitMonitor, notificationExecutor, memoryUpdateExecutor);
taskExecution.start(taskSources);
return new PrestoSparkTaskExecutor<>(taskContext, taskStateMachine, output.getOutputSupplier(), taskInfoCodec, taskInfoCollector, shuffleStatsCollector, executionExceptionFactory, output.getOutputBufferType(), outputBuffer, tempStorage, tempDataOperationContext);
}
use of com.facebook.presto.spi.plan.PlanNodeId in project presto by prestodb.
the class PrestoSparkRddFactory method createRdd.
private <T extends PrestoSparkTaskOutput> JavaPairRDD<MutablePartitionId, T> createRdd(JavaSparkContext sparkContext, Session session, PlanFragment fragment, PrestoSparkTaskExecutorFactoryProvider executorFactoryProvider, CollectionAccumulator<SerializedTaskInfo> taskInfoCollector, CollectionAccumulator<PrestoSparkShuffleStats> shuffleStatsCollector, TableWriteInfo tableWriteInfo, Map<PlanFragmentId, JavaPairRDD<MutablePartitionId, PrestoSparkMutableRow>> rddInputs, Map<PlanFragmentId, Broadcast<?>> broadcastInputs, Class<T> outputType) {
checkInputs(fragment.getRemoteSourceNodes(), rddInputs, broadcastInputs);
PrestoSparkTaskDescriptor taskDescriptor = new PrestoSparkTaskDescriptor(session.toSessionRepresentation(), session.getIdentity().getExtraCredentials(), fragment, tableWriteInfo);
SerializedPrestoSparkTaskDescriptor serializedTaskDescriptor = new SerializedPrestoSparkTaskDescriptor(taskDescriptorJsonCodec.toJsonBytes(taskDescriptor));
Optional<Integer> numberOfShufflePartitions = Optional.empty();
Map<String, RDD<Tuple2<MutablePartitionId, PrestoSparkMutableRow>>> shuffleInputRddMap = new HashMap<>();
for (Map.Entry<PlanFragmentId, JavaPairRDD<MutablePartitionId, PrestoSparkMutableRow>> input : rddInputs.entrySet()) {
RDD<Tuple2<MutablePartitionId, PrestoSparkMutableRow>> rdd = input.getValue().rdd();
shuffleInputRddMap.put(input.getKey().toString(), rdd);
if (!numberOfShufflePartitions.isPresent()) {
numberOfShufflePartitions = Optional.of(rdd.getNumPartitions());
} else {
checkArgument(numberOfShufflePartitions.get() == rdd.getNumPartitions(), "Incompatible number of input partitions: %s != %s", numberOfShufflePartitions.get(), rdd.getNumPartitions());
}
}
PrestoSparkTaskProcessor<T> taskProcessor = new PrestoSparkTaskProcessor<>(executorFactoryProvider, serializedTaskDescriptor, taskInfoCollector, shuffleStatsCollector, toTaskProcessorBroadcastInputs(broadcastInputs), outputType);
Optional<PrestoSparkTaskSourceRdd> taskSourceRdd;
List<TableScanNode> tableScans = findTableScanNodes(fragment.getRoot());
if (!tableScans.isEmpty()) {
try (CloseableSplitSourceProvider splitSourceProvider = new CloseableSplitSourceProvider(splitManager::getSplits)) {
SplitSourceFactory splitSourceFactory = new SplitSourceFactory(splitSourceProvider, WarningCollector.NOOP);
Map<PlanNodeId, SplitSource> splitSources = splitSourceFactory.createSplitSources(fragment, session, tableWriteInfo);
taskSourceRdd = Optional.of(createTaskSourcesRdd(fragment.getId(), sparkContext, session, fragment.getPartitioning(), tableScans, splitSources, numberOfShufflePartitions));
}
} else if (rddInputs.size() == 0) {
checkArgument(fragment.getPartitioning().equals(SINGLE_DISTRIBUTION), "SINGLE_DISTRIBUTION partitioning is expected: %s", fragment.getPartitioning());
// In case of no inputs we still need to schedule a task.
// Task with no inputs may produce results (e.g.: ValuesNode).
// To force the task to be scheduled we create a PrestoSparkTaskSourceRdd that contains exactly one partition.
// Since there's also no table scans in the fragment, the list of TaskSource's for this partition is empty.
taskSourceRdd = Optional.of(new PrestoSparkTaskSourceRdd(sparkContext.sc(), ImmutableList.of(ImmutableList.of())));
} else {
taskSourceRdd = Optional.empty();
}
return JavaPairRDD.fromRDD(PrestoSparkTaskRdd.create(sparkContext.sc(), taskSourceRdd, shuffleInputRddMap, taskProcessor), classTag(MutablePartitionId.class), classTag(outputType));
}
use of com.facebook.presto.spi.plan.PlanNodeId in project presto by prestodb.
the class PrestoSparkRddFactory method createTaskSourcesRdd.
private PrestoSparkTaskSourceRdd createTaskSourcesRdd(PlanFragmentId fragmentId, JavaSparkContext sparkContext, Session session, PartitioningHandle partitioning, List<TableScanNode> tableScans, Map<PlanNodeId, SplitSource> splitSources, Optional<Integer> numberOfShufflePartitions) {
ListMultimap<Integer, SerializedPrestoSparkTaskSource> taskSourcesMap = ArrayListMultimap.create();
for (TableScanNode tableScan : tableScans) {
int totalNumberOfSplits = 0;
SplitSource splitSource = requireNonNull(splitSources.get(tableScan.getId()), "split source is missing for table scan node with id: " + tableScan.getId());
try (PrestoSparkSplitAssigner splitAssigner = createSplitAssigner(session, tableScan.getId(), splitSource, partitioning)) {
while (true) {
Optional<SetMultimap<Integer, ScheduledSplit>> batch = splitAssigner.getNextBatch();
if (!batch.isPresent()) {
break;
}
int numberOfSplitsInCurrentBatch = batch.get().size();
log.info("Found %s splits for table scan node with id %s", numberOfSplitsInCurrentBatch, tableScan.getId());
totalNumberOfSplits += numberOfSplitsInCurrentBatch;
taskSourcesMap.putAll(createTaskSources(tableScan.getId(), batch.get()));
}
}
log.info("Total number of splits for table scan node with id %s: %s", tableScan.getId(), totalNumberOfSplits);
}
long allTaskSourcesSerializedSizeInBytes = taskSourcesMap.values().stream().mapToLong(serializedTaskSource -> serializedTaskSource.getBytes().length).sum();
log.info("Total serialized size of all task sources for fragment %s: %s", fragmentId, DataSize.succinctBytes(allTaskSourcesSerializedSizeInBytes));
List<List<SerializedPrestoSparkTaskSource>> taskSourcesByPartitionId = new ArrayList<>();
// If the fragment contains any shuffle inputs, this value will be present
if (numberOfShufflePartitions.isPresent()) {
// non bucketed tables match, an empty partition must be inserted if bucket is missing.
for (int partitionId = 0; partitionId < numberOfShufflePartitions.get(); partitionId++) {
// Eagerly remove task sources from the map to let GC reclaim the memory
// If task sources are missing for a partition the removeAll returns an empty list
taskSourcesByPartitionId.add(requireNonNull(taskSourcesMap.removeAll(partitionId), "taskSources is null"));
}
} else {
taskSourcesByPartitionId.addAll(Multimaps.asMap(taskSourcesMap).values());
}
return new PrestoSparkTaskSourceRdd(sparkContext.sc(), taskSourcesByPartitionId);
}
use of com.facebook.presto.spi.plan.PlanNodeId in project presto by prestodb.
the class TestPrestoSparkSourceDistributionSplitAssigner method assertSplitAssignment.
private static void assertSplitAssignment(boolean autoTuneEnabled, DataSize maxSplitsDataSizePerSparkPartition, int initialPartitionCount, int minSparkInputPartitionCountForAutoTune, int maxSparkInputPartitionCountForAutoTune, List<Long> splitSizes, Map<Integer, List<Long>> expectedAssignment) {
// assign splits in one shot
{
PrestoSparkSplitAssigner assigner = new PrestoSparkSourceDistributionSplitAssigner(new PlanNodeId("test"), createSplitSource(splitSizes), Integer.MAX_VALUE, maxSplitsDataSizePerSparkPartition.toBytes(), initialPartitionCount, autoTuneEnabled, minSparkInputPartitionCountForAutoTune, maxSparkInputPartitionCountForAutoTune);
Optional<SetMultimap<Integer, ScheduledSplit>> actualAssignment = assigner.getNextBatch();
if (!splitSizes.isEmpty()) {
assertThat(actualAssignment).isPresent();
assertAssignedSplits(actualAssignment.get(), expectedAssignment);
} else {
assertThat(actualAssignment).isNotPresent();
}
}
// assign splits iteratively
for (int splitBatchSize = 1; splitBatchSize < splitSizes.size(); splitBatchSize *= 2) {
HashMultimap<Integer, ScheduledSplit> actualAssignment = HashMultimap.create();
// sort splits to make assignment match the assignment done in one shot
List<Long> sortedSplits = new ArrayList<>(splitSizes);
sortedSplits.sort(Comparator.<Long>naturalOrder().reversed());
PrestoSparkSplitAssigner assigner = new PrestoSparkSourceDistributionSplitAssigner(new PlanNodeId("test"), createSplitSource(sortedSplits), splitBatchSize, maxSplitsDataSizePerSparkPartition.toBytes(), initialPartitionCount, autoTuneEnabled, minSparkInputPartitionCountForAutoTune, maxSparkInputPartitionCountForAutoTune);
while (true) {
Optional<SetMultimap<Integer, ScheduledSplit>> assignment = assigner.getNextBatch();
if (!assignment.isPresent()) {
break;
}
actualAssignment.putAll(assignment.get());
}
assertAssignedSplits(actualAssignment, expectedAssignment);
}
}
use of com.facebook.presto.spi.plan.PlanNodeId in project presto by prestodb.
the class TestHashJoinOperator method testFullOuterJoinWithEmptyLookupSource.
@Test(dataProvider = "hashJoinTestValues")
public void testFullOuterJoinWithEmptyLookupSource(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled) {
TaskContext taskContext = createTaskContext();
// build factory
List<Type> buildTypes = ImmutableList.of(VARCHAR);
RowPagesBuilder buildPages = rowPagesBuilder(buildHashEnabled, Ints.asList(0), buildTypes);
BuildSideSetup buildSideSetup = setupBuildSide(parallelBuild, taskContext, Ints.asList(0), buildPages, Optional.empty(), false, SINGLE_STREAM_SPILLER_FACTORY);
JoinBridgeManager<PartitionedLookupSourceFactory> lookupSourceFactoryManager = buildSideSetup.getLookupSourceFactoryManager();
// probe factory
List<Type> probeTypes = ImmutableList.of(VARCHAR);
RowPagesBuilder probePages = rowPagesBuilder(probeHashEnabled, Ints.asList(0), probeTypes);
List<Page> probeInput = probePages.row("a").row("b").row((String) null).row("c").build();
OperatorFactory joinOperatorFactory = new LookupJoinOperators().fullOuterJoin(0, new PlanNodeId("test"), lookupSourceFactoryManager, probePages.getTypes(), Ints.asList(0), getHashChannelAsInt(probePages), Optional.empty(), OptionalInt.of(1), PARTITIONING_SPILLER_FACTORY);
// build drivers and operators
instantiateBuildDrivers(buildSideSetup, taskContext);
buildLookupSource(buildSideSetup);
// expected
MaterializedResult expected = MaterializedResult.resultBuilder(taskContext.getSession(), concat(probeTypes, buildTypes)).row("a", null).row("b", null).row(null, null).row("c", null).build();
assertOperatorEquals(joinOperatorFactory, taskContext.addPipelineContext(0, true, true, false).addDriverContext(), probeInput, expected, true, getHashChannels(probePages, buildPages));
}
Aggregations