Search in sources :

Example 6 with JoinStage

use of io.cdap.cdap.etl.api.join.JoinStage in project cdap by cdapio.

the class SparkPipelineRunner method handleAutoJoinOnKeys.

/**
 * The purpose of this method is to collect various pieces of information together into a JoinRequest.
 * This amounts to gathering the SparkCollection, schema, join key, and join type for each stage involved in the join.
 */
private SparkCollection<Object> handleAutoJoinOnKeys(String stageName, JoinDefinition joinDefinition, Map<String, SparkCollection<Object>> inputDataCollections, @Nullable Integer numPartitions) {
    // sort stages to join so that broadcasts happen last. This is to ensure that the left side is not a broadcast
    // so that we don't try to broadcast both sides of the join. It also causes less data to be shuffled for the
    // non-broadcast joins.
    List<JoinStage> joinOrder = new ArrayList<>(joinDefinition.getStages());
    joinOrder.sort((s1, s2) -> {
        if (s1.isBroadcast() && !s2.isBroadcast()) {
            return 1;
        } else if (!s1.isBroadcast() && s2.isBroadcast()) {
            return -1;
        }
        return 0;
    });
    Iterator<JoinStage> stageIter = joinOrder.iterator();
    JoinStage left = stageIter.next();
    String leftName = left.getStageName();
    SparkCollection<Object> leftCollection = inputDataCollections.get(left.getStageName());
    Schema leftSchema = left.getSchema();
    JoinCondition condition = joinDefinition.getCondition();
    JoinCondition.OnKeys onKeys = (JoinCondition.OnKeys) condition;
    // If this is a join on A.x = B.y = C.z and A.k = B.k = C.k, then stageKeys will look like:
    // A -> [x, k]
    // B -> [y, k]
    // C -> [z, k]
    Map<String, List<String>> stageKeys = onKeys.getKeys().stream().collect(Collectors.toMap(JoinKey::getStageName, JoinKey::getFields));
    Schema outputSchema = joinDefinition.getOutputSchema();
    // when we properly propagate schema at runtime, this condition should no longer happen
    if (outputSchema == null) {
        throw new IllegalArgumentException(String.format("Joiner stage '%s' cannot calculate its output schema because " + "one or more inputs have dynamic or unknown schema. " + "An output schema must be directly provided.", stageName));
    }
    List<JoinCollection> toJoin = new ArrayList<>();
    List<Schema> keySchema = null;
    while (stageIter.hasNext()) {
        // in this loop, information for each stage to be joined is gathered together into a JoinCollection
        JoinStage right = stageIter.next();
        String rightName = right.getStageName();
        Schema rightSchema = right.getSchema();
        List<String> key = stageKeys.get(rightName);
        if (rightSchema == null) {
            if (keySchema == null) {
                keySchema = deriveKeySchema(stageName, stageKeys, joinDefinition);
            }
            // if the schema is not known, generate it from the provided output schema and the selected fields
            rightSchema = deriveInputSchema(stageName, rightName, key, keySchema, joinDefinition.getSelectedFields(), joinDefinition.getOutputSchema());
        } else {
            // drop fields that aren't included in the final output
            // don't need to do this if the schema was derived, since it will already only contain
            // fields in the output schema
            Set<String> requiredFields = new HashSet<>(key);
            for (JoinField joinField : joinDefinition.getSelectedFields()) {
                if (!joinField.getStageName().equals(rightName)) {
                    continue;
                }
                requiredFields.add(joinField.getFieldName());
            }
            rightSchema = trimSchema(rightSchema, requiredFields);
        }
        // JoinCollection contains the stage name,  SparkCollection, schema, joinkey,
        // whether it's required, and whether to broadcast
        toJoin.add(new JoinCollection(rightName, inputDataCollections.get(rightName), rightSchema, key, right.isRequired(), right.isBroadcast()));
    }
    List<String> leftKey = stageKeys.get(leftName);
    if (leftSchema == null) {
        if (keySchema == null) {
            keySchema = deriveKeySchema(stageName, stageKeys, joinDefinition);
        }
        leftSchema = deriveInputSchema(stageName, leftName, leftKey, keySchema, joinDefinition.getSelectedFields(), joinDefinition.getOutputSchema());
    }
    // JoinRequest contains the left side of the join, plus 1 or more other stages to join to.
    JoinRequest joinRequest = new JoinRequest(stageName, leftName, leftKey, leftSchema, left.isRequired(), onKeys.isNullSafe(), joinDefinition.getSelectedFields(), joinDefinition.getOutputSchema(), toJoin, numPartitions, joinDefinition.getDistribution(), joinDefinition);
    return leftCollection.join(joinRequest);
}
Also used : JoinStage(io.cdap.cdap.etl.api.join.JoinStage) Schema(io.cdap.cdap.api.data.schema.Schema) ArrayList(java.util.ArrayList) JoinField(io.cdap.cdap.etl.api.join.JoinField) JoinRequest(io.cdap.cdap.etl.spark.join.JoinRequest) JoinCondition(io.cdap.cdap.etl.api.join.JoinCondition) List(java.util.List) ArrayList(java.util.ArrayList) JoinCollection(io.cdap.cdap.etl.spark.join.JoinCollection) HashSet(java.util.HashSet)

