Search in sources :

Example 1 with PluginFunctionContext

use of co.cask.cdap.etl.spark.function.PluginFunctionContext in project cdap by caskdata.

the class SparkPipelineRunner method runPipeline.

public void runPipeline(PipelinePhase pipelinePhase, String sourcePluginType, JavaSparkExecutionContext sec, Map<String, Integer> stagePartitions, PipelinePluginContext pluginContext) throws Exception {
    MacroEvaluator macroEvaluator = new DefaultMacroEvaluator(sec.getWorkflowToken(), sec.getRuntimeArguments(), sec.getLogicalStartTime(), sec, sec.getNamespace());
    Map<String, SparkCollection<Object>> stageDataCollections = new HashMap<>();
    Map<String, SparkCollection<ErrorRecord<Object>>> stageErrorCollections = new HashMap<>();
    // should never happen, but removes warning
    if (pipelinePhase.getDag() == null) {
        throw new IllegalStateException("Pipeline phase has no connections.");
    }
    for (String stageName : pipelinePhase.getDag().getTopologicalOrder()) {
        StageInfo stageInfo = pipelinePhase.getStage(stageName);
        //noinspection ConstantConditions
        String pluginType = stageInfo.getPluginType();
        // don't want to do an additional filter for stages that can emit errors,
        // but aren't connected to an ErrorTransform
        boolean hasErrorOutput = false;
        Set<String> outputs = pipelinePhase.getStageOutputs(stageInfo.getName());
        for (String output : outputs) {
            //noinspection ConstantConditions
            if (ErrorTransform.PLUGIN_TYPE.equals(pipelinePhase.getStage(output).getPluginType())) {
                hasErrorOutput = true;
                break;
            }
        }
        SparkCollection<Object> stageData = null;
        Map<String, SparkCollection<Object>> inputDataCollections = new HashMap<>();
        Set<String> stageInputs = stageInfo.getInputs();
        for (String inputStageName : stageInputs) {
            inputDataCollections.put(inputStageName, stageDataCollections.get(inputStageName));
        }
        // initialize the stageRDD as the union of all input RDDs.
        if (!inputDataCollections.isEmpty()) {
            Iterator<SparkCollection<Object>> inputCollectionIter = inputDataCollections.values().iterator();
            stageData = inputCollectionIter.next();
            // don't union inputs records if we're joining or if we're processing errors
            while (!BatchJoiner.PLUGIN_TYPE.equals(pluginType) && !ErrorTransform.PLUGIN_TYPE.equals(pluginType) && inputCollectionIter.hasNext()) {
                stageData = stageData.union(inputCollectionIter.next());
            }
        }
        SparkCollection<ErrorRecord<Object>> stageErrors = null;
        PluginFunctionContext pluginFunctionContext = new PluginFunctionContext(stageInfo, sec);
        if (stageData == null) {
            // null in the other else-if conditions
            if (sourcePluginType.equals(pluginType)) {
                SparkCollection<Tuple2<Boolean, Object>> combinedData = getSource(stageInfo);
                if (hasErrorOutput) {
                    // need to cache, otherwise the stage can be computed twice, once for output and once for errors.
                    combinedData.cache();
                    stageErrors = combinedData.flatMap(stageInfo, Compat.convert(new OutputFilter<>()));
                }
                stageData = combinedData.flatMap(stageInfo, Compat.convert(new ErrorFilter<>()));
            } else {
                throw new IllegalStateException(String.format("Stage '%s' has no input and is not a source.", stageName));
            }
        } else if (BatchSink.PLUGIN_TYPE.equals(pluginType)) {
            stageData.store(stageInfo, Compat.convert(new BatchSinkFunction(pluginFunctionContext)));
        } else if (Transform.PLUGIN_TYPE.equals(pluginType)) {
            SparkCollection<Tuple2<Boolean, Object>> combinedData = stageData.transform(stageInfo);
            if (hasErrorOutput) {
                // need to cache, otherwise the stage can be computed twice, once for output and once for errors.
                combinedData.cache();
                stageErrors = combinedData.flatMap(stageInfo, Compat.convert(new OutputFilter<>()));
            }
            stageData = combinedData.flatMap(stageInfo, Compat.convert(new ErrorFilter<>()));
        } else if (ErrorTransform.PLUGIN_TYPE.equals(pluginType)) {
            // union all the errors coming into this stage
            SparkCollection<ErrorRecord<Object>> inputErrors = null;
            for (String inputStage : stageInputs) {
                SparkCollection<ErrorRecord<Object>> inputErrorsFromStage = stageErrorCollections.get(inputStage);
                if (inputErrorsFromStage == null) {
                    continue;
                }
                if (inputErrors == null) {
                    inputErrors = inputErrorsFromStage;
                } else {
                    inputErrors = inputErrors.union(inputErrorsFromStage);
                }
            }
            if (inputErrors != null) {
                SparkCollection<Tuple2<Boolean, Object>> combinedData = inputErrors.flatMap(stageInfo, Compat.convert(new ErrorTransformFunction<>(pluginFunctionContext)));
                if (hasErrorOutput) {
                    // need to cache, otherwise the stage can be computed twice, once for output and once for errors.
                    combinedData.cache();
                    stageErrors = combinedData.flatMap(stageInfo, Compat.convert(new OutputFilter<>()));
                }
                stageData = combinedData.flatMap(stageInfo, Compat.convert(new ErrorFilter<>()));
            }
        } else if (SparkCompute.PLUGIN_TYPE.equals(pluginType)) {
            SparkCompute<Object, Object> sparkCompute = pluginContext.newPluginInstance(stageName, macroEvaluator);
            stageData = stageData.compute(stageInfo, sparkCompute);
        } else if (SparkSink.PLUGIN_TYPE.equals(pluginType)) {
            SparkSink<Object> sparkSink = pluginContext.newPluginInstance(stageName, macroEvaluator);
            stageData.store(stageInfo, sparkSink);
        } else if (BatchAggregator.PLUGIN_TYPE.equals(pluginType)) {
            Integer partitions = stagePartitions.get(stageName);
            SparkCollection<Tuple2<Boolean, Object>> combinedData = stageData.aggregate(stageInfo, partitions);
            if (hasErrorOutput) {
                // need to cache, otherwise the stage can be computed twice, once for output and once for errors.
                combinedData.cache();
                stageErrors = combinedData.flatMap(stageInfo, Compat.convert(new OutputFilter<>()));
            }
            stageData = combinedData.flatMap(stageInfo, Compat.convert(new ErrorFilter<>()));
        } else if (BatchJoiner.PLUGIN_TYPE.equals(pluginType)) {
            BatchJoiner<Object, Object, Object> joiner = pluginContext.newPluginInstance(stageName, macroEvaluator);
            BatchJoinerRuntimeContext joinerRuntimeContext = pluginFunctionContext.createBatchRuntimeContext();
            joiner.initialize(joinerRuntimeContext);
            Map<String, SparkPairCollection<Object, Object>> preJoinStreams = new HashMap<>();
            for (Map.Entry<String, SparkCollection<Object>> inputStreamEntry : inputDataCollections.entrySet()) {
                String inputStage = inputStreamEntry.getKey();
                SparkCollection<Object> inputStream = inputStreamEntry.getValue();
                preJoinStreams.put(inputStage, addJoinKey(stageInfo, inputStage, inputStream));
            }
            Set<String> remainingInputs = new HashSet<>();
            remainingInputs.addAll(inputDataCollections.keySet());
            Integer numPartitions = stagePartitions.get(stageName);
            SparkPairCollection<Object, List<JoinElement<Object>>> joinedInputs = null;
            // inner join on required inputs
            for (final String inputStageName : joiner.getJoinConfig().getRequiredInputs()) {
                SparkPairCollection<Object, Object> preJoinCollection = preJoinStreams.get(inputStageName);
                if (joinedInputs == null) {
                    joinedInputs = preJoinCollection.mapValues(new InitialJoinFunction<>(inputStageName));
                } else {
                    JoinFlattenFunction<Object> joinFlattenFunction = new JoinFlattenFunction<>(inputStageName);
                    joinedInputs = numPartitions == null ? joinedInputs.join(preJoinCollection).mapValues(joinFlattenFunction) : joinedInputs.join(preJoinCollection, numPartitions).mapValues(joinFlattenFunction);
                }
                remainingInputs.remove(inputStageName);
            }
            // outer join on non-required inputs
            boolean isFullOuter = joinedInputs == null;
            for (final String inputStageName : remainingInputs) {
                SparkPairCollection<Object, Object> preJoinStream = preJoinStreams.get(inputStageName);
                if (joinedInputs == null) {
                    joinedInputs = preJoinStream.mapValues(new InitialJoinFunction<>(inputStageName));
                } else {
                    if (isFullOuter) {
                        OuterJoinFlattenFunction<Object> flattenFunction = new OuterJoinFlattenFunction<>(inputStageName);
                        joinedInputs = numPartitions == null ? joinedInputs.fullOuterJoin(preJoinStream).mapValues(flattenFunction) : joinedInputs.fullOuterJoin(preJoinStream, numPartitions).mapValues(flattenFunction);
                    } else {
                        LeftJoinFlattenFunction<Object> flattenFunction = new LeftJoinFlattenFunction<>(inputStageName);
                        joinedInputs = numPartitions == null ? joinedInputs.leftOuterJoin(preJoinStream).mapValues(flattenFunction) : joinedInputs.leftOuterJoin(preJoinStream, numPartitions).mapValues(flattenFunction);
                    }
                }
            }
            // should never happen, but removes warnings
            if (joinedInputs == null) {
                throw new IllegalStateException("There are no inputs into join stage " + stageName);
            }
            stageData = mergeJoinResults(stageInfo, joinedInputs).cache();
        } else if (Windower.PLUGIN_TYPE.equals(pluginType)) {
            Windower windower = pluginContext.newPluginInstance(stageName, macroEvaluator);
            stageData = stageData.window(stageInfo, windower);
        } else {
            throw new IllegalStateException(String.format("Stage %s is of unsupported plugin type %s.", stageName, pluginType));
        }
        if (shouldCache(pipelinePhase, stageInfo)) {
            stageData = stageData.cache();
            if (stageErrors != null) {
                stageErrors = stageErrors.cache();
            }
        }
        stageDataCollections.put(stageName, stageData);
        stageErrorCollections.put(stageName, stageErrors);
    }
}
Also used : OutputFilter(co.cask.cdap.etl.spark.function.OutputFilter) MacroEvaluator(co.cask.cdap.api.macro.MacroEvaluator) DefaultMacroEvaluator(co.cask.cdap.etl.common.DefaultMacroEvaluator) HashSet(java.util.HashSet) Set(java.util.Set) HashMap(java.util.HashMap) StageInfo(co.cask.cdap.etl.planner.StageInfo) PluginFunctionContext(co.cask.cdap.etl.spark.function.PluginFunctionContext) DefaultMacroEvaluator(co.cask.cdap.etl.common.DefaultMacroEvaluator) BatchJoinerRuntimeContext(co.cask.cdap.etl.api.batch.BatchJoinerRuntimeContext) Windower(co.cask.cdap.etl.api.streaming.Windower) ErrorFilter(co.cask.cdap.etl.spark.function.ErrorFilter) JoinFlattenFunction(co.cask.cdap.etl.spark.function.JoinFlattenFunction) OuterJoinFlattenFunction(co.cask.cdap.etl.spark.function.OuterJoinFlattenFunction) LeftJoinFlattenFunction(co.cask.cdap.etl.spark.function.LeftJoinFlattenFunction) BatchJoiner(co.cask.cdap.etl.api.batch.BatchJoiner) JoinElement(co.cask.cdap.etl.api.JoinElement) BatchSinkFunction(co.cask.cdap.etl.spark.function.BatchSinkFunction) SparkSink(co.cask.cdap.etl.api.batch.SparkSink) OuterJoinFlattenFunction(co.cask.cdap.etl.spark.function.OuterJoinFlattenFunction) LeftJoinFlattenFunction(co.cask.cdap.etl.spark.function.LeftJoinFlattenFunction) Tuple2(scala.Tuple2) HashMap(java.util.HashMap) Map(java.util.Map) ErrorRecord(co.cask.cdap.etl.api.ErrorRecord)

