Search in sources :

Example 1 with JoinStage

use of io.cdap.cdap.etl.api.join.JoinStage in project cdap by caskdata.

the class BatchSparkPipelineDriverTest method testShouldJoinOnSQLEngineWithBroadcastAndAlreadyPushedCollection.

@Test
public void testShouldJoinOnSQLEngineWithBroadcastAndAlreadyPushedCollection() {
    List<JoinStage> noneBroadcast = Arrays.asList(JoinStage.builder("a", null).setBroadcast(false).build(), JoinStage.builder("b", null).setBroadcast(false).build(), JoinStage.builder("c", null).setBroadcast(true).build());
    JoinDefinition joinDefinition = mock(JoinDefinition.class);
    doReturn(noneBroadcast).when(joinDefinition).getStages();
    Map<String, SparkCollection<Object>> collections = new HashMap<>();
    collections.put("a", mock(SQLEngineCollection.class));
    collections.put("b", mock(RDDCollection.class));
    collections.put("c", mock(RDDCollection.class));
    Assert.assertTrue(driver.canJoinOnSQLEngine(STAGE_NAME, joinDefinition, collections));
}
Also used : SparkCollection(io.cdap.cdap.etl.spark.SparkCollection) JoinStage(io.cdap.cdap.etl.api.join.JoinStage) JoinDefinition(io.cdap.cdap.etl.api.join.JoinDefinition) HashMap(java.util.HashMap) Matchers.anyString(org.mockito.Matchers.anyString) Test(org.junit.Test)

Example 2 with JoinStage

use of io.cdap.cdap.etl.api.join.JoinStage in project cdap by caskdata.

the class BatchSparkPipelineDriverTest method testSQLEngineDoesNotSupportJoin.

@Test
public void testSQLEngineDoesNotSupportJoin() {
    when(adapter.canJoin(anyString(), any(JoinDefinition.class))).thenReturn(false);
    List<JoinStage> noneBroadcast = Arrays.asList(JoinStage.builder("a", null).setBroadcast(false).build(), JoinStage.builder("b", null).setBroadcast(false).build(), JoinStage.builder("c", null).setBroadcast(false).build());
    JoinDefinition joinDefinition = mock(JoinDefinition.class);
    doReturn(noneBroadcast).when(joinDefinition).getStages();
    Map<String, SparkCollection<Object>> collections = new HashMap<>();
    collections.put("a", mock(RDDCollection.class));
    collections.put("b", mock(RDDCollection.class));
    collections.put("c", mock(RDDCollection.class));
    Assert.assertFalse(driver.canJoinOnSQLEngine(STAGE_NAME, joinDefinition, collections));
}
Also used : SparkCollection(io.cdap.cdap.etl.spark.SparkCollection) JoinStage(io.cdap.cdap.etl.api.join.JoinStage) JoinDefinition(io.cdap.cdap.etl.api.join.JoinDefinition) HashMap(java.util.HashMap) Matchers.anyString(org.mockito.Matchers.anyString) Test(org.junit.Test)

Example 3 with JoinStage

use of io.cdap.cdap.etl.api.join.JoinStage in project cdap by caskdata.

the class SparkPipelineRunner method handleAutoJoinOnKeys.

/**
 * The purpose of this method is to collect various pieces of information together into a JoinRequest.
 * This amounts to gathering the SparkCollection, schema, join key, and join type for each stage involved in the join.
 */
