Search in sources :

Example 6 with JoinField

use of io.cdap.cdap.etl.api.join.JoinField in project cdap by caskdata.

the class RDDCollection method join.

@SuppressWarnings("unchecked")
@Override
public SparkCollection<T> join(JoinRequest joinRequest) {
    Map<String, Dataset> collections = new HashMap<>();
    String stageName = joinRequest.getStageName();
    Function<StructuredRecord, StructuredRecord> recordsInCounter = new CountingFunction<>(stageName, sec.getMetrics(), Constants.Metrics.RECORDS_IN, sec.getDataTracer(stageName));
    StructType leftSparkSchema = DataFrames.toDataType(joinRequest.getLeftSchema());
    Dataset<Row> left = toDataset(((JavaRDD<StructuredRecord>) rdd).map(recordsInCounter), leftSparkSchema);
    collections.put(joinRequest.getLeftStage(), left);
    List<Column> leftJoinColumns = joinRequest.getLeftKey().stream().map(left::col).collect(Collectors.toList());
    /*
        This flag keeps track of whether there is at least one required stage in the join.
        This is needed in case there is a join like:

        A (optional), B (required), C (optional), D (required)

        The correct thing to do here is:

        1. A right outer join B as TMP1
        2. TMP1 left outer join C as TMP2
        3. TMP2 inner join D

        Join #1 is a straightforward join between 2 sides.
        Join #2 is a left outer because TMP1 becomes 'required', since it uses required input B.
        Join #3 is an inner join even though it contains 2 optional datasets, because 'B' is still required.
     */
    Integer joinPartitions = joinRequest.getNumPartitions();
    boolean seenRequired = joinRequest.isLeftRequired();
    Dataset<Row> joined = left;
    List<List<Column>> listOfListOfLeftCols = new ArrayList<>();
    for (JoinCollection toJoin : joinRequest.getToJoin()) {
        SparkCollection<StructuredRecord> data = (SparkCollection<StructuredRecord>) toJoin.getData();
        StructType sparkSchema = DataFrames.toDataType(toJoin.getSchema());
        Dataset<Row> right = toDataset(((JavaRDD<StructuredRecord>) data.getUnderlying()).map(recordsInCounter), sparkSchema);
        collections.put(toJoin.getStage(), right);
        List<Column> rightJoinColumns = toJoin.getKey().stream().map(right::col).collect(Collectors.toList());
        // UUID for salt column name to avoid name collisions
        String saltColumn = UUID.randomUUID().toString();
        if (joinRequest.isDistributionEnabled()) {
            boolean isLeftStageSkewed = joinRequest.getLeftStage().equals(joinRequest.getDistribution().getSkewedStageName());
            // Apply salt/explode transformations to each Dataset
            if (isLeftStageSkewed) {
                left = saltDataset(left, saltColumn, joinRequest.getDistribution().getDistributionFactor());
                right = explodeDataset(right, saltColumn, joinRequest.getDistribution().getDistributionFactor());
            } else {
                left = explodeDataset(left, saltColumn, joinRequest.getDistribution().getDistributionFactor());
                right = saltDataset(right, saltColumn, joinRequest.getDistribution().getDistributionFactor());
            }
            // Add the salt column to the join key
            leftJoinColumns.add(left.col(saltColumn));
            rightJoinColumns.add(right.col(saltColumn));
            // Updating other values that will be used later in join
            joined = left;
            sparkSchema = sparkSchema.add(saltColumn, DataTypes.IntegerType, false);
            leftSparkSchema = leftSparkSchema.add(saltColumn, DataTypes.IntegerType, false);
        }
        Column joinOn;
        // Making effectively final to use in streams
        List<Column> finalLeftJoinColumns = leftJoinColumns;
        if (seenRequired) {
            joinOn = IntStream.range(0, leftJoinColumns.size()).mapToObj(i -> eq(finalLeftJoinColumns.get(i), rightJoinColumns.get(i), joinRequest.isNullSafe())).reduce((a, b) -> a.and(b)).get();
        } else {
            // For the case when all joins are outer. Collect left keys at each level (each iteration)
            // coalesce these keys at each level and compare with right
            joinOn = IntStream.range(0, leftJoinColumns.size()).mapToObj(i -> {
                collectLeftJoinOnCols(listOfListOfLeftCols, i, finalLeftJoinColumns.get(i));
                return eq(getLeftJoinOnCoalescedColumn(finalLeftJoinColumns.get(i), i, listOfListOfLeftCols), rightJoinColumns.get(i), joinRequest.isNullSafe());
            }).reduce((a, b) -> a.and(b)).get();
        }
        String joinType;
        if (seenRequired && toJoin.isRequired()) {
            joinType = "inner";
        } else if (seenRequired && !toJoin.isRequired()) {
            joinType = "leftouter";
        } else if (!seenRequired && toJoin.isRequired()) {
            joinType = "rightouter";
        } else {
            joinType = "outer";
        }
        seenRequired = seenRequired || toJoin.isRequired();
        if (toJoin.isBroadcast()) {
            right = functions.broadcast(right);
        }
        // we are forced to with spark.cdap.pipeline.aggregate.dataset.partitions.ignore = false
        if (!ignorePartitionsDuringDatasetAggregation && joinPartitions != null && !toJoin.isBroadcast()) {
            List<String> rightKeys = new ArrayList<>(toJoin.getKey());
            List<String> leftKeys = new ArrayList<>(joinRequest.getLeftKey());
            // number of partitions
            if (joinRequest.isDistributionEnabled()) {
                rightKeys.add(saltColumn);
                leftKeys.add(saltColumn);
            }
            right = partitionOnKey(right, rightKeys, joinRequest.isNullSafe(), sparkSchema, joinPartitions);
            // as intermediate joins will already be partitioned on the key
            if (joined == left) {
                joined = partitionOnKey(joined, leftKeys, joinRequest.isNullSafe(), leftSparkSchema, joinPartitions);
            }
        }
        joined = joined.join(right, joinOn, joinType);
        /*
           Additionally if none of the datasets are required until now, which means all of the joines will outer.
           In this case also we need to pass on the join columns as we need to compare using coalesce of all previous
           columns with the right dataset
       */
        if (toJoin.isRequired() || !seenRequired) {
            leftJoinColumns = rightJoinColumns;
        }
    }
    // select and alias fields in the expected order
    List<Column> outputColumns = new ArrayList<>(joinRequest.getFields().size());
    for (JoinField field : joinRequest.getFields()) {
        Column column = collections.get(field.getStageName()).col(field.getFieldName());
        if (field.getAlias() != null) {
            column = column.alias(field.getAlias());
        }
        outputColumns.add(column);
    }
    Seq<Column> outputColumnSeq = JavaConversions.asScalaBuffer(outputColumns).toSeq();
    joined = joined.select(outputColumnSeq);
    Schema outputSchema = joinRequest.getOutputSchema();
    JavaRDD<StructuredRecord> output = joined.javaRDD().map(r -> DataFrames.fromRow(r, outputSchema)).map(new CountingFunction<>(stageName, sec.getMetrics(), Constants.Metrics.RECORDS_OUT, sec.getDataTracer(stageName)));
    return (SparkCollection<T>) wrap(output);
}
Also used : DataType(org.apache.spark.sql.types.DataType) org.apache.spark.sql.functions.coalesce(org.apache.spark.sql.functions.coalesce) Arrays(java.util.Arrays) DataFrames(io.cdap.cdap.api.spark.sql.DataFrames) DatasetAggregationReduceFunction(io.cdap.cdap.etl.spark.function.DatasetAggregationReduceFunction) JoinExpressionRequest(io.cdap.cdap.etl.spark.join.JoinExpressionRequest) PluginFunctionContext(io.cdap.cdap.etl.spark.function.PluginFunctionContext) LoggerFactory(org.slf4j.LoggerFactory) CountingFunction(io.cdap.cdap.etl.spark.function.CountingFunction) Constants(io.cdap.cdap.etl.common.Constants) StructuredRecord(io.cdap.cdap.api.data.format.StructuredRecord) JavaSparkExecutionContext(io.cdap.cdap.api.spark.JavaSparkExecutionContext) DatasetContext(io.cdap.cdap.api.data.DatasetContext) StorageLevel(org.apache.spark.storage.StorageLevel) Map(java.util.Map) MapFunction(org.apache.spark.api.java.function.MapFunction) FunctionCache(io.cdap.cdap.etl.spark.function.FunctionCache) DataTypes(org.apache.spark.sql.types.DataTypes) StructType(org.apache.spark.sql.types.StructType) JoinField(io.cdap.cdap.etl.api.join.JoinField) Seq(scala.collection.Seq) RecordInfo(io.cdap.cdap.etl.common.RecordInfo) UUID(java.util.UUID) Tuple2(scala.Tuple2) Collectors(java.util.stream.Collectors) StageStatisticsCollector(io.cdap.cdap.etl.common.StageStatisticsCollector) SparkCollection(io.cdap.cdap.etl.spark.SparkCollection) List(java.util.List) JoinRequest(io.cdap.cdap.etl.spark.join.JoinRequest) Encoder(org.apache.spark.sql.Encoder) DatasetAggregationFinalizeFunction(io.cdap.cdap.etl.spark.function.DatasetAggregationFinalizeFunction) Function(org.apache.spark.api.java.function.Function) org.apache.spark.sql.functions(org.apache.spark.sql.functions) IntStream(java.util.stream.IntStream) LiteralsBridge(io.cdap.cdap.etl.spark.plugin.LiteralsBridge) Dataset(org.apache.spark.sql.Dataset) DatasetAggregationAccumulator(io.cdap.cdap.etl.spark.function.DatasetAggregationAccumulator) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) DatasetAggregationGetKeyFunction(io.cdap.cdap.etl.spark.function.DatasetAggregationGetKeyFunction) JavaRDD(org.apache.spark.api.java.JavaRDD) Nullable(javax.annotation.Nullable) JavaConversions(scala.collection.JavaConversions) Logger(org.slf4j.Logger) Iterator(java.util.Iterator) JoinCollection(io.cdap.cdap.etl.spark.join.JoinCollection) Column(org.apache.spark.sql.Column) SQLContext(org.apache.spark.sql.SQLContext) Row(org.apache.spark.sql.Row) Schema(io.cdap.cdap.api.data.schema.Schema) Encoders(org.apache.spark.sql.Encoders) org.apache.spark.sql.functions.floor(org.apache.spark.sql.functions.floor) StageSpec(io.cdap.cdap.etl.proto.v2.spec.StageSpec) StructType(org.apache.spark.sql.types.StructType) HashMap(java.util.HashMap) Schema(io.cdap.cdap.api.data.schema.Schema) ArrayList(java.util.ArrayList) JoinField(io.cdap.cdap.etl.api.join.JoinField) CountingFunction(io.cdap.cdap.etl.spark.function.CountingFunction) StructuredRecord(io.cdap.cdap.api.data.format.StructuredRecord) Column(org.apache.spark.sql.Column) List(java.util.List) ArrayList(java.util.ArrayList) JoinCollection(io.cdap.cdap.etl.spark.join.JoinCollection) Dataset(org.apache.spark.sql.Dataset) SparkCollection(io.cdap.cdap.etl.spark.SparkCollection) Row(org.apache.spark.sql.Row)

