use of io.cdap.cdap.etl.spark.join.JoinExpressionRequest in project cdap by caskdata.
the class RDDCollection method join.
@SuppressWarnings("unchecked")
@Override
public SparkCollection<T> join(JoinExpressionRequest joinRequest) {
Function<StructuredRecord, StructuredRecord> recordsInCounter = new CountingFunction<>(joinRequest.getStageName(), sec.getMetrics(), Constants.Metrics.RECORDS_IN, sec.getDataTracer(joinRequest.getStageName()));
JoinCollection leftInfo = joinRequest.getLeft();
StructType leftSchema = DataFrames.toDataType(leftInfo.getSchema());
Dataset<Row> leftDF = toDataset(((JavaRDD<StructuredRecord>) rdd).map(recordsInCounter), leftSchema);
JoinCollection rightInfo = joinRequest.getRight();
SparkCollection<?> rightData = rightInfo.getData();
StructType rightSchema = DataFrames.toDataType(rightInfo.getSchema());
Dataset<Row> rightDF = toDataset(((JavaRDD<StructuredRecord>) rightData.getUnderlying()).map(recordsInCounter), rightSchema);
// if this is not a broadcast join, Spark will reprocess each side multiple times, depending on the number
// of partitions. If the left side has N partitions and the right side has M partitions,
// the left side gets reprocessed M times and the right side gets reprocessed N times.
// Cache the input to prevent confusing metrics and potential source re-reading.
// this is only necessary for inner joins, since outer joins are automatically changed to
// BroadcastNestedLoopJoins by Spark
boolean isInner = joinRequest.getLeft().isRequired() && joinRequest.getRight().isRequired();
boolean isBroadcast = joinRequest.getLeft().isBroadcast() || joinRequest.getRight().isBroadcast();
if (isInner && !isBroadcast) {
leftDF = leftDF.persist(StorageLevel.DISK_ONLY());
rightDF = rightDF.persist(StorageLevel.DISK_ONLY());
}
// register using unique names to avoid collisions.
String leftId = UUID.randomUUID().toString().replaceAll("-", "");
String rightId = UUID.randomUUID().toString().replaceAll("-", "");
leftDF.registerTempTable(leftId);
rightDF.registerTempTable(rightId);
/*
Suppose the join was originally:
select P.id as id, users.name as username
from purchases as P join users
on P.user_id = users.id or P.user_id = 0
After registering purchases as uuid0 and users as uuid1,
the query needs to be rewritten to replace the original names with the new generated ids,
as the query needs to be:
select P.id as id, uuid1.name as username
from uuid0 as P join uuid1
on P.user_id = uuid1.id or P.user_id = 0
*/
String sql = getSQL(joinRequest.rename(leftId, rightId));
LOG.debug("Executing join stage {} using SQL: \n{}", joinRequest.getStageName(), sql);
Dataset<Row> joined = sqlContext.sql(sql);
Schema outputSchema = joinRequest.getOutputSchema();
JavaRDD<StructuredRecord> output = joined.javaRDD().map(r -> DataFrames.fromRow(r, outputSchema)).map(new CountingFunction<>(joinRequest.getStageName(), sec.getMetrics(), Constants.Metrics.RECORDS_OUT, sec.getDataTracer(joinRequest.getStageName())));
return (SparkCollection<T>) wrap(output);
}
use of io.cdap.cdap.etl.spark.join.JoinExpressionRequest in project cdap by caskdata.
the class SparkPipelineRunner method handleAutoJoinWithSQL.
/*
Implement a join by generating a SQL query that Spark will execute.
Joins on key equality are not implemented this way because they have special repartitioning
that allows them to specify a different number of partitions for different joins in the same pipeline.
When Spark handles SQL queries, it uses spark.sql.shuffle.partitions number of partitions, which is a global
setting that applies to any SQL join in the pipeline.
*/
private SparkCollection<Object> handleAutoJoinWithSQL(String stageName, JoinDefinition joinDefinition, Map<String, SparkCollection<Object>> inputDataCollections) {
JoinCondition.OnExpression condition = (JoinCondition.OnExpression) joinDefinition.getCondition();
Map<String, String> aliases = condition.getDatasetAliases();
// earlier validation ensure there are exactly 2 inputs being joined
JoinStage leftStage = joinDefinition.getStages().get(0);
JoinStage rightStage = joinDefinition.getStages().get(1);
String leftStageName = leftStage.getStageName();
String rightStageName = rightStage.getStageName();
SparkCollection<Object> leftData = inputDataCollections.get(leftStageName);
JoinCollection leftCollection = new JoinCollection(leftStageName, inputDataCollections.get(leftStageName), leftStage.getSchema(), Collections.emptyList(), leftStage.isRequired(), leftStage.isBroadcast());
JoinCollection rightCollection = new JoinCollection(rightStageName, inputDataCollections.get(rightStageName), rightStage.getSchema(), Collections.emptyList(), rightStage.isRequired(), rightStage.isBroadcast());
JoinExpressionRequest joinRequest = new JoinExpressionRequest(stageName, joinDefinition.getSelectedFields(), leftCollection, rightCollection, condition, joinDefinition.getOutputSchema(), joinDefinition);
return leftData.join(joinRequest);
}
Aggregations