private SparkCollection<Object> handleAutoJoinOnKeys(String stageName, JoinDefinition joinDefinition, Map<String, SparkCollection<Object>> inputDataCollections, @Nullable Integer numPartitions) {
    // sort stages to join so that broadcasts happen last. This is to ensure that the left side is not a broadcast
    // so that we don't try to broadcast both sides of the join. It also causes less data to be shuffled for the
    // non-broadcast joins.
    List<JoinStage> joinOrder = new ArrayList<>(joinDefinition.getStages());
    joinOrder.sort((s1, s2) -> {
        if (s1.isBroadcast() && !s2.isBroadcast()) {
            return 1;
        } else if (!s1.isBroadcast() && s2.isBroadcast()) {
            return -1;
        }
        return 0;
    });
    Iterator<JoinStage> stageIter = joinOrder.iterator();
    JoinStage left = stageIter.next();
    String leftName = left.getStageName();
    SparkCollection<Object> leftCollection = inputDataCollections.get(left.getStageName());
    Schema leftSchema = left.getSchema();
    JoinCondition condition = joinDefinition.getCondition();
    JoinCondition.OnKeys onKeys = (JoinCondition.OnKeys) condition;
    // If this is a join on A.x = B.y = C.z and A.k = B.k = C.k, then stageKeys will look like:
    // A -> [x, k]
    // B -> [y, k]
    // C -> [z, k]
    Map<String, List<String>> stageKeys = onKeys.getKeys().stream().collect(Collectors.toMap(JoinKey::getStageName, JoinKey::getFields));
    Schema outputSchema = joinDefinition.getOutputSchema();
    // when we properly propagate schema at runtime, this condition should no longer happen
    if (outputSchema == null) {
        throw new IllegalArgumentException(String.format("Joiner stage '%s' cannot calculate its output schema because " + "one or more inputs have dynamic or unknown schema. " + "An output schema must be directly provided.", stageName));
    }
    List<JoinCollection> toJoin = new ArrayList<>();
    List<Schema> keySchema = null;
    while (stageIter.hasNext()) {
        // in this loop, information for each stage to be joined is gathered together into a JoinCollection
        JoinStage right = stageIter.next();
        String rightName = right.getStageName();
        Schema rightSchema = right.getSchema();
        List<String> key = stageKeys.get(rightName);
        if (rightSchema == null) {
            if (keySchema == null) {
                keySchema = deriveKeySchema(stageName, stageKeys, joinDefinition);
            }
            // if the schema is not known, generate it from the provided output schema and the selected fields
            rightSchema = deriveInputSchema(stageName, rightName, key, keySchema, joinDefinition.getSelectedFields(), joinDefinition.getOutputSchema());
        } else {
            // drop fields that aren't included in the final output
            // don't need to do this if the schema was derived, since it will already only contain
            // fields in the output schema
            Set<String> requiredFields = new HashSet<>(key);
            for (JoinField joinField : joinDefinition.getSelectedFields()) {
                if (!joinField.getStageName().equals(rightName)) {
                    continue;
                }
                requiredFields.add(joinField.getFieldName());
            }
            rightSchema = trimSchema(rightSchema, requiredFields);
        }
        // JoinCollection contains the stage name,  SparkCollection, schema, joinkey,
        // whether it's required, and whether to broadcast
        toJoin.add(new JoinCollection(rightName, inputDataCollections.get(rightName), rightSchema, key, right.isRequired(), right.isBroadcast()));
    }
    List<String> leftKey = stageKeys.get(leftName);
    if (leftSchema == null) {
        if (keySchema == null) {
            keySchema = deriveKeySchema(stageName, stageKeys, joinDefinition);
        }
        leftSchema = deriveInputSchema(stageName, leftName, leftKey, keySchema, joinDefinition.getSelectedFields(), joinDefinition.getOutputSchema());
    }
    // JoinRequest contains the left side of the join, plus 1 or more other stages to join to.
    JoinRequest joinRequest = new JoinRequest(stageName, leftName, leftKey, leftSchema, left.isRequired(), onKeys.isNullSafe(), joinDefinition.getSelectedFields(), joinDefinition.getOutputSchema(), toJoin, numPartitions, joinDefinition.getDistribution(), joinDefinition);
    return leftCollection.join(joinRequest);
}
Also used : JoinStage(io.cdap.cdap.etl.api.join.JoinStage) Schema(io.cdap.cdap.api.data.schema.Schema) ArrayList(java.util.ArrayList) JoinField(io.cdap.cdap.etl.api.join.JoinField) JoinRequest(io.cdap.cdap.etl.spark.join.JoinRequest) JoinCondition(io.cdap.cdap.etl.api.join.JoinCondition) List(java.util.List) ArrayList(java.util.ArrayList) JoinCollection(io.cdap.cdap.etl.spark.join.JoinCollection) HashSet(java.util.HashSet)

