Search in sources :

Example 1 with JoinExpressionRequest

use of io.cdap.cdap.etl.spark.join.JoinExpressionRequest in project cdap by caskdata.

the class RDDCollection method join.

@SuppressWarnings("unchecked")
@Override
public SparkCollection<T> join(JoinExpressionRequest joinRequest) {
    Function<StructuredRecord, StructuredRecord> recordsInCounter = new CountingFunction<>(joinRequest.getStageName(), sec.getMetrics(), Constants.Metrics.RECORDS_IN, sec.getDataTracer(joinRequest.getStageName()));
    JoinCollection leftInfo = joinRequest.getLeft();
    StructType leftSchema = DataFrames.toDataType(leftInfo.getSchema());
    Dataset<Row> leftDF = toDataset(((JavaRDD<StructuredRecord>) rdd).map(recordsInCounter), leftSchema);
    JoinCollection rightInfo = joinRequest.getRight();
    SparkCollection<?> rightData = rightInfo.getData();
    StructType rightSchema = DataFrames.toDataType(rightInfo.getSchema());
    Dataset<Row> rightDF = toDataset(((JavaRDD<StructuredRecord>) rightData.getUnderlying()).map(recordsInCounter), rightSchema);
    // if this is not a broadcast join, Spark will reprocess each side multiple times, depending on the number
    // of partitions. If the left side has N partitions and the right side has M partitions,
    // the left side gets reprocessed M times and the right side gets reprocessed N times.
    // Cache the input to prevent confusing metrics and potential source re-reading.
    // this is only necessary for inner joins, since outer joins are automatically changed to
    // BroadcastNestedLoopJoins by Spark
    boolean isInner = joinRequest.getLeft().isRequired() && joinRequest.getRight().isRequired();
    boolean isBroadcast = joinRequest.getLeft().isBroadcast() || joinRequest.getRight().isBroadcast();
    if (isInner && !isBroadcast) {
        leftDF = leftDF.persist(StorageLevel.DISK_ONLY());
        rightDF = rightDF.persist(StorageLevel.DISK_ONLY());
    }
    // register using unique names to avoid collisions.
    String leftId = UUID.randomUUID().toString().replaceAll("-", "");
    String rightId = UUID.randomUUID().toString().replaceAll("-", "");
    leftDF.registerTempTable(leftId);
    rightDF.registerTempTable(rightId);
    /*
        Suppose the join was originally:

          select P.id as id, users.name as username
          from purchases as P join users
          on P.user_id = users.id or P.user_id = 0

        After registering purchases as uuid0 and users as uuid1,
        the query needs to be rewritten to replace the original names with the new generated ids,
        as the query needs to be:

          select P.id as id, uuid1.name as username
          from uuid0 as P join uuid1
          on P.user_id = uuid1.id or P.user_id = 0
     */
    String sql = getSQL(joinRequest.rename(leftId, rightId));
    LOG.debug("Executing join stage {} using SQL: \n{}", joinRequest.getStageName(), sql);
    Dataset<Row> joined = sqlContext.sql(sql);
    Schema outputSchema = joinRequest.getOutputSchema();
    JavaRDD<StructuredRecord> output = joined.javaRDD().map(r -> DataFrames.fromRow(r, outputSchema)).map(new CountingFunction<>(joinRequest.getStageName(), sec.getMetrics(), Constants.Metrics.RECORDS_OUT, sec.getDataTracer(joinRequest.getStageName())));
    return (SparkCollection<T>) wrap(output);
}
Also used : DataType(org.apache.spark.sql.types.DataType) org.apache.spark.sql.functions.coalesce(org.apache.spark.sql.functions.coalesce) Arrays(java.util.Arrays) DataFrames(io.cdap.cdap.api.spark.sql.DataFrames) DatasetAggregationReduceFunction(io.cdap.cdap.etl.spark.function.DatasetAggregationReduceFunction) JoinExpressionRequest(io.cdap.cdap.etl.spark.join.JoinExpressionRequest) PluginFunctionContext(io.cdap.cdap.etl.spark.function.PluginFunctionContext) LoggerFactory(org.slf4j.LoggerFactory) CountingFunction(io.cdap.cdap.etl.spark.function.CountingFunction) Constants(io.cdap.cdap.etl.common.Constants) StructuredRecord(io.cdap.cdap.api.data.format.StructuredRecord) JavaSparkExecutionContext(io.cdap.cdap.api.spark.JavaSparkExecutionContext) DatasetContext(io.cdap.cdap.api.data.DatasetContext) StorageLevel(org.apache.spark.storage.StorageLevel) Map(java.util.Map) MapFunction(org.apache.spark.api.java.function.MapFunction) FunctionCache(io.cdap.cdap.etl.spark.function.FunctionCache) DataTypes(org.apache.spark.sql.types.DataTypes) StructType(org.apache.spark.sql.types.StructType) JoinField(io.cdap.cdap.etl.api.join.JoinField) Seq(scala.collection.Seq) RecordInfo(io.cdap.cdap.etl.common.RecordInfo) UUID(java.util.UUID) Tuple2(scala.Tuple2) Collectors(java.util.stream.Collectors) StageStatisticsCollector(io.cdap.cdap.etl.common.StageStatisticsCollector) SparkCollection(io.cdap.cdap.etl.spark.SparkCollection) List(java.util.List) JoinRequest(io.cdap.cdap.etl.spark.join.JoinRequest) Encoder(org.apache.spark.sql.Encoder) DatasetAggregationFinalizeFunction(io.cdap.cdap.etl.spark.function.DatasetAggregationFinalizeFunction) Function(org.apache.spark.api.java.function.Function) org.apache.spark.sql.functions(org.apache.spark.sql.functions) IntStream(java.util.stream.IntStream) LiteralsBridge(io.cdap.cdap.etl.spark.plugin.LiteralsBridge) Dataset(org.apache.spark.sql.Dataset) DatasetAggregationAccumulator(io.cdap.cdap.etl.spark.function.DatasetAggregationAccumulator) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) DatasetAggregationGetKeyFunction(io.cdap.cdap.etl.spark.function.DatasetAggregationGetKeyFunction) JavaRDD(org.apache.spark.api.java.JavaRDD) Nullable(javax.annotation.Nullable) JavaConversions(scala.collection.JavaConversions) Logger(org.slf4j.Logger) Iterator(java.util.Iterator) JoinCollection(io.cdap.cdap.etl.spark.join.JoinCollection) Column(org.apache.spark.sql.Column) SQLContext(org.apache.spark.sql.SQLContext) Row(org.apache.spark.sql.Row) Schema(io.cdap.cdap.api.data.schema.Schema) Encoders(org.apache.spark.sql.Encoders) org.apache.spark.sql.functions.floor(org.apache.spark.sql.functions.floor) StageSpec(io.cdap.cdap.etl.proto.v2.spec.StageSpec) StructType(org.apache.spark.sql.types.StructType) Schema(io.cdap.cdap.api.data.schema.Schema) CountingFunction(io.cdap.cdap.etl.spark.function.CountingFunction) StructuredRecord(io.cdap.cdap.api.data.format.StructuredRecord) SparkCollection(io.cdap.cdap.etl.spark.SparkCollection) Row(org.apache.spark.sql.Row) JoinCollection(io.cdap.cdap.etl.spark.join.JoinCollection)