Example 7 with JoinField

use of io.cdap.cdap.etl.api.join.JoinField in project cdap by caskdata.

the class SparkPipelineRunner method handleAutoJoinOnKeys.

/**
 * The purpose of this method is to collect various pieces of information together into a JoinRequest.
 * This amounts to gathering the SparkCollection, schema, join key, and join type for each stage involved in the join.
 */
private SparkCollection<Object> handleAutoJoinOnKeys(String stageName, JoinDefinition joinDefinition, Map<String, SparkCollection<Object>> inputDataCollections, @Nullable Integer numPartitions) {
    // sort stages to join so that broadcasts happen last. This is to ensure that the left side is not a broadcast
    // so that we don't try to broadcast both sides of the join. It also causes less data to be shuffled for the
    // non-broadcast joins.
    List<JoinStage> joinOrder = new ArrayList<>(joinDefinition.getStages());
    joinOrder.sort((s1, s2) -> {
        if (s1.isBroadcast() && !s2.isBroadcast()) {
            return 1;
        } else if (!s1.isBroadcast() && s2.isBroadcast()) {
            return -1;
        }
        return 0;
    });
    Iterator<JoinStage> stageIter = joinOrder.iterator();
    JoinStage left = stageIter.next();
    String leftName = left.getStageName();
    SparkCollection<Object> leftCollection = inputDataCollections.get(left.getStageName());
    Schema leftSchema = left.getSchema();
    JoinCondition condition = joinDefinition.getCondition();
    JoinCondition.OnKeys onKeys = (JoinCondition.OnKeys) condition;
    // If this is a join on A.x = B.y = C.z and A.k = B.k = C.k, then stageKeys will look like:
    // A -> [x, k]
    // B -> [y, k]
    // C -> [z, k]
    Map<String, List<String>> stageKeys = onKeys.getKeys().stream().collect(Collectors.toMap(JoinKey::getStageName, JoinKey::getFields));
    Schema outputSchema = joinDefinition.getOutputSchema();
    // when we properly propagate schema at runtime, this condition should no longer happen
    if (outputSchema == null) {
        throw new IllegalArgumentException(String.format("Joiner stage '%s' cannot calculate its output schema because " + "one or more inputs have dynamic or unknown schema. " + "An output schema must be directly provided.", stageName));
    }
    List<JoinCollection> toJoin = new ArrayList<>();
    List<Schema> keySchema = null;
    while (stageIter.hasNext()) {
        // in this loop, information for each stage to be joined is gathered together into a JoinCollection
        JoinStage right = stageIter.next();
        String rightName = right.getStageName();
        Schema rightSchema = right.getSchema();
        List<String> key = stageKeys.get(rightName);
        if (rightSchema == null) {
            if (keySchema == null) {
                keySchema = deriveKeySchema(stageName, stageKeys, joinDefinition);
            }
            // if the schema is not known, generate it from the provided output schema and the selected fields
            rightSchema = deriveInputSchema(stageName, rightName, key, keySchema, joinDefinition.getSelectedFields(), joinDefinition.getOutputSchema());
        } else {
            // drop fields that aren't included in the final output
            // don't need to do this if the schema was derived, since it will already only contain
            // fields in the output schema
            Set<String> requiredFields = new HashSet<>(key);
            for (JoinField joinField : joinDefinition.getSelectedFields()) {
                if (!joinField.getStageName().equals(rightName)) {
                    continue;
                }
                requiredFields.add(joinField.getFieldName());
            }
            rightSchema = trimSchema(rightSchema, requiredFields);
        }
        // JoinCollection contains the stage name,  SparkCollection, schema, joinkey,
        // whether it's required, and whether to broadcast
        toJoin.add(new JoinCollection(rightName, inputDataCollections.get(rightName), rightSchema, key, right.isRequired(), right.isBroadcast()));
    }
    List<String> leftKey = stageKeys.get(leftName);
    if (leftSchema == null) {
        if (keySchema == null) {
            keySchema = deriveKeySchema(stageName, stageKeys, joinDefinition);
        }
        leftSchema = deriveInputSchema(stageName, leftName, leftKey, keySchema, joinDefinition.getSelectedFields(), joinDefinition.getOutputSchema());
    }
    // JoinRequest contains the left side of the join, plus 1 or more other stages to join to.
    JoinRequest joinRequest = new JoinRequest(stageName, leftName, leftKey, leftSchema, left.isRequired(), onKeys.isNullSafe(), joinDefinition.getSelectedFields(), joinDefinition.getOutputSchema(), toJoin, numPartitions, joinDefinition.getDistribution(), joinDefinition);
    return leftCollection.join(joinRequest);
}
Also used : JoinStage(io.cdap.cdap.etl.api.join.JoinStage) Schema(io.cdap.cdap.api.data.schema.Schema) ArrayList(java.util.ArrayList) JoinField(io.cdap.cdap.etl.api.join.JoinField) JoinRequest(io.cdap.cdap.etl.spark.join.JoinRequest) JoinCondition(io.cdap.cdap.etl.api.join.JoinCondition) List(java.util.List) ArrayList(java.util.ArrayList) JoinCollection(io.cdap.cdap.etl.spark.join.JoinCollection) HashSet(java.util.HashSet)