Example 4 with JoinStage

use of io.cdap.cdap.etl.api.join.JoinStage in project cdap by caskdata.

the class BatchSparkPipelineDriver method handleAutoJoin.

@Override
@SuppressWarnings("unchecked")
protected SparkCollection<Object> handleAutoJoin(String stageName, JoinDefinition joinDefinition, Map<String, SparkCollection<Object>> inputDataCollections, @Nullable Integer numPartitions) {
    if (sqlEngineAdapter != null && canJoinOnSQLEngine(stageName, joinDefinition, inputDataCollections)) {
        // collections representing data that has been pushed to the SQL engine.
        for (JoinStage joinStage : joinDefinition.getStages()) {
            String joinStageName = joinStage.getStageName();
            // If the input collection is already a SQL Engine collection, there's no need to push.
            if (inputDataCollections.get(joinStageName) instanceof SQLBackedCollection) {
                continue;
            }
            SparkCollection<Object> collection = inputDataCollections.get(joinStage.getStageName());
            SQLEngineJob<SQLDataset> pushJob = sqlEngineAdapter.push(joinStageName, joinStage.getSchema(), collection);
            inputDataCollections.put(joinStageName, new SQLEngineCollection<>(sec, functionCacheFactory, jsc, new SQLContext(jsc), datasetContext, sinkFactory, collection, joinStageName, sqlEngineAdapter, pushJob));
        }
    }
    return super.handleAutoJoin(stageName, joinDefinition, inputDataCollections, numPartitions);
}
Also used : JoinStage(io.cdap.cdap.etl.api.join.JoinStage) SQLDataset(io.cdap.cdap.etl.api.engine.sql.dataset.SQLDataset) SQLContext(org.apache.spark.sql.SQLContext)

Example 5 with JoinStage

use of io.cdap.cdap.etl.api.join.JoinStage in project cdap by caskdata.

the class JoinOnFunction method createInitializedJoinOnTransform.