Example 2 with JoinExpressionRequest

use of io.cdap.cdap.etl.spark.join.JoinExpressionRequest in project cdap by caskdata.

the class SparkPipelineRunner method handleAutoJoinWithSQL.

/*
      Implement a join by generating a SQL query that Spark will execute.
      Joins on key equality are not implemented this way because they have special repartitioning
      that allows them to specify a different number of partitions for different joins in the same pipeline.

      When Spark handles SQL queries, it uses spark.sql.shuffle.partitions number of partitions, which is a global
      setting that applies to any SQL join in the pipeline.
   */
private SparkCollection<Object> handleAutoJoinWithSQL(String stageName, JoinDefinition joinDefinition, Map<String, SparkCollection<Object>> inputDataCollections) {
    JoinCondition.OnExpression condition = (JoinCondition.OnExpression) joinDefinition.getCondition();
    Map<String, String> aliases = condition.getDatasetAliases();
    // earlier validation ensure there are exactly 2 inputs being joined
    JoinStage leftStage = joinDefinition.getStages().get(0);
    JoinStage rightStage = joinDefinition.getStages().get(1);
    String leftStageName = leftStage.getStageName();
    String rightStageName = rightStage.getStageName();
    SparkCollection<Object> leftData = inputDataCollections.get(leftStageName);
    JoinCollection leftCollection = new JoinCollection(leftStageName, inputDataCollections.get(leftStageName), leftStage.getSchema(), Collections.emptyList(), leftStage.isRequired(), leftStage.isBroadcast());
    JoinCollection rightCollection = new JoinCollection(rightStageName, inputDataCollections.get(rightStageName), rightStage.getSchema(), Collections.emptyList(), rightStage.isRequired(), rightStage.isBroadcast());
    JoinExpressionRequest joinRequest = new JoinExpressionRequest(stageName, joinDefinition.getSelectedFields(), leftCollection, rightCollection, condition, joinDefinition.getOutputSchema(), joinDefinition);
    return leftData.join(joinRequest);
}
Also used : JoinStage(io.cdap.cdap.etl.api.join.JoinStage) JoinExpressionRequest(io.cdap.cdap.etl.spark.join.JoinExpressionRequest) JoinCollection(io.cdap.cdap.etl.spark.join.JoinCollection) JoinCondition(io.cdap.cdap.etl.api.join.JoinCondition)