Example 2 with PluginFunctionContext

use of co.cask.cdap.etl.spark.function.PluginFunctionContext in project cdap by caskdata.

the class SparkStreamingPipelineRunner method getSource.

@Override
protected SparkCollection<RecordInfo<Object>> getSource(StageSpec stageSpec, StageStatisticsCollector collector) throws Exception {
    StreamingSource<Object> source;
    if (checkpointsDisabled) {
        PluginFunctionContext pluginFunctionContext = new PluginFunctionContext(stageSpec, sec, collector);
        source = pluginFunctionContext.createPlugin();
    } else {
        // check for macros in any StreamingSource. If checkpoints are enabled,
        // SparkStreaming will serialize all InputDStreams created in the checkpoint, which means
        // the InputDStream is deserialized directly from the checkpoint instead of instantiated through CDAP.
        // This means there isn't any way for us to perform macro evaluation on sources when they are loaded from
        // checkpoints. We can work around this in all other pipeline stages by dynamically instantiating the
        // plugin in all DStream functions, but can't for InputDStreams because the InputDStream constructor
        // adds itself to the context dag. Yay for constructors with global side effects.
        // TODO: (HYDRATOR-1030) figure out how to do this at configure time instead of run time
        MacroEvaluator macroEvaluator = new ErrorMacroEvaluator("Due to spark limitations, macro evaluation is not allowed in streaming sources when checkpointing " + "is enabled.");
        PluginContext pluginContext = new SparkPipelinePluginContext(sec.getPluginContext(), sec.getMetrics(), spec.isStageLoggingEnabled(), spec.isProcessTimingEnabled());
        source = pluginContext.newPluginInstance(stageSpec.getName(), macroEvaluator);
    }
    DataTracer dataTracer = sec.getDataTracer(stageSpec.getName());
    StreamingContext sourceContext = new DefaultStreamingContext(stageSpec, sec, streamingContext);
    JavaDStream<Object> javaDStream = source.getStream(sourceContext);
    if (dataTracer.isEnabled()) {
        // it will create a new function for each RDD, which would limit each RDD but not the entire DStream.
        javaDStream = javaDStream.transform(new LimitingFunction<>(spec.getNumOfRecordsPreview()));
    }
    JavaDStream<RecordInfo<Object>> outputDStream = javaDStream.transform(new CountingTransformFunction<>(stageSpec.getName(), sec.getMetrics(), "records.out", dataTracer)).map(new WrapOutputTransformFunction<>(stageSpec.getName()));
    return new DStreamCollection<>(sec, outputDStream);
}
Also used : PairDStreamCollection(co.cask.cdap.etl.spark.streaming.PairDStreamCollection) DStreamCollection(co.cask.cdap.etl.spark.streaming.DStreamCollection) StreamingContext(co.cask.cdap.etl.api.streaming.StreamingContext) JavaStreamingContext(org.apache.spark.streaming.api.java.JavaStreamingContext) DefaultStreamingContext(co.cask.cdap.etl.spark.streaming.DefaultStreamingContext) MacroEvaluator(co.cask.cdap.api.macro.MacroEvaluator) SparkPipelinePluginContext(co.cask.cdap.etl.spark.plugin.SparkPipelinePluginContext) PluginContext(co.cask.cdap.api.plugin.PluginContext) RecordInfo(co.cask.cdap.etl.common.RecordInfo) CountingTransformFunction(co.cask.cdap.etl.spark.streaming.function.CountingTransformFunction) DefaultStreamingContext(co.cask.cdap.etl.spark.streaming.DefaultStreamingContext) PluginFunctionContext(co.cask.cdap.etl.spark.function.PluginFunctionContext) SparkPipelinePluginContext(co.cask.cdap.etl.spark.plugin.SparkPipelinePluginContext) DataTracer(co.cask.cdap.api.preview.DataTracer) LimitingFunction(co.cask.cdap.etl.spark.streaming.function.preview.LimitingFunction)