Example 7 with JoinStage

use of io.cdap.cdap.etl.api.join.JoinStage in project cdap by cdapio.

the class SparkPipelineRunner method handleAutoJoinWithSQL.

/*
      Implement a join by generating a SQL query that Spark will execute.
      Joins on key equality are not implemented this way because they have special repartitioning
      that allows them to specify a different number of partitions for different joins in the same pipeline.

      When Spark handles SQL queries, it uses spark.sql.shuffle.partitions number of partitions, which is a global
      setting that applies to any SQL join in the pipeline.
   */
private SparkCollection<Object> handleAutoJoinWithSQL(String stageName, JoinDefinition joinDefinition, Map<String, SparkCollection<Object>> inputDataCollections) {
    JoinCondition.OnExpression condition = (JoinCondition.OnExpression) joinDefinition.getCondition();
    Map<String, String> aliases = condition.getDatasetAliases();
    // earlier validation ensure there are exactly 2 inputs being joined
    JoinStage leftStage = joinDefinition.getStages().get(0);
    JoinStage rightStage = joinDefinition.getStages().get(1);
    String leftStageName = leftStage.getStageName();
    String rightStageName = rightStage.getStageName();
    SparkCollection<Object> leftData = inputDataCollections.get(leftStageName);
    JoinCollection leftCollection = new JoinCollection(leftStageName, inputDataCollections.get(leftStageName), leftStage.getSchema(), Collections.emptyList(), leftStage.isRequired(), leftStage.isBroadcast());
    JoinCollection rightCollection = new JoinCollection(rightStageName, inputDataCollections.get(rightStageName), rightStage.getSchema(), Collections.emptyList(), rightStage.isRequired(), rightStage.isBroadcast());
    JoinExpressionRequest joinRequest = new JoinExpressionRequest(stageName, joinDefinition.getSelectedFields(), leftCollection, rightCollection, condition, joinDefinition.getOutputSchema(), joinDefinition);
    return leftData.join(joinRequest);
}
Also used : JoinStage(io.cdap.cdap.etl.api.join.JoinStage) JoinExpressionRequest(io.cdap.cdap.etl.spark.join.JoinExpressionRequest) JoinCollection(io.cdap.cdap.etl.spark.join.JoinCollection) JoinCondition(io.cdap.cdap.etl.api.join.JoinCondition)

Example 8 with JoinStage

use of io.cdap.cdap.etl.api.join.JoinStage in project cdap by cdapio.

the class JoinOnFunction method createInitializedJoinOnTransform.