private JoinOnTransform<INPUT_RECORD, JOIN_KEY> createInitializedJoinOnTransform() throws Exception {
    Object plugin = pluginFunctionContext.createPlugin();
    BatchJoiner<JOIN_KEY, INPUT_RECORD, Object> joiner;
    boolean filterNullKeys = false;
    if (plugin instanceof BatchAutoJoiner) {
        BatchAutoJoiner autoJoiner = (BatchAutoJoiner) plugin;
        AutoJoinerContext autoJoinerContext = pluginFunctionContext.createAutoJoinerContext();
        JoinDefinition joinDefinition = autoJoiner.define(autoJoinerContext);
        autoJoinerContext.getFailureCollector().getOrThrowException();
        String stageName = pluginFunctionContext.getStageName();
        if (joinDefinition == null) {
            throw new IllegalStateException(String.format("Join stage '%s' did not specify a join definition. " + "Check with the plugin developer to ensure it is implemented correctly.", stageName));
        }
        JoinCondition condition = joinDefinition.getCondition();
        /*
         Filter out the record if it comes from an optional stage
         and the key is null, or if any of the fields in the key is null.
         For example, suppose we are performing a left outer join on:

          A (id, name) = (0, alice), (null, bob)
          B (id, email) = (0, alice@example.com), (null, placeholder@example.com)

         The final output should be:

         joined (A.id, A.name, B.email) = (0, alice, alice@example.com), (null, bob, null, null)

         that is, the bob record should not be joined to the placeholder@example email, even though both their
         ids are null.
       */
        if (condition.getOp() == JoinCondition.Op.KEY_EQUALITY && !((JoinCondition.OnKeys) condition).isNullSafe()) {
            filterNullKeys = joinDefinition.getStages().stream().filter(s -> !s.isRequired()).map(JoinStage::getStageName).anyMatch(s -> s.equals(inputStageName));
        }
        joiner = new JoinerBridge(stageName, autoJoiner, joinDefinition);
    } else {
        joiner = (BatchJoiner<JOIN_KEY, INPUT_RECORD, Object>) plugin;
        BatchJoinerRuntimeContext context = pluginFunctionContext.createBatchRuntimeContext();
        joiner.initialize(context);
    }
    return new JoinOnTransform<>(joiner, inputStageName, filterNullKeys);
}
Also used : BatchJoiner(io.cdap.cdap.etl.api.batch.BatchJoiner) Transformation(io.cdap.cdap.etl.api.Transformation) BatchJoinerRuntimeContext(io.cdap.cdap.etl.api.batch.BatchJoinerRuntimeContext) PairFlatMapFunction(org.apache.spark.api.java.function.PairFlatMapFunction) Iterator(java.util.Iterator) JoinStage(io.cdap.cdap.etl.api.join.JoinStage) JoinerBridge(io.cdap.cdap.etl.common.plugin.JoinerBridge) Tuple2(scala.Tuple2) Schema(io.cdap.cdap.api.data.schema.Schema) Constants(io.cdap.cdap.etl.common.Constants) StructuredRecord(io.cdap.cdap.api.data.format.StructuredRecord) TrackedTransform(io.cdap.cdap.etl.common.TrackedTransform) Emitter(io.cdap.cdap.etl.api.Emitter) DefaultEmitter(io.cdap.cdap.etl.common.DefaultEmitter) JoinDefinition(io.cdap.cdap.etl.api.join.JoinDefinition) AutoJoinerContext(io.cdap.cdap.etl.api.join.AutoJoinerContext) JoinCondition(io.cdap.cdap.etl.api.join.JoinCondition) BatchAutoJoiner(io.cdap.cdap.etl.api.batch.BatchAutoJoiner) BatchJoinerRuntimeContext(io.cdap.cdap.etl.api.batch.BatchJoinerRuntimeContext) JoinStage(io.cdap.cdap.etl.api.join.JoinStage) JoinCondition(io.cdap.cdap.etl.api.join.JoinCondition) BatchAutoJoiner(io.cdap.cdap.etl.api.batch.BatchAutoJoiner) AutoJoinerContext(io.cdap.cdap.etl.api.join.AutoJoinerContext) JoinDefinition(io.cdap.cdap.etl.api.join.JoinDefinition) JoinerBridge(io.cdap.cdap.etl.common.plugin.JoinerBridge)

Aggregations

JoinStage (io.cdap.cdap.etl.api.join.JoinStage)9 JoinDefinition (io.cdap.cdap.etl.api.join.JoinDefinition)6 JoinCondition (io.cdap.cdap.etl.api.join.JoinCondition)4 SparkCollection (io.cdap.cdap.etl.spark.SparkCollection)4 HashMap (java.util.HashMap)4 Test (org.junit.Test)4 Matchers.anyString (org.mockito.Matchers.anyString)4 Schema (io.cdap.cdap.api.data.schema.Schema)3 JoinField (io.cdap.cdap.etl.api.join.JoinField)2 JoinCollection (io.cdap.cdap.etl.spark.join.JoinCollection)2 ArrayList (java.util.ArrayList)2 HashSet (java.util.HashSet)2 StructuredRecord (io.cdap.cdap.api.data.format.StructuredRecord)1 Emitter (io.cdap.cdap.etl.api.Emitter)1 Transformation (io.cdap.cdap.etl.api.Transformation)1 BatchAutoJoiner (io.cdap.cdap.etl.api.batch.BatchAutoJoiner)1 BatchJoiner (io.cdap.cdap.etl.api.batch.BatchJoiner)1 BatchJoinerRuntimeContext (io.cdap.cdap.etl.api.batch.BatchJoinerRuntimeContext)1 SQLDataset (io.cdap.cdap.etl.api.engine.sql.dataset.SQLDataset)1 AutoJoinerContext (io.cdap.cdap.etl.api.join.AutoJoinerContext)1