use of co.cask.cdap.etl.spark.function.PluginFunctionContext in project cdap by caskdata.
the class SparkPipelineRunner method runPipeline.
public void runPipeline(PipelinePhase pipelinePhase, String sourcePluginType, JavaSparkExecutionContext sec, Map<String, Integer> stagePartitions, PipelinePluginContext pluginContext) throws Exception {
MacroEvaluator macroEvaluator = new DefaultMacroEvaluator(sec.getWorkflowToken(), sec.getRuntimeArguments(), sec.getLogicalStartTime(), sec, sec.getNamespace());
Map<String, SparkCollection<Object>> stageDataCollections = new HashMap<>();
Map<String, SparkCollection<ErrorRecord<Object>>> stageErrorCollections = new HashMap<>();
// should never happen, but removes warning
if (pipelinePhase.getDag() == null) {
throw new IllegalStateException("Pipeline phase has no connections.");
}
for (String stageName : pipelinePhase.getDag().getTopologicalOrder()) {
StageInfo stageInfo = pipelinePhase.getStage(stageName);
//noinspection ConstantConditions
String pluginType = stageInfo.getPluginType();
// don't want to do an additional filter for stages that can emit errors,
// but aren't connected to an ErrorTransform
boolean hasErrorOutput = false;
Set<String> outputs = pipelinePhase.getStageOutputs(stageInfo.getName());
for (String output : outputs) {
//noinspection ConstantConditions
if (ErrorTransform.PLUGIN_TYPE.equals(pipelinePhase.getStage(output).getPluginType())) {
hasErrorOutput = true;
break;
}
}
SparkCollection<Object> stageData = null;
Map<String, SparkCollection<Object>> inputDataCollections = new HashMap<>();
Set<String> stageInputs = stageInfo.getInputs();
for (String inputStageName : stageInputs) {
inputDataCollections.put(inputStageName, stageDataCollections.get(inputStageName));
}
// initialize the stageRDD as the union of all input RDDs.
if (!inputDataCollections.isEmpty()) {
Iterator<SparkCollection<Object>> inputCollectionIter = inputDataCollections.values().iterator();
stageData = inputCollectionIter.next();
// don't union inputs records if we're joining or if we're processing errors
while (!BatchJoiner.PLUGIN_TYPE.equals(pluginType) && !ErrorTransform.PLUGIN_TYPE.equals(pluginType) && inputCollectionIter.hasNext()) {
stageData = stageData.union(inputCollectionIter.next());
}
}
SparkCollection<ErrorRecord<Object>> stageErrors = null;
PluginFunctionContext pluginFunctionContext = new PluginFunctionContext(stageInfo, sec);
if (stageData == null) {
// null in the other else-if conditions
if (sourcePluginType.equals(pluginType)) {
SparkCollection<Tuple2<Boolean, Object>> combinedData = getSource(stageInfo);
if (hasErrorOutput) {
// need to cache, otherwise the stage can be computed twice, once for output and once for errors.
combinedData.cache();
stageErrors = combinedData.flatMap(stageInfo, Compat.convert(new OutputFilter<>()));
}
stageData = combinedData.flatMap(stageInfo, Compat.convert(new ErrorFilter<>()));
} else {
throw new IllegalStateException(String.format("Stage '%s' has no input and is not a source.", stageName));
}
} else if (BatchSink.PLUGIN_TYPE.equals(pluginType)) {
stageData.store(stageInfo, Compat.convert(new BatchSinkFunction(pluginFunctionContext)));
} else if (Transform.PLUGIN_TYPE.equals(pluginType)) {
SparkCollection<Tuple2<Boolean, Object>> combinedData = stageData.transform(stageInfo);
if (hasErrorOutput) {
// need to cache, otherwise the stage can be computed twice, once for output and once for errors.
combinedData.cache();
stageErrors = combinedData.flatMap(stageInfo, Compat.convert(new OutputFilter<>()));
}
stageData = combinedData.flatMap(stageInfo, Compat.convert(new ErrorFilter<>()));
} else if (ErrorTransform.PLUGIN_TYPE.equals(pluginType)) {
// union all the errors coming into this stage
SparkCollection<ErrorRecord<Object>> inputErrors = null;
for (String inputStage : stageInputs) {
SparkCollection<ErrorRecord<Object>> inputErrorsFromStage = stageErrorCollections.get(inputStage);
if (inputErrorsFromStage == null) {
continue;
}
if (inputErrors == null) {
inputErrors = inputErrorsFromStage;
} else {
inputErrors = inputErrors.union(inputErrorsFromStage);
}
}
if (inputErrors != null) {
SparkCollection<Tuple2<Boolean, Object>> combinedData = inputErrors.flatMap(stageInfo, Compat.convert(new ErrorTransformFunction<>(pluginFunctionContext)));
if (hasErrorOutput) {
// need to cache, otherwise the stage can be computed twice, once for output and once for errors.
combinedData.cache();
stageErrors = combinedData.flatMap(stageInfo, Compat.convert(new OutputFilter<>()));
}
stageData = combinedData.flatMap(stageInfo, Compat.convert(new ErrorFilter<>()));
}
} else if (SparkCompute.PLUGIN_TYPE.equals(pluginType)) {
SparkCompute<Object, Object> sparkCompute = pluginContext.newPluginInstance(stageName, macroEvaluator);
stageData = stageData.compute(stageInfo, sparkCompute);
} else if (SparkSink.PLUGIN_TYPE.equals(pluginType)) {
SparkSink<Object> sparkSink = pluginContext.newPluginInstance(stageName, macroEvaluator);
stageData.store(stageInfo, sparkSink);
} else if (BatchAggregator.PLUGIN_TYPE.equals(pluginType)) {
Integer partitions = stagePartitions.get(stageName);
SparkCollection<Tuple2<Boolean, Object>> combinedData = stageData.aggregate(stageInfo, partitions);
if (hasErrorOutput) {
// need to cache, otherwise the stage can be computed twice, once for output and once for errors.
combinedData.cache();
stageErrors = combinedData.flatMap(stageInfo, Compat.convert(new OutputFilter<>()));
}
stageData = combinedData.flatMap(stageInfo, Compat.convert(new ErrorFilter<>()));
} else if (BatchJoiner.PLUGIN_TYPE.equals(pluginType)) {
BatchJoiner<Object, Object, Object> joiner = pluginContext.newPluginInstance(stageName, macroEvaluator);
BatchJoinerRuntimeContext joinerRuntimeContext = pluginFunctionContext.createBatchRuntimeContext();
joiner.initialize(joinerRuntimeContext);
Map<String, SparkPairCollection<Object, Object>> preJoinStreams = new HashMap<>();
for (Map.Entry<String, SparkCollection<Object>> inputStreamEntry : inputDataCollections.entrySet()) {
String inputStage = inputStreamEntry.getKey();
SparkCollection<Object> inputStream = inputStreamEntry.getValue();
preJoinStreams.put(inputStage, addJoinKey(stageInfo, inputStage, inputStream));
}
Set<String> remainingInputs = new HashSet<>();
remainingInputs.addAll(inputDataCollections.keySet());
Integer numPartitions = stagePartitions.get(stageName);
SparkPairCollection<Object, List<JoinElement<Object>>> joinedInputs = null;
// inner join on required inputs
for (final String inputStageName : joiner.getJoinConfig().getRequiredInputs()) {
SparkPairCollection<Object, Object> preJoinCollection = preJoinStreams.get(inputStageName);
if (joinedInputs == null) {
joinedInputs = preJoinCollection.mapValues(new InitialJoinFunction<>(inputStageName));
} else {
JoinFlattenFunction<Object> joinFlattenFunction = new JoinFlattenFunction<>(inputStageName);
joinedInputs = numPartitions == null ? joinedInputs.join(preJoinCollection).mapValues(joinFlattenFunction) : joinedInputs.join(preJoinCollection, numPartitions).mapValues(joinFlattenFunction);
}
remainingInputs.remove(inputStageName);
}
// outer join on non-required inputs
boolean isFullOuter = joinedInputs == null;
for (final String inputStageName : remainingInputs) {
SparkPairCollection<Object, Object> preJoinStream = preJoinStreams.get(inputStageName);
if (joinedInputs == null) {
joinedInputs = preJoinStream.mapValues(new InitialJoinFunction<>(inputStageName));
} else {
if (isFullOuter) {
OuterJoinFlattenFunction<Object> flattenFunction = new OuterJoinFlattenFunction<>(inputStageName);
joinedInputs = numPartitions == null ? joinedInputs.fullOuterJoin(preJoinStream).mapValues(flattenFunction) : joinedInputs.fullOuterJoin(preJoinStream, numPartitions).mapValues(flattenFunction);
} else {
LeftJoinFlattenFunction<Object> flattenFunction = new LeftJoinFlattenFunction<>(inputStageName);
joinedInputs = numPartitions == null ? joinedInputs.leftOuterJoin(preJoinStream).mapValues(flattenFunction) : joinedInputs.leftOuterJoin(preJoinStream, numPartitions).mapValues(flattenFunction);
}
}
}
// should never happen, but removes warnings
if (joinedInputs == null) {
throw new IllegalStateException("There are no inputs into join stage " + stageName);
}
stageData = mergeJoinResults(stageInfo, joinedInputs).cache();
} else if (Windower.PLUGIN_TYPE.equals(pluginType)) {
Windower windower = pluginContext.newPluginInstance(stageName, macroEvaluator);
stageData = stageData.window(stageInfo, windower);
} else {
throw new IllegalStateException(String.format("Stage %s is of unsupported plugin type %s.", stageName, pluginType));
}
if (shouldCache(pipelinePhase, stageInfo)) {
stageData = stageData.cache();
if (stageErrors != null) {
stageErrors = stageErrors.cache();
}
}
stageDataCollections.put(stageName, stageData);
stageErrorCollections.put(stageName, stageErrors);
}
}
use of co.cask.cdap.etl.spark.function.PluginFunctionContext in project cdap by caskdata.
the class SparkStreamingPipelineRunner method getSource.
@Override
protected SparkCollection<RecordInfo<Object>> getSource(StageSpec stageSpec, StageStatisticsCollector collector) throws Exception {
StreamingSource<Object> source;
if (checkpointsDisabled) {
PluginFunctionContext pluginFunctionContext = new PluginFunctionContext(stageSpec, sec, collector);
source = pluginFunctionContext.createPlugin();
} else {
// check for macros in any StreamingSource. If checkpoints are enabled,
// SparkStreaming will serialize all InputDStreams created in the checkpoint, which means
// the InputDStream is deserialized directly from the checkpoint instead of instantiated through CDAP.
// This means there isn't any way for us to perform macro evaluation on sources when they are loaded from
// checkpoints. We can work around this in all other pipeline stages by dynamically instantiating the
// plugin in all DStream functions, but can't for InputDStreams because the InputDStream constructor
// adds itself to the context dag. Yay for constructors with global side effects.
// TODO: (HYDRATOR-1030) figure out how to do this at configure time instead of run time
MacroEvaluator macroEvaluator = new ErrorMacroEvaluator("Due to spark limitations, macro evaluation is not allowed in streaming sources when checkpointing " + "is enabled.");
PluginContext pluginContext = new SparkPipelinePluginContext(sec.getPluginContext(), sec.getMetrics(), spec.isStageLoggingEnabled(), spec.isProcessTimingEnabled());
source = pluginContext.newPluginInstance(stageSpec.getName(), macroEvaluator);
}
DataTracer dataTracer = sec.getDataTracer(stageSpec.getName());
StreamingContext sourceContext = new DefaultStreamingContext(stageSpec, sec, streamingContext);
JavaDStream<Object> javaDStream = source.getStream(sourceContext);
if (dataTracer.isEnabled()) {
// it will create a new function for each RDD, which would limit each RDD but not the entire DStream.
javaDStream = javaDStream.transform(new LimitingFunction<>(spec.getNumOfRecordsPreview()));
}
JavaDStream<RecordInfo<Object>> outputDStream = javaDStream.transform(new CountingTransformFunction<>(stageSpec.getName(), sec.getMetrics(), "records.out", dataTracer)).map(new WrapOutputTransformFunction<>(stageSpec.getName()));
return new DStreamCollection<>(sec, outputDStream);
}
use of co.cask.cdap.etl.spark.function.PluginFunctionContext in project cdap by caskdata.
the class StreamingBatchSinkFunction method call.
@Override
public Void call(JavaRDD<T> data, Time batchTime) throws Exception {
if (data.isEmpty()) {
return null;
}
final long logicalStartTime = batchTime.milliseconds();
MacroEvaluator evaluator = new DefaultMacroEvaluator(new BasicArguments(sec), logicalStartTime, sec.getSecureStore(), sec.getNamespace());
PluginContext pluginContext = new SparkPipelinePluginContext(sec.getPluginContext(), sec.getMetrics(), stageSpec.isStageLoggingEnabled(), stageSpec.isProcessTimingEnabled());
final SparkBatchSinkFactory sinkFactory = new SparkBatchSinkFactory();
final String stageName = stageSpec.getName();
final BatchSink<Object, Object, Object> batchSink = pluginContext.newPluginInstance(stageName, evaluator);
final PipelineRuntime pipelineRuntime = new SparkPipelineRuntime(sec, logicalStartTime);
boolean isPrepared = false;
boolean isDone = false;
try {
sec.execute(new TxRunnable() {
@Override
public void run(DatasetContext datasetContext) throws Exception {
SparkBatchSinkContext sinkContext = new SparkBatchSinkContext(sinkFactory, sec, datasetContext, pipelineRuntime, stageSpec);
batchSink.prepareRun(sinkContext);
}
});
isPrepared = true;
PluginFunctionContext pluginFunctionContext = new PluginFunctionContext(stageSpec, sec, pipelineRuntime.getArguments().asMap(), batchTime.milliseconds(), new NoopStageStatisticsCollector());
PairFlatMapFunc<T, Object, Object> sinkFunction = new BatchSinkFunction<T, Object, Object>(pluginFunctionContext);
sinkFactory.writeFromRDD(data.flatMapToPair(Compat.convert(sinkFunction)), sec, stageName, Object.class, Object.class);
isDone = true;
sec.execute(new TxRunnable() {
@Override
public void run(DatasetContext datasetContext) throws Exception {
SparkBatchSinkContext sinkContext = new SparkBatchSinkContext(sinkFactory, sec, datasetContext, pipelineRuntime, stageSpec);
batchSink.onRunFinish(true, sinkContext);
}
});
} catch (Exception e) {
LOG.error("Error writing to sink {} for the batch for time {}.", stageName, logicalStartTime, e);
} finally {
if (isPrepared && !isDone) {
sec.execute(new TxRunnable() {
@Override
public void run(DatasetContext datasetContext) throws Exception {
SparkBatchSinkContext sinkContext = new SparkBatchSinkContext(sinkFactory, sec, datasetContext, pipelineRuntime, stageSpec);
batchSink.onRunFinish(false, sinkContext);
}
});
}
}
return null;
}
use of co.cask.cdap.etl.spark.function.PluginFunctionContext in project cdap by caskdata.
the class RDDCollection method aggregate.
@Override
public SparkCollection<RecordInfo<Object>> aggregate(StageSpec stageSpec, @Nullable Integer partitions, StageStatisticsCollector collector) {
PluginFunctionContext pluginFunctionContext = new PluginFunctionContext(stageSpec, sec, collector);
PairFlatMapFunc<T, Object, T> groupByFunction = new AggregatorGroupByFunction<>(pluginFunctionContext);
PairFlatMapFunction<T, Object, T> sparkGroupByFunction = Compat.convert(groupByFunction);
JavaPairRDD<Object, T> keyedCollection = rdd.flatMapToPair(sparkGroupByFunction);
JavaPairRDD<Object, Iterable<T>> groupedCollection = partitions == null ? keyedCollection.groupByKey() : keyedCollection.groupByKey(partitions);
FlatMapFunc<Tuple2<Object, Iterable<T>>, RecordInfo<Object>> aggregateFunction = new AggregatorAggregateFunction<>(pluginFunctionContext);
FlatMapFunction<Tuple2<Object, Iterable<T>>, RecordInfo<Object>> sparkAggregateFunction = Compat.convert(aggregateFunction);
return wrap(groupedCollection.flatMap(sparkAggregateFunction));
}
use of co.cask.cdap.etl.spark.function.PluginFunctionContext in project cdap by caskdata.
the class DynamicDriverContext method readExternal.
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
serializationVersion = in.readUTF();
stageSpec = (StageSpec) in.readObject();
sec = (JavaSparkExecutionContext) in.readObject();
// we intentionally do not serialize this context in order to ensure that the runtime arguments
// and logical start time are picked up from the JavaSparkExecutionContext. If we serialized it,
// the arguments and start time of the very first pipeline run would get serialized, then
// used for every subsequent run that loads from the checkpoint.
pluginFunctionContext = new PluginFunctionContext(stageSpec, sec, new NoopStageStatisticsCollector());
}
Aggregations