Aggregations

JoinCollection (io.cdap.cdap.etl.spark.join.JoinCollection)2 JoinExpressionRequest (io.cdap.cdap.etl.spark.join.JoinExpressionRequest)2 DatasetContext (io.cdap.cdap.api.data.DatasetContext)1 StructuredRecord (io.cdap.cdap.api.data.format.StructuredRecord)1 Schema (io.cdap.cdap.api.data.schema.Schema)1 JavaSparkExecutionContext (io.cdap.cdap.api.spark.JavaSparkExecutionContext)1 DataFrames (io.cdap.cdap.api.spark.sql.DataFrames)1 JoinCondition (io.cdap.cdap.etl.api.join.JoinCondition)1 JoinField (io.cdap.cdap.etl.api.join.JoinField)1 JoinStage (io.cdap.cdap.etl.api.join.JoinStage)1 Constants (io.cdap.cdap.etl.common.Constants)1 RecordInfo (io.cdap.cdap.etl.common.RecordInfo)1 StageStatisticsCollector (io.cdap.cdap.etl.common.StageStatisticsCollector)1 StageSpec (io.cdap.cdap.etl.proto.v2.spec.StageSpec)1 SparkCollection (io.cdap.cdap.etl.spark.SparkCollection)1 CountingFunction (io.cdap.cdap.etl.spark.function.CountingFunction)1 DatasetAggregationAccumulator (io.cdap.cdap.etl.spark.function.DatasetAggregationAccumulator)1 DatasetAggregationFinalizeFunction (io.cdap.cdap.etl.spark.function.DatasetAggregationFinalizeFunction)1 DatasetAggregationGetKeyFunction (io.cdap.cdap.etl.spark.function.DatasetAggregationGetKeyFunction)1 DatasetAggregationReduceFunction (io.cdap.cdap.etl.spark.function.DatasetAggregationReduceFunction)1