Example 3 with PluginFunctionContext

use of co.cask.cdap.etl.spark.function.PluginFunctionContext in project cdap by caskdata.

the class StreamingBatchSinkFunction method call.

@Override
public Void call(JavaRDD<T> data, Time batchTime) throws Exception {
    if (data.isEmpty()) {
        return null;
    }
    final long logicalStartTime = batchTime.milliseconds();
    MacroEvaluator evaluator = new DefaultMacroEvaluator(new BasicArguments(sec), logicalStartTime, sec.getSecureStore(), sec.getNamespace());
    PluginContext pluginContext = new SparkPipelinePluginContext(sec.getPluginContext(), sec.getMetrics(), stageSpec.isStageLoggingEnabled(), stageSpec.isProcessTimingEnabled());
    final SparkBatchSinkFactory sinkFactory = new SparkBatchSinkFactory();
    final String stageName = stageSpec.getName();
    final BatchSink<Object, Object, Object> batchSink = pluginContext.newPluginInstance(stageName, evaluator);
    final PipelineRuntime pipelineRuntime = new SparkPipelineRuntime(sec, logicalStartTime);
    boolean isPrepared = false;
    boolean isDone = false;
    try {
        sec.execute(new TxRunnable() {

            @Override
            public void run(DatasetContext datasetContext) throws Exception {
                SparkBatchSinkContext sinkContext = new SparkBatchSinkContext(sinkFactory, sec, datasetContext, pipelineRuntime, stageSpec);
                batchSink.prepareRun(sinkContext);
            }
        });
        isPrepared = true;
        PluginFunctionContext pluginFunctionContext = new PluginFunctionContext(stageSpec, sec, pipelineRuntime.getArguments().asMap(), batchTime.milliseconds(), new NoopStageStatisticsCollector());
        PairFlatMapFunc<T, Object, Object> sinkFunction = new BatchSinkFunction<T, Object, Object>(pluginFunctionContext);
        sinkFactory.writeFromRDD(data.flatMapToPair(Compat.convert(sinkFunction)), sec, stageName, Object.class, Object.class);
        isDone = true;
        sec.execute(new TxRunnable() {

            @Override
            public void run(DatasetContext datasetContext) throws Exception {
                SparkBatchSinkContext sinkContext = new SparkBatchSinkContext(sinkFactory, sec, datasetContext, pipelineRuntime, stageSpec);
                batchSink.onRunFinish(true, sinkContext);
            }
        });
    } catch (Exception e) {
        LOG.error("Error writing to sink {} for the batch for time {}.", stageName, logicalStartTime, e);
    } finally {
        if (isPrepared && !isDone) {
            sec.execute(new TxRunnable() {

                @Override
                public void run(DatasetContext datasetContext) throws Exception {
                    SparkBatchSinkContext sinkContext = new SparkBatchSinkContext(sinkFactory, sec, datasetContext, pipelineRuntime, stageSpec);
                    batchSink.onRunFinish(false, sinkContext);
                }
            });
        }
    }
    return null;
}
Also used : NoopStageStatisticsCollector(co.cask.cdap.etl.common.NoopStageStatisticsCollector) MacroEvaluator(co.cask.cdap.api.macro.MacroEvaluator) DefaultMacroEvaluator(co.cask.cdap.etl.common.DefaultMacroEvaluator) SparkPipelineRuntime(co.cask.cdap.etl.spark.SparkPipelineRuntime) PipelineRuntime(co.cask.cdap.etl.common.PipelineRuntime) SparkPipelinePluginContext(co.cask.cdap.etl.spark.plugin.SparkPipelinePluginContext) PluginContext(co.cask.cdap.api.plugin.PluginContext) SparkPipelineRuntime(co.cask.cdap.etl.spark.SparkPipelineRuntime) SparkBatchSinkContext(co.cask.cdap.etl.spark.batch.SparkBatchSinkContext) BatchSinkFunction(co.cask.cdap.etl.spark.function.BatchSinkFunction) SparkPipelinePluginContext(co.cask.cdap.etl.spark.plugin.SparkPipelinePluginContext) PluginFunctionContext(co.cask.cdap.etl.spark.function.PluginFunctionContext) SparkBatchSinkFactory(co.cask.cdap.etl.spark.batch.SparkBatchSinkFactory) TxRunnable(co.cask.cdap.api.TxRunnable) DefaultMacroEvaluator(co.cask.cdap.etl.common.DefaultMacroEvaluator) BasicArguments(co.cask.cdap.etl.common.BasicArguments) DatasetContext(co.cask.cdap.api.data.DatasetContext)