Example 8 with JoinField

use of io.cdap.cdap.etl.api.join.JoinField in project cdap by cdapio.

the class BaseRDDCollection method getSQL.

static String getSQL(JoinExpressionRequest join) {
    JoinCondition.OnExpression condition = join.getCondition();
    Map<String, String> datasetAliases = condition.getDatasetAliases();
    String leftName = join.getLeft().getStage();
    String leftAlias = datasetAliases.getOrDefault(leftName, leftName);
    String rightName = join.getRight().getStage();
    String rightAlias = datasetAliases.getOrDefault(rightName, rightName);
    StringBuilder query = new StringBuilder("SELECT ");
    // see https://spark.apache.org/docs/3.0.0/sql-ref-syntax-qry-select-hints.html for more info on join hints
    if (join.getLeft().isBroadcast() && join.getRight().isBroadcast()) {
        query.append("/*+ BROADCAST(").append(leftAlias).append("), BROADCAST(").append(rightAlias).append(") */ ");
    } else if (join.getLeft().isBroadcast()) {
        query.append("/*+ BROADCAST(").append(leftAlias).append(") */ ");
    } else if (join.getRight().isBroadcast()) {
        query.append("/*+ BROADCAST(").append(rightAlias).append(") */ ");
    }
    for (JoinField field : join.getFields()) {
        String outputName = field.getAlias() == null ? field.getFieldName() : field.getAlias();
        String datasetName = datasetAliases.getOrDefault(field.getStageName(), field.getStageName());
        // `datasetName`.`fieldName` as outputName
        query.append("`").append(datasetName).append("`.`").append(field.getFieldName()).append("` as ").append(outputName).append(", ");
    }
    // remove trailing ', '
    query.setLength(query.length() - 2);
    String joinType;
    boolean leftRequired = join.getLeft().isRequired();
    boolean rightRequired = join.getRight().isRequired();
    if (leftRequired && rightRequired) {
        joinType = "JOIN";
    } else if (leftRequired && !rightRequired) {
        joinType = "LEFT OUTER JOIN";
    } else if (!leftRequired && rightRequired) {
        joinType = "RIGHT OUTER JOIN";
    } else {
        joinType = "FULL OUTER JOIN";
    }
    // FROM `leftDataset` as `leftAlias` JOIN `rightDataset` as `rightAlias`
    query.append(" FROM `").append(leftName).append("` as `").append(leftAlias).append("` ");
    query.append(joinType).append(" `").append(rightName).append("` as `").append(rightAlias).append("` ");
    // ON [expr]
    query.append(" ON ").append(condition.getExpression());
    return query.toString();
}
Also used : JoinField(io.cdap.cdap.etl.api.join.JoinField) JoinCondition(io.cdap.cdap.etl.api.join.JoinCondition)

