use of org.apache.spark.util.CollectionAccumulator in project presto by prestodb.
the class PrestoSparkQueryExecutionFactory method create.
@Override
public IPrestoSparkQueryExecution create(SparkContext sparkContext, PrestoSparkSession prestoSparkSession, Optional<String> sqlText, Optional<String> sqlLocation, Optional<String> sqlFileHexHash, Optional<String> sqlFileSizeInBytes, Optional<String> sparkQueueName, PrestoSparkTaskExecutorFactoryProvider executorFactoryProvider, Optional<String> queryStatusInfoOutputLocation, Optional<String> queryDataOutputLocation) {
PrestoSparkConfInitializer.checkInitialized(sparkContext);
String sql;
if (sqlText.isPresent()) {
checkArgument(!sqlLocation.isPresent(), "sqlText and sqlLocation should not be set at the same time");
sql = sqlText.get();
} else {
checkArgument(sqlLocation.isPresent(), "sqlText or sqlLocation must be present");
byte[] sqlFileBytes = metadataStorage.read(sqlLocation.get());
if (sqlFileSizeInBytes.isPresent()) {
if (Integer.valueOf(sqlFileSizeInBytes.get()) != sqlFileBytes.length) {
throw new PrestoException(MALFORMED_QUERY_FILE, format("sql file size %s is different from expected sqlFileSizeInBytes %s", sqlFileBytes.length, sqlFileSizeInBytes.get()));
}
}
if (sqlFileHexHash.isPresent()) {
try {
MessageDigest md = MessageDigest.getInstance("SHA-512");
String actualHexHashCode = BaseEncoding.base16().lowerCase().encode(md.digest(sqlFileBytes));
if (!sqlFileHexHash.get().equals(actualHexHashCode)) {
throw new PrestoException(MALFORMED_QUERY_FILE, format("actual hash code %s is different from expected sqlFileHexHash %s", actualHexHashCode, sqlFileHexHash.get()));
}
} catch (NoSuchAlgorithmException e) {
throw new PrestoException(GENERIC_INTERNAL_ERROR, "unsupported hash algorithm", e);
}
}
sql = new String(sqlFileBytes, UTF_8);
}
log.info("Query: %s", sql);
QueryStateTimer queryStateTimer = new QueryStateTimer(systemTicker());
queryStateTimer.beginPlanning();
QueryId queryId = queryIdGenerator.createNextQueryId();
log.info("Starting execution for presto query: %s", queryId);
System.out.printf("Query id: %s\n", queryId);
sparkContext.conf().set(PRESTO_QUERY_ID_CONFIG, queryId.getId());
SessionContext sessionContext = PrestoSparkSessionContext.createFromSessionInfo(prestoSparkSession, credentialsProviders, authenticatorProviders);
Session session = sessionSupplier.createSession(queryId, sessionContext);
session = sessionPropertyDefaults.newSessionWithDefaultProperties(session, Optional.empty(), Optional.empty());
WarningCollector warningCollector = warningCollectorFactory.create(getWarningHandlingLevel(session));
PlanAndMore planAndMore = null;
try {
TransactionId transactionId = transactionManager.beginTransaction(true);
session = session.beginTransactionId(transactionId, transactionManager, accessControl);
queryMonitor.queryCreatedEvent(new BasicQueryInfo(createQueryInfo(session, sql, PLANNING, Optional.empty(), sparkQueueName, Optional.empty(), queryStateTimer, Optional.empty(), warningCollector)));
// including queueing time
Duration queryMaxRunTime = getQueryMaxRunTime(session);
// excluding queueing time
Duration queryMaxExecutionTime = getQueryMaxExecutionTime(session);
// pick a smaller one as we are not tracking queueing for Presto on Spark
Duration queryTimeout = queryMaxRunTime.compareTo(queryMaxExecutionTime) < 0 ? queryMaxRunTime : queryMaxExecutionTime;
long queryCompletionDeadline = System.currentTimeMillis() + queryTimeout.toMillis();
settingsRequirements.verify(sparkContext, session);
queryStateTimer.beginAnalyzing();
PreparedQuery preparedQuery = queryPreparer.prepareQuery(session, sql, warningCollector);
Optional<QueryType> queryType = StatementUtils.getQueryType(preparedQuery.getStatement().getClass());
if (queryType.isPresent() && (queryType.get() == QueryType.DATA_DEFINITION)) {
queryStateTimer.endAnalysis();
DDLDefinitionTask<?> task = (DDLDefinitionTask<?>) ddlTasks.get(preparedQuery.getStatement().getClass());
return new PrestoSparkDataDefinitionExecution(task, preparedQuery.getStatement(), transactionManager, accessControl, metadata, session, queryStateTimer, warningCollector);
} else {
planAndMore = queryPlanner.createQueryPlan(session, preparedQuery, warningCollector);
SubPlan fragmentedPlan = planFragmenter.fragmentQueryPlan(session, planAndMore.getPlan(), warningCollector);
log.info(textDistributedPlan(fragmentedPlan, metadata.getFunctionAndTypeManager(), session, true));
fragmentedPlan = configureOutputPartitioning(session, fragmentedPlan);
TableWriteInfo tableWriteInfo = getTableWriteInfo(session, fragmentedPlan);
JavaSparkContext javaSparkContext = new JavaSparkContext(sparkContext);
CollectionAccumulator<SerializedTaskInfo> taskInfoCollector = new CollectionAccumulator<>();
taskInfoCollector.register(sparkContext, Option.empty(), false);
CollectionAccumulator<PrestoSparkShuffleStats> shuffleStatsCollector = new CollectionAccumulator<>();
shuffleStatsCollector.register(sparkContext, Option.empty(), false);
TempStorage tempStorage = tempStorageManager.getTempStorage(storageBasedBroadcastJoinStorage);
queryStateTimer.endAnalysis();
return new PrestoSparkQueryExecution(javaSparkContext, session, queryMonitor, taskInfoCollector, shuffleStatsCollector, prestoSparkTaskExecutorFactory, executorFactoryProvider, queryStateTimer, warningCollector, sql, planAndMore, fragmentedPlan, sparkQueueName, taskInfoCodec, sparkTaskDescriptorJsonCodec, queryStatusInfoJsonCodec, queryDataJsonCodec, rddFactory, tableWriteInfo, transactionManager, createPagesSerde(blockEncodingManager), executionExceptionFactory, queryTimeout, queryCompletionDeadline, metadataStorage, queryStatusInfoOutputLocation, queryDataOutputLocation, tempStorage, nodeMemoryConfig, waitTimeMetrics);
}
} catch (Throwable executionFailure) {
queryStateTimer.beginFinishing();
try {
rollback(session, transactionManager);
} catch (RuntimeException rollbackFailure) {
log.error(rollbackFailure, "Encountered error when performing rollback");
}
queryStateTimer.endQuery();
Optional<ExecutionFailureInfo> failureInfo = Optional.empty();
if (executionFailure instanceof PrestoSparkExecutionException) {
failureInfo = executionExceptionFactory.extractExecutionFailureInfo((PrestoSparkExecutionException) executionFailure);
verify(failureInfo.isPresent());
}
if (!failureInfo.isPresent()) {
failureInfo = Optional.of(toFailure(executionFailure));
}
try {
QueryInfo queryInfo = createQueryInfo(session, sql, FAILED, Optional.ofNullable(planAndMore), sparkQueueName, failureInfo, queryStateTimer, Optional.empty(), warningCollector);
queryMonitor.queryCompletedEvent(queryInfo);
if (queryStatusInfoOutputLocation.isPresent()) {
PrestoSparkQueryStatusInfo prestoSparkQueryStatusInfo = createPrestoSparkQueryInfo(queryInfo, Optional.ofNullable(planAndMore), warningCollector, OptionalLong.empty());
metadataStorage.write(queryStatusInfoOutputLocation.get(), queryStatusInfoJsonCodec.toJsonBytes(prestoSparkQueryStatusInfo));
}
} catch (RuntimeException eventFailure) {
log.error(eventFailure, "Error publishing query immediate failure event");
}
throw failureInfo.get().toFailure();
}
}
Aggregations