use of co.cask.cdap.api.spark.SparkExecutionContext in project cdap by caskdata.
the class SparkRuntimeUtils method initSparkMain.
/**
* Initialize a Spark main() method. This is the first method to be called from the main() method of any
* spark program.
*
* @return a {@link Cancellable} for releasing resources.
*/
public static Cancellable initSparkMain() {
final Thread mainThread = Thread.currentThread();
SparkClassLoader sparkClassLoader;
try {
sparkClassLoader = SparkClassLoader.findFromContext();
} catch (IllegalStateException e) {
sparkClassLoader = SparkClassLoader.create();
}
final ClassLoader oldClassLoader = ClassLoaders.setContextClassLoader(sparkClassLoader.getRuntimeContext().getProgramInvocationClassLoader());
final SparkExecutionContext sec = sparkClassLoader.getSparkExecutionContext(true);
final SparkRuntimeContext runtimeContext = sparkClassLoader.getRuntimeContext();
String executorServiceURI = System.getenv(CDAP_SPARK_EXECUTION_SERVICE_URI);
final Service driverService;
if (executorServiceURI != null) {
// Creates the SparkDriverService in distributed mode for heartbeating and tokens update
driverService = new SparkDriverService(URI.create(executorServiceURI), runtimeContext);
} else {
// In local mode, just create a no-op service for state transition.
driverService = new AbstractService() {
@Override
protected void doStart() {
notifyStarted();
}
@Override
protected void doStop() {
notifyStopped();
}
};
}
// Watch for stopping of the driver service.
// It can happen when a user program finished such that the Cancellable.cancel() returned by this method is called,
// or it can happen when it received a stop command (distributed mode) in the SparkDriverService via heartbeat.
// In local mode, the LocalSparkSubmitter calls the Cancellable.cancel() returned by this method directly
// (via SparkMainWraper).
// We use a service listener so that it can handle all cases.
final CountDownLatch mainThreadCallLatch = new CountDownLatch(1);
driverService.addListener(new ServiceListenerAdapter() {
@Override
public void terminated(Service.State from) {
handleStopped();
}
@Override
public void failed(Service.State from, Throwable failure) {
handleStopped();
}
private void handleStopped() {
// Avoid interrupt/join on the current thread
if (Thread.currentThread() != mainThread) {
mainThread.interrupt();
// If it is spark streaming, wait for the user class call returns from the main thread.
if (SparkRuntimeEnv.getStreamingContext().isDefined()) {
Uninterruptibles.awaitUninterruptibly(mainThreadCallLatch);
}
}
// Close the SparkExecutionContext (it will stop all the SparkContext and release all resources)
if (sec instanceof AutoCloseable) {
try {
((AutoCloseable) sec).close();
} catch (Exception e) {
// Just log. It shouldn't throw, and if it does (due to bug), nothing much can be done
LOG.warn("Exception raised when calling {}.close() for program run {}.", sec.getClass().getName(), runtimeContext.getProgramRunId(), e);
}
}
}
}, Threads.SAME_THREAD_EXECUTOR);
driverService.startAndWait();
return new Cancellable() {
@Override
public void cancel() {
// since it's the last thing that Spark main methhod would do.
if (Thread.currentThread() == mainThread) {
mainThreadCallLatch.countDown();
mainThread.setContextClassLoader(oldClassLoader);
}
driverService.stopAndWait();
}
};
}
Aggregations