Example 9 with JoinField

use of io.cdap.cdap.etl.api.join.JoinField in project cdap by cdapio.

the class SparkPipelineRunner method handleAutoJoinOnKeys.

/**
 * The purpose of this method is to collect various pieces of information together into a JoinRequest.
 * This amounts to gathering the SparkCollection, schema, join key, and join type for each stage involved in the join.
 */
private SparkCollection<Object> handleAutoJoinOnKeys(String stageName, JoinDefinition joinDefinition, Map<String, SparkCollection<Object>> inputDataCollections, @Nullable Integer numPartitions) {
    // sort stages to join so that broadcasts happen last. This is to ensure that the left side is not a broadcast
    // so that we don't try to broadcast both sides of the join. It also causes less data to be shuffled for the
    // non-broadcast joins.
    List<JoinStage> joinOrder = new ArrayList<>(joinDefinition.getStages());
    joinOrder.sort((s1, s2) -> {
        if (s1.isBroadcast() && !s2.isBroadcast()) {
            return 1;
        } else if (!s1.isBroadcast() && s2.isBroadcast()) {
            return -1;
        }
        return 0;
    });
    Iterator<JoinStage> stageIter = joinOrder.iterator();
    JoinStage left = stageIter.next();
    String leftName = left.getStageName();
    SparkCollection<Object> leftCollection = inputDataCollections.get(left.getStageName());
    Schema leftSchema = left.getSchema();
    JoinCondition condition = joinDefinition.getCondition();
    JoinCondition.OnKeys onKeys = (JoinCondition.OnKeys) condition;
    // If this is a join on A.x = B.y = C.z and A.k = B.k = C.k, then stageKeys will look like:
    // A -> [x, k]
    // B -> [y, k]
    // C -> [z, k]
    Map<String, List<String>> stageKeys = onKeys.getKeys().stream().collect(Collectors.toMap(JoinKey::getStageName, JoinKey::getFields));
    Schema outputSchema = joinDefinition.getOutputSchema();
    // when we properly propagate schema at runtime, this condition should no longer happen
    if (outputSchema == null) {
        throw new IllegalArgumentException(String.format("Joiner stage '%s' cannot calculate its output schema because " + "one or more inputs have dynamic or unknown schema. " + "An output schema must be directly provided.", stageName));
    }
    List<JoinCollection> toJoin = new ArrayList<>();
    List<Schema> keySchema = null;
    while (stageIter.hasNext()) {
        // in this loop, information for each stage to be joined is gathered together into a JoinCollection
        JoinStage right = stageIter.next();
        String rightName = right.getStageName();
        Schema rightSchema = right.getSchema();
        List<String> key = stageKeys.get(rightName);
        if (rightSchema == null) {
            if (keySchema == null) {
                keySchema = deriveKeySchema(stageName, stageKeys, joinDefinition);
            }
            // if the schema is not known, generate it from the provided output schema and the selected fields
            rightSchema = deriveInputSchema(stageName, rightName, key, keySchema, joinDefinition.getSelectedFields(), joinDefinition.getOutputSchema());
        } else {
            // drop fields that aren't included in the final output
            // don't need to do this if the schema was derived, since it will already only contain
            // fields in the output schema
            Set<String> requiredFields = new HashSet<>(key);
            for (JoinField joinField : joinDefinition.getSelectedFields()) {
                if (!joinField.getStageName().equals(rightName)) {
                    continue;
                }
                requiredFields.add(joinField.getFieldName());
            }
            rightSchema = trimSchema(rightSchema, requiredFields);
        }
        // JoinCollection contains the stage name,  SparkCollection, schema, joinkey,
        // whether it's required, and whether to broadcast
        toJoin.add(new JoinCollection(rightName, inputDataCollections.get(rightName), rightSchema, key, right.isRequired(), right.isBroadcast()));
    }
    List<String> leftKey = stageKeys.get(leftName);
    if (leftSchema == null) {
        if (keySchema == null) {
            keySchema = deriveKeySchema(stageName, stageKeys, joinDefinition);
        }
        leftSchema = deriveInputSchema(stageName, leftName, leftKey, keySchema, joinDefinition.getSelectedFields(), joinDefinition.getOutputSchema());
    }
    // JoinRequest contains the left side of the join, plus 1 or more other stages to join to.
    JoinRequest joinRequest = new JoinRequest(stageName, leftName, leftKey, leftSchema, left.isRequired(), onKeys.isNullSafe(), joinDefinition.getSelectedFields(), joinDefinition.getOutputSchema(), toJoin, numPartitions, joinDefinition.getDistribution(), joinDefinition);
    return leftCollection.join(joinRequest);
}
Also used : JoinStage(io.cdap.cdap.etl.api.join.JoinStage) Schema(io.cdap.cdap.api.data.schema.Schema) ArrayList(java.util.ArrayList) JoinField(io.cdap.cdap.etl.api.join.JoinField) JoinRequest(io.cdap.cdap.etl.spark.join.JoinRequest) JoinCondition(io.cdap.cdap.etl.api.join.JoinCondition) List(java.util.List) ArrayList(java.util.ArrayList) JoinCollection(io.cdap.cdap.etl.spark.join.JoinCollection) HashSet(java.util.HashSet)