private JoinOnTransform<INPUT_RECORD, JOIN_KEY> createInitializedJoinOnTransform() throws Exception {
    Object plugin = pluginFunctionContext.createPlugin();
    BatchJoiner<JOIN_KEY, INPUT_RECORD, Object> joiner;
    boolean filterNullKeys = false;
    if (plugin instanceof BatchAutoJoiner) {
        BatchAutoJoiner autoJoiner = (BatchAutoJoiner) plugin;
        AutoJoinerContext autoJoinerContext = pluginFunctionContext.createAutoJoinerContext();
        JoinDefinition joinDefinition = autoJoiner.define(autoJoinerContext);
        autoJoinerContext.getFailureCollector().getOrThrowException();
        String stageName = pluginFunctionContext.getStageName();
        if (joinDefinition == null) {
            throw new IllegalStateException(String.format("Join stage '%s' did not specify a join definition. " + "Check with the plugin developer to ensure it is implemented correctly.", stageName));
        }
        JoinCondition condition = joinDefinition.getCondition();
        /*
         Filter out the record if it comes from an optional stage
         and the key is null, or if any of the fields in the key is null.
         For example, suppose we are performing a left outer join on:

          A (id, name) = (0, alice), (null, bob)
          B (id, email) = (0, alice@example.com), (null, placeholder@example.com)

         The final output should be:

         joined (A.id, A.name, B.email) = (0, alice, alice@example.com), (null, bob, null, null)

         that is, the bob record should not be joined to the placeholder@example email, even though both their
         ids are null.
       */
        if (condition.getOp() == JoinCondition.Op.KEY_EQUALITY && !((JoinCondition.OnKeys) condition).isNullSafe()) {
            filterNullKeys = joinDefinition.getStages().stream().filter(s -> !s.isRequired()).map(JoinStage::getStageName).anyMatch(s -> s.equals(inputStageName));
        }
        joiner = new JoinerBridge(stageName, autoJoiner, joinDefinition);
    } else {
        joiner = (BatchJoiner<JOIN_KEY, INPUT_RECORD, Object>) plugin;
        BatchJoinerRuntimeContext context = pluginFunctionContext.createBatchRuntimeContext();
        joiner.initialize(context);
    }
    return new JoinOnTransform<>(joiner, inputStageName, filterNullKeys);
}
Also used : BatchJoiner(io.cdap.cdap.etl.api.batch.BatchJoiner) Transformation(io.cdap.cdap.etl.api.Transformation) BatchJoinerRuntimeContext(io.cdap.cdap.etl.api.batch.BatchJoinerRuntimeContext) PairFlatMapFunction(org.apache.spark.api.java.function.PairFlatMapFunction) Iterator(java.util.Iterator) JoinStage(io.cdap.cdap.etl.api.join.JoinStage) JoinerBridge(io.cdap.cdap.etl.common.plugin.JoinerBridge) Tuple2(scala.Tuple2) Schema(io.cdap.cdap.api.data.schema.Schema) Constants(io.cdap.cdap.etl.common.Constants) StructuredRecord(io.cdap.cdap.api.data.format.StructuredRecord) TrackedTransform(io.cdap.cdap.etl.common.TrackedTransform) Emitter(io.cdap.cdap.etl.api.Emitter) DefaultEmitter(io.cdap.cdap.etl.common.DefaultEmitter) JoinDefinition(io.cdap.cdap.etl.api.join.JoinDefinition) AutoJoinerContext(io.cdap.cdap.etl.api.join.AutoJoinerContext) JoinCondition(io.cdap.cdap.etl.api.join.JoinCondition) BatchAutoJoiner(io.cdap.cdap.etl.api.batch.BatchAutoJoiner) BatchJoinerRuntimeContext(io.cdap.cdap.etl.api.batch.BatchJoinerRuntimeContext) JoinStage(io.cdap.cdap.etl.api.join.JoinStage) JoinCondition(io.cdap.cdap.etl.api.join.JoinCondition) BatchAutoJoiner(io.cdap.cdap.etl.api.batch.BatchAutoJoiner) AutoJoinerContext(io.cdap.cdap.etl.api.join.AutoJoinerContext) JoinDefinition(io.cdap.cdap.etl.api.join.JoinDefinition) JoinerBridge(io.cdap.cdap.etl.common.plugin.JoinerBridge)

Example 9 with JoinStage

use of io.cdap.cdap.etl.api.join.JoinStage in project cdap by cdapio.

the class BatchSparkPipelineDriverTest method testSQLEngineDoesNotSupportJoin.

@Test
public void testSQLEngineDoesNotSupportJoin() {
    when(adapter.canJoin(anyString(), any(JoinDefinition.class))).thenReturn(false);
    List<JoinStage> noneBroadcast = Arrays.asList(JoinStage.builder("a", null).setBroadcast(false).build(), JoinStage.builder("b", null).setBroadcast(false).build(), JoinStage.builder("c", null).setBroadcast(false).build());
    JoinDefinition joinDefinition = mock(JoinDefinition.class);
    doReturn(noneBroadcast).when(joinDefinition).getStages();
    Map<String, SparkCollection<Object>> collections = new HashMap<>();
    collections.put("a", mock(RDDCollection.class));
    collections.put("b", mock(RDDCollection.class));
    collections.put("c", mock(RDDCollection.class));
    Assert.assertFalse(driver.canJoinOnSQLEngine(STAGE_NAME, joinDefinition, collections));
}
Also used : SparkCollection(io.cdap.cdap.etl.spark.SparkCollection) JoinStage(io.cdap.cdap.etl.api.join.JoinStage) JoinDefinition(io.cdap.cdap.etl.api.join.JoinDefinition) HashMap(java.util.HashMap) Matchers.anyString(org.mockito.Matchers.anyString) Test(org.junit.Test)

