use of io.cdap.cdap.etl.spark.SparkCollection in project cdap by cdapio.
the class RDDCollection method join.
@SuppressWarnings("unchecked")
@Override
public SparkCollection<T> join(JoinRequest joinRequest) {
Map<String, Dataset> collections = new HashMap<>();
String stageName = joinRequest.getStageName();
Function<StructuredRecord, StructuredRecord> recordsInCounter = new CountingFunction<>(stageName, sec.getMetrics(), Constants.Metrics.RECORDS_IN, sec.getDataTracer(stageName));
StructType leftSparkSchema = DataFrames.toDataType(joinRequest.getLeftSchema());
Dataset<Row> left = toDataset(((JavaRDD<StructuredRecord>) rdd).map(recordsInCounter), leftSparkSchema);
collections.put(joinRequest.getLeftStage(), left);
List<Column> leftJoinColumns = joinRequest.getLeftKey().stream().map(left::col).collect(Collectors.toList());
/*
This flag keeps track of whether there is at least one required stage in the join.
This is needed in case there is a join like:
A (optional), B (required), C (optional), D (required)
The correct thing to do here is:
1. A right outer join B as TMP1
2. TMP1 left outer join C as TMP2
3. TMP2 inner join D
Join #1 is a straightforward join between 2 sides.
Join #2 is a left outer because TMP1 becomes 'required', since it uses required input B.
Join #3 is an inner join even though it contains 2 optional datasets, because 'B' is still required.
*/
Integer joinPartitions = joinRequest.getNumPartitions();
boolean seenRequired = joinRequest.isLeftRequired();
Dataset<Row> joined = left;
List<List<Column>> listOfListOfLeftCols = new ArrayList<>();
for (JoinCollection toJoin : joinRequest.getToJoin()) {
SparkCollection<StructuredRecord> data = (SparkCollection<StructuredRecord>) toJoin.getData();
StructType sparkSchema = DataFrames.toDataType(toJoin.getSchema());
Dataset<Row> right = toDataset(((JavaRDD<StructuredRecord>) data.getUnderlying()).map(recordsInCounter), sparkSchema);
collections.put(toJoin.getStage(), right);
List<Column> rightJoinColumns = toJoin.getKey().stream().map(right::col).collect(Collectors.toList());
// UUID for salt column name to avoid name collisions
String saltColumn = UUID.randomUUID().toString();
if (joinRequest.isDistributionEnabled()) {
boolean isLeftStageSkewed = joinRequest.getLeftStage().equals(joinRequest.getDistribution().getSkewedStageName());
// Apply salt/explode transformations to each Dataset
if (isLeftStageSkewed) {
left = saltDataset(left, saltColumn, joinRequest.getDistribution().getDistributionFactor());
right = explodeDataset(right, saltColumn, joinRequest.getDistribution().getDistributionFactor());
} else {
left = explodeDataset(left, saltColumn, joinRequest.getDistribution().getDistributionFactor());
right = saltDataset(right, saltColumn, joinRequest.getDistribution().getDistributionFactor());
}
// Add the salt column to the join key
leftJoinColumns.add(left.col(saltColumn));
rightJoinColumns.add(right.col(saltColumn));
// Updating other values that will be used later in join
joined = left;
sparkSchema = sparkSchema.add(saltColumn, DataTypes.IntegerType, false);
leftSparkSchema = leftSparkSchema.add(saltColumn, DataTypes.IntegerType, false);
}
Column joinOn;
// Making effectively final to use in streams
List<Column> finalLeftJoinColumns = leftJoinColumns;
if (seenRequired) {
joinOn = IntStream.range(0, leftJoinColumns.size()).mapToObj(i -> eq(finalLeftJoinColumns.get(i), rightJoinColumns.get(i), joinRequest.isNullSafe())).reduce((a, b) -> a.and(b)).get();
} else {
// For the case when all joins are outer. Collect left keys at each level (each iteration)
// coalesce these keys at each level and compare with right
joinOn = IntStream.range(0, leftJoinColumns.size()).mapToObj(i -> {
collectLeftJoinOnCols(listOfListOfLeftCols, i, finalLeftJoinColumns.get(i));
return eq(getLeftJoinOnCoalescedColumn(finalLeftJoinColumns.get(i), i, listOfListOfLeftCols), rightJoinColumns.get(i), joinRequest.isNullSafe());
}).reduce((a, b) -> a.and(b)).get();
}
String joinType;
if (seenRequired && toJoin.isRequired()) {
joinType = "inner";
} else if (seenRequired && !toJoin.isRequired()) {
joinType = "leftouter";
} else if (!seenRequired && toJoin.isRequired()) {
joinType = "rightouter";
} else {
joinType = "outer";
}
seenRequired = seenRequired || toJoin.isRequired();
if (toJoin.isBroadcast()) {
right = functions.broadcast(right);
}
// we are forced to with spark.cdap.pipeline.aggregate.dataset.partitions.ignore = false
if (!ignorePartitionsDuringDatasetAggregation && joinPartitions != null && !toJoin.isBroadcast()) {
List<String> rightKeys = new ArrayList<>(toJoin.getKey());
List<String> leftKeys = new ArrayList<>(joinRequest.getLeftKey());
// number of partitions
if (joinRequest.isDistributionEnabled()) {
rightKeys.add(saltColumn);
leftKeys.add(saltColumn);
}
right = partitionOnKey(right, rightKeys, joinRequest.isNullSafe(), sparkSchema, joinPartitions);
// as intermediate joins will already be partitioned on the key
if (joined == left) {
joined = partitionOnKey(joined, leftKeys, joinRequest.isNullSafe(), leftSparkSchema, joinPartitions);
}
}
joined = joined.join(right, joinOn, joinType);
/*
Additionally if none of the datasets are required until now, which means all of the joines will outer.
In this case also we need to pass on the join columns as we need to compare using coalesce of all previous
columns with the right dataset
*/
if (toJoin.isRequired() || !seenRequired) {
leftJoinColumns = rightJoinColumns;
}
}
// select and alias fields in the expected order
List<Column> outputColumns = new ArrayList<>(joinRequest.getFields().size());
for (JoinField field : joinRequest.getFields()) {
Column column = collections.get(field.getStageName()).col(field.getFieldName());
if (field.getAlias() != null) {
column = column.alias(field.getAlias());
}
outputColumns.add(column);
}
Seq<Column> outputColumnSeq = JavaConversions.asScalaBuffer(outputColumns).toSeq();
joined = joined.select(outputColumnSeq);
Schema outputSchema = joinRequest.getOutputSchema();
JavaRDD<StructuredRecord> output = joined.javaRDD().map(r -> DataFrames.fromRow(r, outputSchema)).map(new CountingFunction<>(stageName, sec.getMetrics(), Constants.Metrics.RECORDS_OUT, sec.getDataTracer(stageName)));
return (SparkCollection<T>) wrap(output);
}
use of io.cdap.cdap.etl.spark.SparkCollection in project cdap by cdapio.
the class BatchSparkPipelineDriverTest method testSQLEngineDoesNotSupportJoin.
@Test
public void testSQLEngineDoesNotSupportJoin() {
when(adapter.canJoin(anyString(), any(JoinDefinition.class))).thenReturn(false);
List<JoinStage> noneBroadcast = Arrays.asList(JoinStage.builder("a", null).setBroadcast(false).build(), JoinStage.builder("b", null).setBroadcast(false).build(), JoinStage.builder("c", null).setBroadcast(false).build());
JoinDefinition joinDefinition = mock(JoinDefinition.class);
doReturn(noneBroadcast).when(joinDefinition).getStages();
Map<String, SparkCollection<Object>> collections = new HashMap<>();
collections.put("a", mock(RDDCollection.class));
collections.put("b", mock(RDDCollection.class));
collections.put("c", mock(RDDCollection.class));
Assert.assertFalse(driver.canJoinOnSQLEngine(STAGE_NAME, joinDefinition, collections));
}
use of io.cdap.cdap.etl.spark.SparkCollection in project cdap by caskdata.
the class BatchSQLEngineAdapter method tryRelationalTransform.
/**
* This method is called when engine is present and is willing to try performing a relational transform.
*
* @param stageSpec stage specification
* @param transform transform plugin
* @param input input collections
* @return resulting collection or empty optional if tranform can't be done with this engine
*/
public Optional<SQLEngineJob<SQLDataset>> tryRelationalTransform(StageSpec stageSpec, RelationalTransform transform, Map<String, SparkCollection<Object>> input) {
String stageName = stageSpec.getName();
Map<String, Relation> inputRelations = input.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> sqlEngine.getRelation(new SQLRelationDefinition(e.getKey(), stageSpec.getInputSchemas().get(e.getKey())))));
BasicRelationalTransformContext pluginContext = new BasicRelationalTransformContext(getSQLRelationalEngine(), inputRelations, stageSpec.getInputSchemas(), stageSpec.getOutputSchema());
if (!transform.transform(pluginContext)) {
// Plugin was not able to do relational tranform with this engine
return Optional.empty();
}
if (pluginContext.getOutputRelation() == null) {
// Plugin said that tranformation was success but failed to set output
throw new IllegalStateException("Plugin " + transform + " did not produce a relational output");
}
if (!pluginContext.getOutputRelation().isValid()) {
// An output is set to invalid relation, probably some of transforms are not supported by an engine
return Optional.empty();
}
// Ensure input and output schemas for this stage are supported by the engine
if (stageSpec.getInputSchemas().values().stream().anyMatch(s -> !sqlEngine.supportsInputSchema(s))) {
return Optional.empty();
}
if (!sqlEngine.supportsOutputSchema(stageSpec.getOutputSchema())) {
return Optional.empty();
}
// Validate transformation definition with engine
SQLTransformDefinition transformDefinition = new SQLTransformDefinition(stageName, pluginContext.getOutputRelation(), stageSpec.getOutputSchema(), Collections.emptyMap(), Collections.emptyMap());
if (!sqlEngine.canTransform(transformDefinition)) {
return Optional.empty();
}
return Optional.of(runJob(stageSpec.getName(), SQLEngineJobType.EXECUTE, () -> {
// Push all stages that need to be pushed to execute this aggregation
input.forEach((name, collection) -> {
if (!exists(name)) {
push(name, stageSpec.getInputSchemas().get(name), collection);
}
});
// Initialize metrics collector
DefaultStageMetrics stageMetrics = new DefaultStageMetrics(metrics, stageName);
StageStatisticsCollector statisticsCollector = statsCollectors.get(stageName);
// Collect input datasets and execute transformation
Map<String, SQLDataset> inputDatasets = input.keySet().stream().collect(Collectors.toMap(Function.identity(), this::getDatasetForStage));
// Count input records
for (SQLDataset inputDataset : inputDatasets.values()) {
countRecordsIn(inputDataset, statisticsCollector, stageMetrics);
}
// Execute transform
SQLTransformRequest sqlContext = new SQLTransformRequest(inputDatasets, stageSpec.getName(), pluginContext.getOutputRelation(), stageSpec.getOutputSchema());
SQLDataset transformed = sqlEngine.transform(sqlContext);
// Count output records
countRecordsOut(transformed, statisticsCollector, stageMetrics);
return transformed;
}));
}
use of io.cdap.cdap.etl.spark.SparkCollection in project cdap by caskdata.
the class BatchSparkPipelineDriverTest method testShouldNotJoinOnSQLEngineWithBroadcast.
@Test
public void testShouldNotJoinOnSQLEngineWithBroadcast() {
List<JoinStage> noneBroadcast = Arrays.asList(JoinStage.builder("a", null).setBroadcast(false).build(), JoinStage.builder("b", null).setBroadcast(false).build(), JoinStage.builder("c", null).setBroadcast(true).build());
JoinDefinition joinDefinition = mock(JoinDefinition.class);
doReturn(noneBroadcast).when(joinDefinition).getStages();
Map<String, SparkCollection<Object>> collections = new HashMap<>();
collections.put("a", mock(RDDCollection.class));
collections.put("b", mock(RDDCollection.class));
collections.put("c", mock(RDDCollection.class));
Assert.assertFalse(driver.canJoinOnSQLEngine(STAGE_NAME, joinDefinition, collections));
}
use of io.cdap.cdap.etl.spark.SparkCollection in project cdap by caskdata.
the class BatchSparkPipelineDriverTest method testShouldJoinOnSQLEngineWithoutBroadcast.
@Test
public void testShouldJoinOnSQLEngineWithoutBroadcast() {
List<JoinStage> noneBroadcast = Arrays.asList(JoinStage.builder("a", null).setBroadcast(false).build(), JoinStage.builder("b", null).setBroadcast(false).build(), JoinStage.builder("c", null).setBroadcast(false).build());
JoinDefinition joinDefinition = mock(JoinDefinition.class);
doReturn(noneBroadcast).when(joinDefinition).getStages();
Map<String, SparkCollection<Object>> collections = new HashMap<>();
collections.put("a", mock(RDDCollection.class));
collections.put("b", mock(RDDCollection.class));
collections.put("c", mock(RDDCollection.class));
Assert.assertTrue(driver.canJoinOnSQLEngine(STAGE_NAME, joinDefinition, collections));
}
Aggregations