use of co.cask.cdap.etl.planner.StageInfo in project cdap by caskdata.
the class TransformExecutorFactory method addTransformation.
private void addTransformation(PipelinePhase pipeline, String stageName, Map<String, PipeTransformDetail> transformations, Map<String, ErrorOutputWriter<Object, Object>> transformErrorSinkMap) throws Exception {
StageInfo stageInfo = pipeline.getStage(stageName);
String pluginType = stageInfo.getPluginType();
ErrorOutputWriter<Object, Object> errorOutputWriter = transformErrorSinkMap.containsKey(stageName) ? transformErrorSinkMap.get(stageName) : null;
// If stageName is a connector source, it will have stageName along with record so use ConnectorSourceEmitter
if (pipeline.getSources().contains(stageName) && pluginType.equals(Constants.CONNECTOR_TYPE)) {
transformations.put(stageName, new PipeTransformDetail(stageName, true, false, getTransformation(stageInfo), new ConnectorSourceEmitter(stageName)));
} else if (pluginType.equals(BatchJoiner.PLUGIN_TYPE) && isMapPhase) {
// Do not remove stageName only for Map phase of BatchJoiner
transformations.put(stageName, new PipeTransformDetail(stageName, false, false, getTransformation(stageInfo), new TransformEmitter(stageName, errorOutputWriter)));
} else {
boolean isErrorConsumer = ErrorTransform.PLUGIN_TYPE.equals(pluginType);
transformations.put(stageName, new PipeTransformDetail(stageName, true, isErrorConsumer, getTransformation(stageInfo), new TransformEmitter(stageName, errorOutputWriter)));
}
}
use of co.cask.cdap.etl.planner.StageInfo in project cdap by caskdata.
the class SparkPipelineRunner method runPipeline.
public void runPipeline(PipelinePhase pipelinePhase, String sourcePluginType, JavaSparkExecutionContext sec, Map<String, Integer> stagePartitions, PipelinePluginContext pluginContext) throws Exception {
MacroEvaluator macroEvaluator = new DefaultMacroEvaluator(sec.getWorkflowToken(), sec.getRuntimeArguments(), sec.getLogicalStartTime(), sec, sec.getNamespace());
Map<String, SparkCollection<Object>> stageDataCollections = new HashMap<>();
Map<String, SparkCollection<ErrorRecord<Object>>> stageErrorCollections = new HashMap<>();
// should never happen, but removes warning
if (pipelinePhase.getDag() == null) {
throw new IllegalStateException("Pipeline phase has no connections.");
}
for (String stageName : pipelinePhase.getDag().getTopologicalOrder()) {
StageInfo stageInfo = pipelinePhase.getStage(stageName);
//noinspection ConstantConditions
String pluginType = stageInfo.getPluginType();
// don't want to do an additional filter for stages that can emit errors,
// but aren't connected to an ErrorTransform
boolean hasErrorOutput = false;
Set<String> outputs = pipelinePhase.getStageOutputs(stageInfo.getName());
for (String output : outputs) {
//noinspection ConstantConditions
if (ErrorTransform.PLUGIN_TYPE.equals(pipelinePhase.getStage(output).getPluginType())) {
hasErrorOutput = true;
break;
}
}
SparkCollection<Object> stageData = null;
Map<String, SparkCollection<Object>> inputDataCollections = new HashMap<>();
Set<String> stageInputs = stageInfo.getInputs();
for (String inputStageName : stageInputs) {
inputDataCollections.put(inputStageName, stageDataCollections.get(inputStageName));
}
// initialize the stageRDD as the union of all input RDDs.
if (!inputDataCollections.isEmpty()) {
Iterator<SparkCollection<Object>> inputCollectionIter = inputDataCollections.values().iterator();
stageData = inputCollectionIter.next();
// don't union inputs records if we're joining or if we're processing errors
while (!BatchJoiner.PLUGIN_TYPE.equals(pluginType) && !ErrorTransform.PLUGIN_TYPE.equals(pluginType) && inputCollectionIter.hasNext()) {
stageData = stageData.union(inputCollectionIter.next());
}
}
SparkCollection<ErrorRecord<Object>> stageErrors = null;
PluginFunctionContext pluginFunctionContext = new PluginFunctionContext(stageInfo, sec);
if (stageData == null) {
// null in the other else-if conditions
if (sourcePluginType.equals(pluginType)) {
SparkCollection<Tuple2<Boolean, Object>> combinedData = getSource(stageInfo);
if (hasErrorOutput) {
// need to cache, otherwise the stage can be computed twice, once for output and once for errors.
combinedData.cache();
stageErrors = combinedData.flatMap(stageInfo, Compat.convert(new OutputFilter<>()));
}
stageData = combinedData.flatMap(stageInfo, Compat.convert(new ErrorFilter<>()));
} else {
throw new IllegalStateException(String.format("Stage '%s' has no input and is not a source.", stageName));
}
} else if (BatchSink.PLUGIN_TYPE.equals(pluginType)) {
stageData.store(stageInfo, Compat.convert(new BatchSinkFunction(pluginFunctionContext)));
} else if (Transform.PLUGIN_TYPE.equals(pluginType)) {
SparkCollection<Tuple2<Boolean, Object>> combinedData = stageData.transform(stageInfo);
if (hasErrorOutput) {
// need to cache, otherwise the stage can be computed twice, once for output and once for errors.
combinedData.cache();
stageErrors = combinedData.flatMap(stageInfo, Compat.convert(new OutputFilter<>()));
}
stageData = combinedData.flatMap(stageInfo, Compat.convert(new ErrorFilter<>()));
} else if (ErrorTransform.PLUGIN_TYPE.equals(pluginType)) {
// union all the errors coming into this stage
SparkCollection<ErrorRecord<Object>> inputErrors = null;
for (String inputStage : stageInputs) {
SparkCollection<ErrorRecord<Object>> inputErrorsFromStage = stageErrorCollections.get(inputStage);
if (inputErrorsFromStage == null) {
continue;
}
if (inputErrors == null) {
inputErrors = inputErrorsFromStage;
} else {
inputErrors = inputErrors.union(inputErrorsFromStage);
}
}
if (inputErrors != null) {
SparkCollection<Tuple2<Boolean, Object>> combinedData = inputErrors.flatMap(stageInfo, Compat.convert(new ErrorTransformFunction<>(pluginFunctionContext)));
if (hasErrorOutput) {
// need to cache, otherwise the stage can be computed twice, once for output and once for errors.
combinedData.cache();
stageErrors = combinedData.flatMap(stageInfo, Compat.convert(new OutputFilter<>()));
}
stageData = combinedData.flatMap(stageInfo, Compat.convert(new ErrorFilter<>()));
}
} else if (SparkCompute.PLUGIN_TYPE.equals(pluginType)) {
SparkCompute<Object, Object> sparkCompute = pluginContext.newPluginInstance(stageName, macroEvaluator);
stageData = stageData.compute(stageInfo, sparkCompute);
} else if (SparkSink.PLUGIN_TYPE.equals(pluginType)) {
SparkSink<Object> sparkSink = pluginContext.newPluginInstance(stageName, macroEvaluator);
stageData.store(stageInfo, sparkSink);
} else if (BatchAggregator.PLUGIN_TYPE.equals(pluginType)) {
Integer partitions = stagePartitions.get(stageName);
SparkCollection<Tuple2<Boolean, Object>> combinedData = stageData.aggregate(stageInfo, partitions);
if (hasErrorOutput) {
// need to cache, otherwise the stage can be computed twice, once for output and once for errors.
combinedData.cache();
stageErrors = combinedData.flatMap(stageInfo, Compat.convert(new OutputFilter<>()));
}
stageData = combinedData.flatMap(stageInfo, Compat.convert(new ErrorFilter<>()));
} else if (BatchJoiner.PLUGIN_TYPE.equals(pluginType)) {
BatchJoiner<Object, Object, Object> joiner = pluginContext.newPluginInstance(stageName, macroEvaluator);
BatchJoinerRuntimeContext joinerRuntimeContext = pluginFunctionContext.createBatchRuntimeContext();
joiner.initialize(joinerRuntimeContext);
Map<String, SparkPairCollection<Object, Object>> preJoinStreams = new HashMap<>();
for (Map.Entry<String, SparkCollection<Object>> inputStreamEntry : inputDataCollections.entrySet()) {
String inputStage = inputStreamEntry.getKey();
SparkCollection<Object> inputStream = inputStreamEntry.getValue();
preJoinStreams.put(inputStage, addJoinKey(stageInfo, inputStage, inputStream));
}
Set<String> remainingInputs = new HashSet<>();
remainingInputs.addAll(inputDataCollections.keySet());
Integer numPartitions = stagePartitions.get(stageName);
SparkPairCollection<Object, List<JoinElement<Object>>> joinedInputs = null;
// inner join on required inputs
for (final String inputStageName : joiner.getJoinConfig().getRequiredInputs()) {
SparkPairCollection<Object, Object> preJoinCollection = preJoinStreams.get(inputStageName);
if (joinedInputs == null) {
joinedInputs = preJoinCollection.mapValues(new InitialJoinFunction<>(inputStageName));
} else {
JoinFlattenFunction<Object> joinFlattenFunction = new JoinFlattenFunction<>(inputStageName);
joinedInputs = numPartitions == null ? joinedInputs.join(preJoinCollection).mapValues(joinFlattenFunction) : joinedInputs.join(preJoinCollection, numPartitions).mapValues(joinFlattenFunction);
}
remainingInputs.remove(inputStageName);
}
// outer join on non-required inputs
boolean isFullOuter = joinedInputs == null;
for (final String inputStageName : remainingInputs) {
SparkPairCollection<Object, Object> preJoinStream = preJoinStreams.get(inputStageName);
if (joinedInputs == null) {
joinedInputs = preJoinStream.mapValues(new InitialJoinFunction<>(inputStageName));
} else {
if (isFullOuter) {
OuterJoinFlattenFunction<Object> flattenFunction = new OuterJoinFlattenFunction<>(inputStageName);
joinedInputs = numPartitions == null ? joinedInputs.fullOuterJoin(preJoinStream).mapValues(flattenFunction) : joinedInputs.fullOuterJoin(preJoinStream, numPartitions).mapValues(flattenFunction);
} else {
LeftJoinFlattenFunction<Object> flattenFunction = new LeftJoinFlattenFunction<>(inputStageName);
joinedInputs = numPartitions == null ? joinedInputs.leftOuterJoin(preJoinStream).mapValues(flattenFunction) : joinedInputs.leftOuterJoin(preJoinStream, numPartitions).mapValues(flattenFunction);
}
}
}
// should never happen, but removes warnings
if (joinedInputs == null) {
throw new IllegalStateException("There are no inputs into join stage " + stageName);
}
stageData = mergeJoinResults(stageInfo, joinedInputs).cache();
} else if (Windower.PLUGIN_TYPE.equals(pluginType)) {
Windower windower = pluginContext.newPluginInstance(stageName, macroEvaluator);
stageData = stageData.window(stageInfo, windower);
} else {
throw new IllegalStateException(String.format("Stage %s is of unsupported plugin type %s.", stageName, pluginType));
}
if (shouldCache(pipelinePhase, stageInfo)) {
stageData = stageData.cache();
if (stageErrors != null) {
stageErrors = stageErrors.cache();
}
}
stageDataCollections.put(stageName, stageData);
stageErrorCollections.put(stageName, stageErrors);
}
}
use of co.cask.cdap.etl.planner.StageInfo in project cdap by caskdata.
the class TransformExecutorFactory method create.
/**
* Create a transform executor for the specified pipeline. Will instantiate and initialize all sources,
* transforms, and sinks in the pipeline.
*
* @param pipeline the pipeline to create a transform executor for
* @return executor for the pipeline
* @throws InstantiationException if there was an error instantiating a plugin
* @throws Exception if there was an error initializing a plugin
*/
public <KEY_OUT, VAL_OUT> PipeTransformExecutor<T> create(PipelinePhase pipeline, OutputWriter<KEY_OUT, VAL_OUT> outputWriter, Map<String, ErrorOutputWriter<Object, Object>> transformErrorSinkMap) throws Exception {
Map<String, PipeTransformDetail> transformations = new HashMap<>();
Set<String> sources = pipeline.getSources();
// Set input and output schema for this stage
for (String pluginType : pipeline.getPluginTypes()) {
for (StageInfo stageInfo : pipeline.getStagesOfType(pluginType)) {
String stageName = stageInfo.getName();
outputSchemas.put(stageName, stageInfo.getOutputSchema());
perStageInputSchemas.put(stageName, stageInfo.getInputSchemas());
}
}
// recursively set PipeTransformDetail for all the stages
for (String source : sources) {
setPipeTransformDetail(pipeline, source, transformations, transformErrorSinkMap, outputWriter);
}
// sourceStageName will be null in reducers, so need to handle that case
Set<String> startingPoints = (sourceStageName == null) ? pipeline.getSources() : Sets.newHashSet(sourceStageName);
return new PipeTransformExecutor<>(transformations, startingPoints);
}
use of co.cask.cdap.etl.planner.StageInfo in project cdap by caskdata.
the class TransformExecutorFactory method setPipeTransformDetail.
private <KEY_OUT, VAL_OUT> void setPipeTransformDetail(PipelinePhase pipeline, String stageName, Map<String, PipeTransformDetail> transformations, Map<String, ErrorOutputWriter<Object, Object>> transformErrorSinkMap, OutputWriter<KEY_OUT, VAL_OUT> outputWriter) throws Exception {
if (pipeline.getSinks().contains(stageName)) {
StageInfo stageInfo = pipeline.getStage(stageName);
// If there is a connector sink/ joiner at the end of pipeline, do not remove stage name. This is needed to save
// stageName along with the record in connector sink and joiner takes input along with stageName
String pluginType = stageInfo.getPluginType();
boolean removeStageName = !(pluginType.equals(Constants.CONNECTOR_TYPE) || pluginType.equals(BatchJoiner.PLUGIN_TYPE));
boolean isErrorConsumer = pluginType.equals(ErrorTransform.PLUGIN_TYPE);
transformations.put(stageName, new PipeTransformDetail(stageName, removeStageName, isErrorConsumer, getTransformation(stageInfo), new SinkEmitter<>(stageName, outputWriter)));
return;
}
try {
addTransformation(pipeline, stageName, transformations, transformErrorSinkMap);
} catch (Exception e) {
// Catch the Exception to generate a User Error Log for the Pipeline
PIPELINE_LOG.error("Failed to start pipeline stage '{}' with the error: {}. Please review your pipeline " + "configuration and check the system logs for more details.", stageName, Throwables.getRootCause(e).getMessage(), Throwables.getRootCause(e));
throw e;
}
for (String output : pipeline.getDag().getNodeOutputs(stageName)) {
setPipeTransformDetail(pipeline, output, transformations, transformErrorSinkMap, outputWriter);
transformations.get(stageName).addTransformation(output, transformations.get(output));
}
}
use of co.cask.cdap.etl.planner.StageInfo in project cdap by caskdata.
the class ETLWorker method initializeTransforms.
private void initializeTransforms(WorkerContext context, Map<String, TransformDetail> transformDetailMap, PipelinePhase pipeline) throws Exception {
Set<StageInfo> transformInfos = pipeline.getStagesOfType(Transform.PLUGIN_TYPE);
Preconditions.checkArgument(transformInfos != null);
tranformIdToDatasetName = new HashMap<>(transformInfos.size());
for (StageInfo transformInfo : transformInfos) {
String transformName = transformInfo.getName();
try {
Transform<?, ?> transform = context.newPluginInstance(transformName);
;
transform = new WrappedTransform<>(transform, Caller.DEFAULT);
WorkerRealtimeContext transformContext = new WorkerRealtimeContext(context, metrics, new TxLookupProvider(context), transformInfo);
LOG.debug("Transform Class : {}", transform.getClass().getName());
transform.initialize(transformContext);
StageMetrics stageMetrics = new DefaultStageMetrics(metrics, transformName);
transformDetailMap.put(transformName, new TransformDetail(new TrackedTransform<>(transform, stageMetrics, context.getDataTracer(transformName)), pipeline.getStageOutputs(transformName)));
if (transformInfo.getErrorDatasetName() != null) {
tranformIdToDatasetName.put(transformName, transformInfo.getErrorDatasetName());
}
} catch (InstantiationException e) {
LOG.error("Unable to instantiate Transform", e);
Throwables.propagate(e);
}
}
}
Aggregations