Example 10 with JoinField

use of io.cdap.cdap.etl.api.join.JoinField in project cdap by cdapio.

the class RDDCollection method join.

@SuppressWarnings("unchecked")
@Override
public SparkCollection<T> join(JoinRequest joinRequest) {
    Map<String, Dataset> collections = new HashMap<>();
    String stageName = joinRequest.getStageName();
    Function<StructuredRecord, StructuredRecord> recordsInCounter = new CountingFunction<>(stageName, sec.getMetrics(), Constants.Metrics.RECORDS_IN, sec.getDataTracer(stageName));
    StructType leftSparkSchema = DataFrames.toDataType(joinRequest.getLeftSchema());
    Dataset<Row> left = toDataset(((JavaRDD<StructuredRecord>) rdd).map(recordsInCounter), leftSparkSchema);
    collections.put(joinRequest.getLeftStage(), left);
    List<Column> leftJoinColumns = joinRequest.getLeftKey().stream().map(left::col).collect(Collectors.toList());
    /*
        This flag keeps track of whether there is at least one required stage in the join.
        This is needed in case there is a join like:

        A (optional), B (required), C (optional), D (required)

        The correct thing to do here is:

        1. A right outer join B as TMP1
        2. TMP1 left outer join C as TMP2
        3. TMP2 inner join D

        Join #1 is a straightforward join between 2 sides.
        Join #2 is a left outer because TMP1 becomes 'required', since it uses required input B.
        Join #3 is an inner join even though it contains 2 optional datasets, because 'B' is still required.
     */
    Integer joinPartitions = joinRequest.getNumPartitions();
    boolean seenRequired = joinRequest.isLeftRequired();
    Dataset<Row> joined = left;
    List<List<Column>> listOfListOfLeftCols = new ArrayList<>();
    for (JoinCollection toJoin : joinRequest.getToJoin()) {
        SparkCollection<StructuredRecord> data = (SparkCollection<StructuredRecord>) toJoin.getData();
        StructType sparkSchema = DataFrames.toDataType(toJoin.getSchema());
        Dataset<Row> right = toDataset(((JavaRDD<StructuredRecord>) data.getUnderlying()).map(recordsInCounter), sparkSchema);
        collections.put(toJoin.getStage(), right);
        List<Column> rightJoinColumns = toJoin.getKey().stream().map(right::col).collect(Collectors.toList());
        // UUID for salt column name to avoid name collisions
        String saltColumn = UUID.randomUUID().toString();
        if (joinRequest.isDistributionEnabled()) {
            boolean isLeftStageSkewed = joinRequest.getLeftStage().equals(joinRequest.getDistribution().getSkewedStageName());
            // Apply salt/explode transformations to each Dataset
            if (isLeftStageSkewed) {
                left = saltDataset(left, saltColumn, joinRequest.getDistribution().getDistributionFactor());
                right = explodeDataset(right, saltColumn, joinRequest.getDistribution().getDistributionFactor());
            } else {
                left = explodeDataset(left, saltColumn, joinRequest.getDistribution().getDistributionFactor());
                right = saltDataset(right, saltColumn, joinRequest.getDistribution().getDistributionFactor());
            }
            // Add the salt column to the join key
            leftJoinColumns.add(left.col(saltColumn));
            rightJoinColumns.add(right.col(saltColumn));
            // Updating other values that will be used later in join
            joined = left;
            sparkSchema = sparkSchema.add(saltColumn, DataTypes.IntegerType, false);
            leftSparkSchema = leftSparkSchema.add(saltColumn, DataTypes.IntegerType, false);
        }
        Column joinOn;
        // Making effectively final to use in streams
        List<Column> finalLeftJoinColumns = leftJoinColumns;
        if (seenRequired) {
            joinOn = IntStream.range(0, leftJoinColumns.size()).mapToObj(i -> eq(finalLeftJoinColumns.get(i), rightJoinColumns.get(i), joinRequest.isNullSafe())).reduce((a, b) -> a.and(b)).get();
        } else {
            // For the case when all joins are outer. Collect left keys at each level (each iteration)
            // coalesce these keys at each level and compare with right
            joinOn = IntStream.range(0, leftJoinColumns.size()).mapToObj(i -> {
                collectLeftJoinOnCols(listOfListOfLeftCols, i, finalLeftJoinColumns.get(i));
                return eq(getLeftJoinOnCoalescedColumn(finalLeftJoinColumns.get(i), i, listOfListOfLeftCols), rightJoinColumns.get(i), joinRequest.isNullSafe());
            }).reduce((a, b) -> a.and(b)).get();
        }
        String joinType;
        if (seenRequired && toJoin.isRequired()) {
            joinType = "inner";
        } else if (seenRequired && !toJoin.isRequired()) {
            joinType = "leftouter";
        } else if (!seenRequired && toJoin.isRequired()) {
            joinType = "rightouter";
        } else {
            joinType = "outer";
        }
        seenRequired = seenRequired || toJoin.isRequired();
        if (toJoin.isBroadcast()) {
            right = functions.broadcast(right);
        }
        // we are forced to with spark.cdap.pipeline.aggregate.dataset.partitions.ignore = false
        if (!ignorePartitionsDuringDatasetAggregation && joinPartitions != null && !toJoin.isBroadcast()) {
            List<String> rightKeys = new ArrayList<>(toJoin.getKey());
            List<String> leftKeys = new ArrayList<>(joinRequest.getLeftKey());
            // number of partitions
            if (joinRequest.isDistributionEnabled()) {
                rightKeys.add(saltColumn);
                leftKeys.add(saltColumn);
            }
            right = partitionOnKey(right, rightKeys, joinRequest.isNullSafe(), sparkSchema, joinPartitions);
            // as intermediate joins will already be partitioned on the key
            if (joined == left) {
                joined = partitionOnKey(joined, leftKeys, joinRequest.isNullSafe(), leftSparkSchema, joinPartitions);
            }
        }
        joined = joined.join(right, joinOn, joinType);
        /*
           Additionally if none of the datasets are required until now, which means all of the joines will outer.
           In this case also we need to pass on the join columns as we need to compare using coalesce of all previous
           columns with the right dataset
       */
        if (toJoin.isRequired() || !seenRequired) {
            leftJoinColumns = rightJoinColumns;
        }
    }
    // select and alias fields in the expected order
    List<Column> outputColumns = new ArrayList<>(joinRequest.getFields().size());
    for (JoinField field : joinRequest.getFields()) {
        Column column = collections.get(field.getStageName()).col(field.getFieldName());
        if (field.getAlias() != null) {
            column = column.alias(field.getAlias());
        }
        outputColumns.add(column);
    }
    Seq<Column> outputColumnSeq = JavaConversions.asScalaBuffer(outputColumns).toSeq();
    joined = joined.select(outputColumnSeq);
    Schema outputSchema = joinRequest.getOutputSchema();
    JavaRDD<StructuredRecord> output = joined.javaRDD().map(r -> DataFrames.fromRow(r, outputSchema)).map(new CountingFunction<>(stageName, sec.getMetrics(), Constants.Metrics.RECORDS_OUT, sec.getDataTracer(stageName)));
    return (SparkCollection<T>) wrap(output);
}
Also used : DataType(org.apache.spark.sql.types.DataType) org.apache.spark.sql.functions.coalesce(org.apache.spark.sql.functions.coalesce) Arrays(java.util.Arrays) DataFrames(io.cdap.cdap.api.spark.sql.DataFrames) DatasetAggregationReduceFunction(io.cdap.cdap.etl.spark.function.DatasetAggregationReduceFunction) JoinExpressionRequest(io.cdap.cdap.etl.spark.join.JoinExpressionRequest) PluginFunctionContext(io.cdap.cdap.etl.spark.function.PluginFunctionContext) LoggerFactory(org.slf4j.LoggerFactory) CountingFunction(io.cdap.cdap.etl.spark.function.CountingFunction) Constants(io.cdap.cdap.etl.common.Constants) StructuredRecord(io.cdap.cdap.api.data.format.StructuredRecord) JavaSparkExecutionContext(io.cdap.cdap.api.spark.JavaSparkExecutionContext) DatasetContext(io.cdap.cdap.api.data.DatasetContext) StorageLevel(org.apache.spark.storage.StorageLevel) Map(java.util.Map) MapFunction(org.apache.spark.api.java.function.MapFunction) FunctionCache(io.cdap.cdap.etl.spark.function.FunctionCache) DataTypes(org.apache.spark.sql.types.DataTypes) StructType(org.apache.spark.sql.types.StructType) JoinField(io.cdap.cdap.etl.api.join.JoinField) Seq(scala.collection.Seq) RecordInfo(io.cdap.cdap.etl.common.RecordInfo) UUID(java.util.UUID) Tuple2(scala.Tuple2) Collectors(java.util.stream.Collectors) StageStatisticsCollector(io.cdap.cdap.etl.common.StageStatisticsCollector) SparkCollection(io.cdap.cdap.etl.spark.SparkCollection) List(java.util.List) JoinRequest(io.cdap.cdap.etl.spark.join.JoinRequest) Encoder(org.apache.spark.sql.Encoder) DatasetAggregationFinalizeFunction(io.cdap.cdap.etl.spark.function.DatasetAggregationFinalizeFunction) Function(org.apache.spark.api.java.function.Function) org.apache.spark.sql.functions(org.apache.spark.sql.functions) IntStream(java.util.stream.IntStream) LiteralsBridge(io.cdap.cdap.etl.spark.plugin.LiteralsBridge) Dataset(org.apache.spark.sql.Dataset) DatasetAggregationAccumulator(io.cdap.cdap.etl.spark.function.DatasetAggregationAccumulator) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) DatasetAggregationGetKeyFunction(io.cdap.cdap.etl.spark.function.DatasetAggregationGetKeyFunction) JavaRDD(org.apache.spark.api.java.JavaRDD) Nullable(javax.annotation.Nullable) JavaConversions(scala.collection.JavaConversions) Logger(org.slf4j.Logger) Iterator(java.util.Iterator) JoinCollection(io.cdap.cdap.etl.spark.join.JoinCollection) Column(org.apache.spark.sql.Column) SQLContext(org.apache.spark.sql.SQLContext) Row(org.apache.spark.sql.Row) Schema(io.cdap.cdap.api.data.schema.Schema) Encoders(org.apache.spark.sql.Encoders) org.apache.spark.sql.functions.floor(org.apache.spark.sql.functions.floor) StageSpec(io.cdap.cdap.etl.proto.v2.spec.StageSpec) StructType(org.apache.spark.sql.types.StructType) HashMap(java.util.HashMap) Schema(io.cdap.cdap.api.data.schema.Schema) ArrayList(java.util.ArrayList) JoinField(io.cdap.cdap.etl.api.join.JoinField) CountingFunction(io.cdap.cdap.etl.spark.function.CountingFunction) StructuredRecord(io.cdap.cdap.api.data.format.StructuredRecord) Column(org.apache.spark.sql.Column) List(java.util.List) ArrayList(java.util.ArrayList) JoinCollection(io.cdap.cdap.etl.spark.join.JoinCollection) Dataset(org.apache.spark.sql.Dataset) SparkCollection(io.cdap.cdap.etl.spark.SparkCollection) Row(org.apache.spark.sql.Row)