Example 10 with JoinStage

use of io.cdap.cdap.etl.api.join.JoinStage in project hydrator-plugins by cdapio.

the class JoinerConfigTest method testAdvancedWithTooManyInputs.

@Test
public void testAdvancedWithTooManyInputs() {
    JoinerConfig conf = new JoinerConfig("users.id, emails.email", "users.id = emails.userid", new HashSet<>(Arrays.asList("users", "emails")));
    Joiner joiner = new Joiner(conf);
    FailureCollector collector = new MockFailureCollector();
    Schema userSchema = Schema.recordOf("user", Schema.Field.of("id", Schema.of(Schema.Type.INT)));
    Schema emailSchema = Schema.recordOf("email", Schema.Field.of("email", Schema.of(Schema.Type.STRING)), Schema.Field.of("userid", Schema.of(Schema.Type.INT)));
    Map<String, JoinStage> inputStages = new HashMap<>();
    inputStages.put("users", JoinStage.builder("users", userSchema).build());
    inputStages.put("emails", JoinStage.builder("emails", emailSchema).build());
    inputStages.put("users2", JoinStage.builder("users2", userSchema).build());
    AutoJoinerContext autoJoinerContext = new MockAutoJoinerContext(inputStages, collector);
    try {
        joiner.define(autoJoinerContext);
        Assert.fail("Advanced join did not fail with 3 inputs as expected.");
    } catch (ValidationException e) {
        List<ValidationFailure> failures = e.getFailures();
        Assert.assertEquals(1, failures.size());
        List<ValidationFailure.Cause> causes = failures.get(0).getCauses();
        Assert.assertEquals(1, causes.size());
        Assert.assertEquals(JoinerConfig.CONDITION_TYPE, causes.get(0).getAttribute(CauseAttributes.STAGE_CONFIG));
    }
}
Also used : JoinStage(io.cdap.cdap.etl.api.join.JoinStage) ValidationException(io.cdap.cdap.etl.api.validation.ValidationException) HashMap(java.util.HashMap) Schema(io.cdap.cdap.api.data.schema.Schema) ValidationFailure(io.cdap.cdap.etl.api.validation.ValidationFailure) AutoJoinerContext(io.cdap.cdap.etl.api.join.AutoJoinerContext) MockFailureCollector(io.cdap.cdap.etl.mock.validation.MockFailureCollector) List(java.util.List) FailureCollector(io.cdap.cdap.etl.api.FailureCollector) MockFailureCollector(io.cdap.cdap.etl.mock.validation.MockFailureCollector) Test(org.junit.Test)

Aggregations

JoinStage (io.cdap.cdap.etl.api.join.JoinStage)24 JoinDefinition (io.cdap.cdap.etl.api.join.JoinDefinition)14 HashMap (java.util.HashMap)13 Test (org.junit.Test)13 Schema (io.cdap.cdap.api.data.schema.Schema)11 JoinCondition (io.cdap.cdap.etl.api.join.JoinCondition)10 AutoJoinerContext (io.cdap.cdap.etl.api.join.AutoJoinerContext)8 SparkCollection (io.cdap.cdap.etl.spark.SparkCollection)8 Matchers.anyString (org.mockito.Matchers.anyString)8 FailureCollector (io.cdap.cdap.etl.api.FailureCollector)6 JoinField (io.cdap.cdap.etl.api.join.JoinField)5 MockFailureCollector (io.cdap.cdap.etl.mock.validation.MockFailureCollector)5 ArrayList (java.util.ArrayList)5 HashSet (java.util.HashSet)5 List (java.util.List)5 ValidationException (io.cdap.cdap.etl.api.validation.ValidationException)4 JoinCollection (io.cdap.cdap.etl.spark.join.JoinCollection)4 BatchAutoJoiner (io.cdap.cdap.etl.api.batch.BatchAutoJoiner)3 BatchJoiner (io.cdap.cdap.etl.api.batch.BatchJoiner)3 JoinKey (io.cdap.cdap.etl.api.join.JoinKey)3