Example 4 with PluginFunctionContext

use of co.cask.cdap.etl.spark.function.PluginFunctionContext in project cdap by caskdata.

the class RDDCollection method aggregate.

@Override
public SparkCollection<RecordInfo<Object>> aggregate(StageSpec stageSpec, @Nullable Integer partitions, StageStatisticsCollector collector) {
    PluginFunctionContext pluginFunctionContext = new PluginFunctionContext(stageSpec, sec, collector);
    PairFlatMapFunc<T, Object, T> groupByFunction = new AggregatorGroupByFunction<>(pluginFunctionContext);
    PairFlatMapFunction<T, Object, T> sparkGroupByFunction = Compat.convert(groupByFunction);
    JavaPairRDD<Object, T> keyedCollection = rdd.flatMapToPair(sparkGroupByFunction);
    JavaPairRDD<Object, Iterable<T>> groupedCollection = partitions == null ? keyedCollection.groupByKey() : keyedCollection.groupByKey(partitions);
    FlatMapFunc<Tuple2<Object, Iterable<T>>, RecordInfo<Object>> aggregateFunction = new AggregatorAggregateFunction<>(pluginFunctionContext);
    FlatMapFunction<Tuple2<Object, Iterable<T>>, RecordInfo<Object>> sparkAggregateFunction = Compat.convert(aggregateFunction);
    return wrap(groupedCollection.flatMap(sparkAggregateFunction));
}
Also used : RecordInfo(co.cask.cdap.etl.common.RecordInfo) AggregatorAggregateFunction(co.cask.cdap.etl.spark.function.AggregatorAggregateFunction) PluginFunctionContext(co.cask.cdap.etl.spark.function.PluginFunctionContext) Tuple2(scala.Tuple2) AggregatorGroupByFunction(co.cask.cdap.etl.spark.function.AggregatorGroupByFunction)