Aggregations

JoinField (io.cdap.cdap.etl.api.join.JoinField)40 ArrayList (java.util.ArrayList)30 Schema (io.cdap.cdap.api.data.schema.Schema)20 HashSet (java.util.HashSet)19 Test (org.junit.Test)17 JoinCondition (io.cdap.cdap.etl.api.join.JoinCondition)15 StructuredRecord (io.cdap.cdap.api.data.format.StructuredRecord)14 ETLBatchConfig (io.cdap.cdap.etl.proto.v2.ETLBatchConfig)12 ETLStage (io.cdap.cdap.etl.proto.v2.ETLStage)12 HashMap (java.util.HashMap)12 JoinKey (io.cdap.cdap.etl.api.join.JoinKey)10 Table (io.cdap.cdap.api.dataset.table.Table)8 ETLPlugin (io.cdap.cdap.etl.proto.v2.ETLPlugin)8 AppRequest (io.cdap.cdap.proto.artifact.AppRequest)8 ApplicationId (io.cdap.cdap.proto.id.ApplicationId)8 ApplicationManager (io.cdap.cdap.test.ApplicationManager)8 WorkflowManager (io.cdap.cdap.test.WorkflowManager)8 List (java.util.List)8 FieldOperation (io.cdap.cdap.etl.api.lineage.field.FieldOperation)7 FieldTransformOperation (io.cdap.cdap.etl.api.lineage.field.FieldTransformOperation)7