Example 5 with PluginFunctionContext

use of co.cask.cdap.etl.spark.function.PluginFunctionContext in project cdap by caskdata.

the class DynamicDriverContext method readExternal.

@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
    serializationVersion = in.readUTF();
    stageSpec = (StageSpec) in.readObject();
    sec = (JavaSparkExecutionContext) in.readObject();
    // we intentionally do not serialize this context in order to ensure that the runtime arguments
    // and logical start time are picked up from the JavaSparkExecutionContext. If we serialized it,
    // the arguments and start time of the very first pipeline run would get serialized, then
    // used for every subsequent run that loads from the checkpoint.
    pluginFunctionContext = new PluginFunctionContext(stageSpec, sec, new NoopStageStatisticsCollector());
}
Also used : PluginFunctionContext(co.cask.cdap.etl.spark.function.PluginFunctionContext) NoopStageStatisticsCollector(co.cask.cdap.etl.common.NoopStageStatisticsCollector)

Aggregations

PluginFunctionContext (co.cask.cdap.etl.spark.function.PluginFunctionContext)10 MacroEvaluator (co.cask.cdap.api.macro.MacroEvaluator)5 Tuple2 (scala.Tuple2)4 PluginContext (co.cask.cdap.api.plugin.PluginContext)3 DefaultMacroEvaluator (co.cask.cdap.etl.common.DefaultMacroEvaluator)3 NoopStageStatisticsCollector (co.cask.cdap.etl.common.NoopStageStatisticsCollector)3 PipelineRuntime (co.cask.cdap.etl.common.PipelineRuntime)3 RecordInfo (co.cask.cdap.etl.common.RecordInfo)3 SparkPipelineRuntime (co.cask.cdap.etl.spark.SparkPipelineRuntime)3 BatchSinkFunction (co.cask.cdap.etl.spark.function.BatchSinkFunction)3 SparkPipelinePluginContext (co.cask.cdap.etl.spark.plugin.SparkPipelinePluginContext)3 TxRunnable (co.cask.cdap.api.TxRunnable)2 DatasetContext (co.cask.cdap.api.data.DatasetContext)2 DataTracer (co.cask.cdap.api.preview.DataTracer)2 Alert (co.cask.cdap.etl.api.Alert)2 ErrorRecord (co.cask.cdap.etl.api.ErrorRecord)2 BatchJoinerRuntimeContext (co.cask.cdap.etl.api.batch.BatchJoinerRuntimeContext)2 StreamingContext (co.cask.cdap.etl.api.streaming.StreamingContext)2 Windower (co.cask.cdap.etl.api.streaming.Windower)2 BasicArguments (co.cask.cdap.etl.common.BasicArguments)2